summaryrefslogtreecommitdiffstats
path: root/g4f/providers
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/providers')
-rw-r--r--g4f/providers/base_provider.py280
-rw-r--r--g4f/providers/create_images.py155
-rw-r--r--g4f/providers/helper.py61
-rw-r--r--g4f/providers/retry_provider.py119
-rw-r--r--g4f/providers/types.py117
5 files changed, 732 insertions, 0 deletions
diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py
new file mode 100644
index 00000000..b8649ba5
--- /dev/null
+++ b/g4f/providers/base_provider.py
@@ -0,0 +1,280 @@
+from __future__ import annotations
+
+import sys
+import asyncio
+from asyncio import AbstractEventLoop
+from concurrent.futures import ThreadPoolExecutor
+from abc import abstractmethod
+from inspect import signature, Parameter
+from ..typing import CreateResult, AsyncResult, Messages, Union
+from .types import BaseProvider
+from ..errors import NestAsyncioError, ModelNotSupportedError
+from .. import debug
+
+if sys.version_info < (3, 10):
+ NoneType = type(None)
+else:
+ from types import NoneType
+
+# Set Windows event loop policy for better compatibility with asyncio and curl_cffi
+if sys.platform == 'win32':
+ if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
+
+def get_running_loop() -> Union[AbstractEventLoop, None]:
+ try:
+ loop = asyncio.get_running_loop()
+ if not hasattr(loop.__class__, "_nest_patched"):
+ raise NestAsyncioError(
+ 'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.'
+ )
+ return loop
+ except RuntimeError:
+ pass
+
+class AbstractProvider(BaseProvider):
+ """
+ Abstract class for providing asynchronous functionality to derived classes.
+ """
+
+ @classmethod
+ async def create_async(
+ cls,
+ model: str,
+ messages: Messages,
+ *,
+ loop: AbstractEventLoop = None,
+ executor: ThreadPoolExecutor = None,
+ **kwargs
+ ) -> str:
+ """
+ Asynchronously creates a result based on the given model and messages.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ str: The created result as a string.
+ """
+ loop = loop or asyncio.get_running_loop()
+
+ def create_func() -> str:
+ return "".join(cls.create_completion(model, messages, False, **kwargs))
+
+ return await asyncio.wait_for(
+ loop.run_in_executor(executor, create_func),
+ timeout=kwargs.get("timeout")
+ )
+
+ @classmethod
+ @property
+ def params(cls) -> str:
+ """
+ Returns the parameters supported by the provider.
+
+ Args:
+ cls (type): The class on which this property is called.
+
+ Returns:
+ str: A string listing the supported parameters.
+ """
+ sig = signature(
+ cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
+ cls.create_async if issubclass(cls, AsyncProvider) else
+ cls.create_completion
+ )
+
+ def get_type_name(annotation: type) -> str:
+ return annotation.__name__ if hasattr(annotation, "__name__") else str(annotation)
+
+ args = ""
+ for name, param in sig.parameters.items():
+ if name in ("self", "kwargs") or (name == "stream" and not cls.supports_stream):
+ continue
+ args += f"\n {name}"
+ args += f": {get_type_name(param.annotation)}" if param.annotation is not Parameter.empty else ""
+ args += f' = "{param.default}"' if param.default == "" else f" = {param.default}" if param.default is not Parameter.empty else ""
+
+ return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
+
+
+class AsyncProvider(AbstractProvider):
+ """
+ Provides asynchronous functionality for creating completions.
+ """
+
+ @classmethod
+ def create_completion(
+ cls,
+ model: str,
+ messages: Messages,
+ stream: bool = False,
+ **kwargs
+ ) -> CreateResult:
+ """
+ Creates a completion result synchronously.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to False.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ CreateResult: The result of the completion creation.
+ """
+ get_running_loop()
+ yield asyncio.run(cls.create_async(model, messages, **kwargs))
+
+ @staticmethod
+ @abstractmethod
+ async def create_async(
+ model: str,
+ messages: Messages,
+ **kwargs
+ ) -> str:
+ """
+ Abstract method for creating asynchronous results.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ NotImplementedError: If this method is not overridden in derived classes.
+
+ Returns:
+ str: The created result as a string.
+ """
+ raise NotImplementedError()
+
+
+class AsyncGeneratorProvider(AsyncProvider):
+ """
+ Provides asynchronous generator functionality for streaming results.
+ """
+ supports_stream = True
+
+ @classmethod
+ def create_completion(
+ cls,
+ model: str,
+ messages: Messages,
+ stream: bool = True,
+ **kwargs
+ ) -> CreateResult:
+ """
+ Creates a streaming completion result synchronously.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to True.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ CreateResult: The result of the streaming completion creation.
+ """
+ loop = get_running_loop()
+ new_loop = False
+ if not loop:
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ new_loop = True
+
+ generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
+ gen = generator.__aiter__()
+
+ # Fix for RuntimeError: async generator ignored GeneratorExit
+ async def await_callback(callback):
+ return await callback()
+
+ try:
+ while True:
+ yield loop.run_until_complete(await_callback(gen.__anext__))
+ except StopAsyncIteration:
+ ...
+ # Fix for: ResourceWarning: unclosed event loop
+ finally:
+ if new_loop:
+ loop.close()
+ asyncio.set_event_loop(None)
+
+ @classmethod
+ async def create_async(
+ cls,
+ model: str,
+ messages: Messages,
+ **kwargs
+ ) -> str:
+ """
+ Asynchronously creates a result from a generator.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ str: The created result as a string.
+ """
+ return "".join([
+ chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
+ if not isinstance(chunk, Exception)
+ ])
+
+ @staticmethod
+ @abstractmethod
+ async def create_async_generator(
+ model: str,
+ messages: Messages,
+ stream: bool = True,
+ **kwargs
+ ) -> AsyncResult:
+ """
+ Abstract method for creating an asynchronous generator.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to True.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ NotImplementedError: If this method is not overridden in derived classes.
+
+ Returns:
+ AsyncResult: An asynchronous generator yielding results.
+ """
+ raise NotImplementedError()
+
+class ProviderModelMixin:
+ default_model: str
+ models: list[str] = []
+ model_aliases: dict[str, str] = {}
+
+ @classmethod
+ def get_models(cls) -> list[str]:
+ return cls.models
+
+ @classmethod
+ def get_model(cls, model: str) -> str:
+ if not model:
+ model = cls.default_model
+ elif model in cls.model_aliases:
+ model = cls.model_aliases[model]
+ elif model not in cls.get_models():
+ raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
+ debug.last_model = model
+ return model \ No newline at end of file
diff --git a/g4f/providers/create_images.py b/g4f/providers/create_images.py
new file mode 100644
index 00000000..29a2a041
--- /dev/null
+++ b/g4f/providers/create_images.py
@@ -0,0 +1,155 @@
+from __future__ import annotations
+
+import re
+import asyncio
+
+from .. import debug
+from ..typing import CreateResult, Messages
+from .types import BaseProvider, ProviderType
+
+system_message = """
+You can generate images, pictures, photos or img with the DALL-E 3 image generator.
+To generate an image with a prompt, do this:
+
+<img data-prompt=\"keywords for the image\">
+
+Never use own image links. Don't wrap it in backticks.
+It is important to use a only a img tag with a prompt.
+
+<img data-prompt=\"image caption\">
+"""
+
+class CreateImagesProvider(BaseProvider):
+ """
+ Provider class for creating images based on text prompts.
+
+ This provider handles image creation requests embedded within message content,
+ using provided image creation functions.
+
+ Attributes:
+ provider (ProviderType): The underlying provider to handle non-image related tasks.
+ create_images (callable): A function to create images synchronously.
+ create_images_async (callable): A function to create images asynchronously.
+ system_message (str): A message that explains the image creation capability.
+ include_placeholder (bool): Flag to determine whether to include the image placeholder in the output.
+ __name__ (str): Name of the provider.
+ url (str): URL of the provider.
+ working (bool): Indicates if the provider is operational.
+ supports_stream (bool): Indicates if the provider supports streaming.
+ """
+
+ def __init__(
+ self,
+ provider: ProviderType,
+ create_images: callable,
+ create_async: callable,
+ system_message: str = system_message,
+ include_placeholder: bool = True
+ ) -> None:
+ """
+ Initializes the CreateImagesProvider.
+
+ Args:
+ provider (ProviderType): The underlying provider.
+ create_images (callable): Function to create images synchronously.
+ create_async (callable): Function to create images asynchronously.
+ system_message (str, optional): System message to be prefixed to messages. Defaults to a predefined message.
+ include_placeholder (bool, optional): Whether to include image placeholders in the output. Defaults to True.
+ """
+ self.provider = provider
+ self.create_images = create_images
+ self.create_images_async = create_async
+ self.system_message = system_message
+ self.include_placeholder = include_placeholder
+ self.__name__ = provider.__name__
+ self.url = provider.url
+ self.working = provider.working
+ self.supports_stream = provider.supports_stream
+
+ def create_completion(
+ self,
+ model: str,
+ messages: Messages,
+ stream: bool = False,
+ **kwargs
+ ) -> CreateResult:
+ """
+ Creates a completion result, processing any image creation prompts found within the messages.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process, which may contain image prompts.
+ stream (bool, optional): Indicates whether to stream the results. Defaults to False.
+ **kwargs: Additional keywordarguments for the provider.
+
+ Yields:
+ CreateResult: Yields chunks of the processed messages, including image data if applicable.
+
+ Note:
+ This method processes messages to detect image creation prompts. When such a prompt is found,
+ it calls the synchronous image creation function and includes the resulting image in the output.
+ """
+ messages.insert(0, {"role": "system", "content": self.system_message})
+ buffer = ""
+ for chunk in self.provider.create_completion(model, messages, stream, **kwargs):
+ if isinstance(chunk, str) and buffer or "<" in chunk:
+ buffer += chunk
+ if ">" in buffer:
+ match = re.search(r'<img data-prompt="(.*?)">', buffer)
+ if match:
+ placeholder, prompt = match.group(0), match.group(1)
+ start, append = buffer.split(placeholder, 1)
+ if start:
+ yield start
+ if self.include_placeholder:
+ yield placeholder
+ if debug.logging:
+ print(f"Create images with prompt: {prompt}")
+ yield from self.create_images(prompt)
+ if append:
+ yield append
+ else:
+ yield buffer
+ buffer = ""
+ else:
+ yield chunk
+
+ async def create_async(
+ self,
+ model: str,
+ messages: Messages,
+ **kwargs
+ ) -> str:
+ """
+ Asynchronously creates a response, processing any image creation prompts found within the messages.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process, which may contain image prompts.
+ **kwargs: Additional keyword arguments for the provider.
+
+ Returns:
+ str: The processed response string, including asynchronously generated image data if applicable.
+
+ Note:
+ This method processes messages to detect image creation prompts. When such a prompt is found,
+ it calls the asynchronous image creation function and includes the resulting image in the output.
+ """
+ messages.insert(0, {"role": "system", "content": self.system_message})
+ response = await self.provider.create_async(model, messages, **kwargs)
+ matches = re.findall(r'(<img data-prompt="(.*?)">)', response)
+ results = []
+ placeholders = []
+ for placeholder, prompt in matches:
+ if placeholder not in placeholders:
+ if debug.logging:
+ print(f"Create images with prompt: {prompt}")
+ results.append(self.create_images_async(prompt))
+ placeholders.append(placeholder)
+ results = await asyncio.gather(*results)
+ for idx, result in enumerate(results):
+ placeholder = placeholder[idx]
+ if self.include_placeholder:
+ result = placeholder + result
+ response = response.replace(placeholder, result)
+ return response \ No newline at end of file
diff --git a/g4f/providers/helper.py b/g4f/providers/helper.py
new file mode 100644
index 00000000..49d033d1
--- /dev/null
+++ b/g4f/providers/helper.py
@@ -0,0 +1,61 @@
+from __future__ import annotations
+
+import random
+import secrets
+import string
+from aiohttp import BaseConnector
+
+from ..typing import Messages, Optional
+from ..errors import MissingRequirementsError
+
+def format_prompt(messages: Messages, add_special_tokens=False) -> str:
+ """
+ Format a series of messages into a single string, optionally adding special tokens.
+
+ Args:
+ messages (Messages): A list of message dictionaries, each containing 'role' and 'content'.
+ add_special_tokens (bool): Whether to add special formatting tokens.
+
+ Returns:
+ str: A formatted string containing all messages.
+ """
+ if not add_special_tokens and len(messages) <= 1:
+ return messages[0]["content"]
+ formatted = "\n".join([
+ f'{message["role"].capitalize()}: {message["content"]}'
+ for message in messages
+ ])
+ return f"{formatted}\nAssistant:"
+
+def get_random_string(length: int = 10) -> str:
+ """
+ Generate a random string of specified length, containing lowercase letters and digits.
+
+ Args:
+ length (int, optional): Length of the random string to generate. Defaults to 10.
+
+ Returns:
+ str: A random string of the specified length.
+ """
+ return ''.join(
+ random.choice(string.ascii_lowercase + string.digits)
+ for _ in range(length)
+ )
+
+def get_random_hex() -> str:
+ """
+ Generate a random hexadecimal string of a fixed length.
+
+ Returns:
+ str: A random hexadecimal string of 32 characters (16 bytes).
+ """
+ return secrets.token_hex(16).zfill(32)
+
+def get_connector(connector: BaseConnector = None, proxy: str = None) -> Optional[BaseConnector]:
+ if proxy and not connector:
+ try:
+ from aiohttp_socks import ProxyConnector
+ connector = ProxyConnector.from_url(proxy)
+ except ImportError:
+ raise MissingRequirementsError('Install "aiohttp_socks" package for proxy support')
+ return connector \ No newline at end of file
diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py
new file mode 100644
index 00000000..a7ab2881
--- /dev/null
+++ b/g4f/providers/retry_provider.py
@@ -0,0 +1,119 @@
+from __future__ import annotations
+
+import asyncio
+import random
+
+from ..typing import CreateResult, Messages
+from .types import BaseRetryProvider
+from .. import debug
+from ..errors import RetryProviderError, RetryNoProviderError
+
+class RetryProvider(BaseRetryProvider):
+ """
+ A provider class to handle retries for creating completions with different providers.
+
+ Attributes:
+ providers (list): A list of provider instances.
+ shuffle (bool): A flag indicating whether to shuffle providers before use.
+ exceptions (dict): A dictionary to store exceptions encountered during retries.
+ last_provider (BaseProvider): The last provider that was used.
+ """
+
+ def create_completion(
+ self,
+ model: str,
+ messages: Messages,
+ stream: bool = False,
+ **kwargs
+ ) -> CreateResult:
+ """
+ Create a completion using available providers, with an option to stream the response.
+
+ Args:
+ model (str): The model to be used for completion.
+ messages (Messages): The messages to be used for generating completion.
+ stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
+
+ Yields:
+ CreateResult: Tokens or results from the completion.
+
+ Raises:
+ Exception: Any exception encountered during the completion process.
+ """
+ providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers
+ if self.shuffle:
+ random.shuffle(providers)
+
+ self.exceptions = {}
+ started: bool = False
+ for provider in providers:
+ self.last_provider = provider
+ try:
+ if debug.logging:
+ print(f"Using {provider.__name__} provider")
+ for token in provider.create_completion(model, messages, stream, **kwargs):
+ yield token
+ started = True
+ if started:
+ return
+ except Exception as e:
+ self.exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+ if started:
+ raise e
+
+ self.raise_exceptions()
+
+ async def create_async(
+ self,
+ model: str,
+ messages: Messages,
+ **kwargs
+ ) -> str:
+ """
+ Asynchronously create a completion using available providers.
+
+ Args:
+ model (str): The model to be used for completion.
+ messages (Messages): The messages to be used for generating completion.
+
+ Returns:
+ str: The result of the asynchronous completion.
+
+ Raises:
+ Exception: Any exception encountered during the asynchronous completion process.
+ """
+ providers = self.providers
+ if self.shuffle:
+ random.shuffle(providers)
+
+ self.exceptions = {}
+ for provider in providers:
+ self.last_provider = provider
+ try:
+ return await asyncio.wait_for(
+ provider.create_async(model, messages, **kwargs),
+ timeout=kwargs.get("timeout", 60)
+ )
+ except Exception as e:
+ self.exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+
+ self.raise_exceptions()
+
+ def raise_exceptions(self) -> None:
+ """
+ Raise a combined exception if any occurred during retries.
+
+ Raises:
+ RetryProviderError: If any provider encountered an exception.
+ RetryNoProviderError: If no provider is found.
+ """
+ if self.exceptions:
+ raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
+ f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items()
+ ]))
+
+ raise RetryNoProviderError("No provider found") \ No newline at end of file
diff --git a/g4f/providers/types.py b/g4f/providers/types.py
new file mode 100644
index 00000000..7b11ec43
--- /dev/null
+++ b/g4f/providers/types.py
@@ -0,0 +1,117 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import Union, List, Dict, Type
+from ..typing import Messages, CreateResult
+
+class BaseProvider(ABC):
+ """
+ Abstract base class for a provider.
+
+ Attributes:
+ url (str): URL of the provider.
+ working (bool): Indicates if the provider is currently working.
+ needs_auth (bool): Indicates if the provider needs authentication.
+ supports_stream (bool): Indicates if the provider supports streaming.
+ supports_gpt_35_turbo (bool): Indicates if the provider supports GPT-3.5 Turbo.
+ supports_gpt_4 (bool): Indicates if the provider supports GPT-4.
+ supports_message_history (bool): Indicates if the provider supports message history.
+ params (str): List parameters for the provider.
+ """
+
+ url: str = None
+ working: bool = False
+ needs_auth: bool = False
+ supports_stream: bool = False
+ supports_gpt_35_turbo: bool = False
+ supports_gpt_4: bool = False
+ supports_message_history: bool = False
+ params: str
+
+ @classmethod
+ @abstractmethod
+ def create_completion(
+ cls,
+ model: str,
+ messages: Messages,
+ stream: bool,
+ **kwargs
+ ) -> CreateResult:
+ """
+ Create a completion with the given parameters.
+
+ Args:
+ model (str): The model to use.
+ messages (Messages): The messages to process.
+ stream (bool): Whether to use streaming.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ CreateResult: The result of the creation process.
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ @abstractmethod
+ async def create_async(
+ cls,
+ model: str,
+ messages: Messages,
+ **kwargs
+ ) -> str:
+ """
+ Asynchronously create a completion with the given parameters.
+
+ Args:
+ model (str): The model to use.
+ messages (Messages): The messages to process.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ str: The result of the creation process.
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def get_dict(cls) -> Dict[str, str]:
+ """
+ Get a dictionary representation of the provider.
+
+ Returns:
+ Dict[str, str]: A dictionary with provider's details.
+ """
+ return {'name': cls.__name__, 'url': cls.url}
+
+class BaseRetryProvider(BaseProvider):
+ """
+ Base class for a provider that implements retry logic.
+
+ Attributes:
+ providers (List[Type[BaseProvider]]): List of providers to use for retries.
+ shuffle (bool): Whether to shuffle the providers list.
+ exceptions (Dict[str, Exception]): Dictionary of exceptions encountered.
+ last_provider (Type[BaseProvider]): The last provider used.
+ """
+
+ __name__: str = "RetryProvider"
+ supports_stream: bool = True
+
+ def __init__(
+ self,
+ providers: List[Type[BaseProvider]],
+ shuffle: bool = True
+ ) -> None:
+ """
+ Initialize the BaseRetryProvider.
+
+ Args:
+ providers (List[Type[BaseProvider]]): List of providers to use.
+ shuffle (bool): Whether to shuffle the providers list.
+ """
+ self.providers = providers
+ self.shuffle = shuffle
+ self.working = True
+ self.exceptions: Dict[str, Exception] = {}
+ self.last_provider: Type[BaseProvider] = None
+
+ProviderType = Union[Type[BaseProvider], BaseRetryProvider] \ No newline at end of file