diff options
Diffstat (limited to 'g4f/providers')
-rw-r--r-- | g4f/providers/base_provider.py | 280 | ||||
-rw-r--r-- | g4f/providers/create_images.py | 155 | ||||
-rw-r--r-- | g4f/providers/helper.py | 61 | ||||
-rw-r--r-- | g4f/providers/retry_provider.py | 119 | ||||
-rw-r--r-- | g4f/providers/types.py | 117 |
5 files changed, 732 insertions, 0 deletions
diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py new file mode 100644 index 00000000..b8649ba5 --- /dev/null +++ b/g4f/providers/base_provider.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import sys +import asyncio +from asyncio import AbstractEventLoop +from concurrent.futures import ThreadPoolExecutor +from abc import abstractmethod +from inspect import signature, Parameter +from ..typing import CreateResult, AsyncResult, Messages, Union +from .types import BaseProvider +from ..errors import NestAsyncioError, ModelNotSupportedError +from .. import debug + +if sys.version_info < (3, 10): + NoneType = type(None) +else: + from types import NoneType + +# Set Windows event loop policy for better compatibility with asyncio and curl_cffi +if sys.platform == 'win32': + if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + +def get_running_loop() -> Union[AbstractEventLoop, None]: + try: + loop = asyncio.get_running_loop() + if not hasattr(loop.__class__, "_nest_patched"): + raise NestAsyncioError( + 'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.' + ) + return loop + except RuntimeError: + pass + +class AbstractProvider(BaseProvider): + """ + Abstract class for providing asynchronous functionality to derived classes. + """ + + @classmethod + async def create_async( + cls, + model: str, + messages: Messages, + *, + loop: AbstractEventLoop = None, + executor: ThreadPoolExecutor = None, + **kwargs + ) -> str: + """ + Asynchronously creates a result based on the given model and messages. + + Args: + cls (type): The class on which this method is called. + model (str): The model to use for creation. + messages (Messages): The messages to process. + loop (AbstractEventLoop, optional): The event loop to use. Defaults to None. + executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + str: The created result as a string. + """ + loop = loop or asyncio.get_running_loop() + + def create_func() -> str: + return "".join(cls.create_completion(model, messages, False, **kwargs)) + + return await asyncio.wait_for( + loop.run_in_executor(executor, create_func), + timeout=kwargs.get("timeout") + ) + + @classmethod + @property + def params(cls) -> str: + """ + Returns the parameters supported by the provider. + + Args: + cls (type): The class on which this property is called. + + Returns: + str: A string listing the supported parameters. + """ + sig = signature( + cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else + cls.create_async if issubclass(cls, AsyncProvider) else + cls.create_completion + ) + + def get_type_name(annotation: type) -> str: + return annotation.__name__ if hasattr(annotation, "__name__") else str(annotation) + + args = "" + for name, param in sig.parameters.items(): + if name in ("self", "kwargs") or (name == "stream" and not cls.supports_stream): + continue + args += f"\n {name}" + args += f": {get_type_name(param.annotation)}" if param.annotation is not Parameter.empty else "" + args += f' = "{param.default}"' if param.default == "" else f" = {param.default}" if param.default is not Parameter.empty else "" + + return f"g4f.Provider.{cls.__name__} supports: ({args}\n)" + + +class AsyncProvider(AbstractProvider): + """ + Provides asynchronous functionality for creating completions. + """ + + @classmethod + def create_completion( + cls, + model: str, + messages: Messages, + stream: bool = False, + **kwargs + ) -> CreateResult: + """ + Creates a completion result synchronously. + + Args: + cls (type): The class on which this method is called. + model (str): The model to use for creation. + messages (Messages): The messages to process. + stream (bool): Indicates whether to stream the results. Defaults to False. + loop (AbstractEventLoop, optional): The event loop to use. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + CreateResult: The result of the completion creation. + """ + get_running_loop() + yield asyncio.run(cls.create_async(model, messages, **kwargs)) + + @staticmethod + @abstractmethod + async def create_async( + model: str, + messages: Messages, + **kwargs + ) -> str: + """ + Abstract method for creating asynchronous results. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process. + **kwargs: Additional keyword arguments. + + Raises: + NotImplementedError: If this method is not overridden in derived classes. + + Returns: + str: The created result as a string. + """ + raise NotImplementedError() + + +class AsyncGeneratorProvider(AsyncProvider): + """ + Provides asynchronous generator functionality for streaming results. + """ + supports_stream = True + + @classmethod + def create_completion( + cls, + model: str, + messages: Messages, + stream: bool = True, + **kwargs + ) -> CreateResult: + """ + Creates a streaming completion result synchronously. + + Args: + cls (type): The class on which this method is called. + model (str): The model to use for creation. + messages (Messages): The messages to process. + stream (bool): Indicates whether to stream the results. Defaults to True. + loop (AbstractEventLoop, optional): The event loop to use. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + CreateResult: The result of the streaming completion creation. + """ + loop = get_running_loop() + new_loop = False + if not loop: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + new_loop = True + + generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) + gen = generator.__aiter__() + + # Fix for RuntimeError: async generator ignored GeneratorExit + async def await_callback(callback): + return await callback() + + try: + while True: + yield loop.run_until_complete(await_callback(gen.__anext__)) + except StopAsyncIteration: + ... + # Fix for: ResourceWarning: unclosed event loop + finally: + if new_loop: + loop.close() + asyncio.set_event_loop(None) + + @classmethod + async def create_async( + cls, + model: str, + messages: Messages, + **kwargs + ) -> str: + """ + Asynchronously creates a result from a generator. + + Args: + cls (type): The class on which this method is called. + model (str): The model to use for creation. + messages (Messages): The messages to process. + **kwargs: Additional keyword arguments. + + Returns: + str: The created result as a string. + """ + return "".join([ + chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs) + if not isinstance(chunk, Exception) + ]) + + @staticmethod + @abstractmethod + async def create_async_generator( + model: str, + messages: Messages, + stream: bool = True, + **kwargs + ) -> AsyncResult: + """ + Abstract method for creating an asynchronous generator. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process. + stream (bool): Indicates whether to stream the results. Defaults to True. + **kwargs: Additional keyword arguments. + + Raises: + NotImplementedError: If this method is not overridden in derived classes. + + Returns: + AsyncResult: An asynchronous generator yielding results. + """ + raise NotImplementedError() + +class ProviderModelMixin: + default_model: str + models: list[str] = [] + model_aliases: dict[str, str] = {} + + @classmethod + def get_models(cls) -> list[str]: + return cls.models + + @classmethod + def get_model(cls, model: str) -> str: + if not model: + model = cls.default_model + elif model in cls.model_aliases: + model = cls.model_aliases[model] + elif model not in cls.get_models(): + raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}") + debug.last_model = model + return model
\ No newline at end of file diff --git a/g4f/providers/create_images.py b/g4f/providers/create_images.py new file mode 100644 index 00000000..29a2a041 --- /dev/null +++ b/g4f/providers/create_images.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import re +import asyncio + +from .. import debug +from ..typing import CreateResult, Messages +from .types import BaseProvider, ProviderType + +system_message = """ +You can generate images, pictures, photos or img with the DALL-E 3 image generator. +To generate an image with a prompt, do this: + +<img data-prompt=\"keywords for the image\"> + +Never use own image links. Don't wrap it in backticks. +It is important to use a only a img tag with a prompt. + +<img data-prompt=\"image caption\"> +""" + +class CreateImagesProvider(BaseProvider): + """ + Provider class for creating images based on text prompts. + + This provider handles image creation requests embedded within message content, + using provided image creation functions. + + Attributes: + provider (ProviderType): The underlying provider to handle non-image related tasks. + create_images (callable): A function to create images synchronously. + create_images_async (callable): A function to create images asynchronously. + system_message (str): A message that explains the image creation capability. + include_placeholder (bool): Flag to determine whether to include the image placeholder in the output. + __name__ (str): Name of the provider. + url (str): URL of the provider. + working (bool): Indicates if the provider is operational. + supports_stream (bool): Indicates if the provider supports streaming. + """ + + def __init__( + self, + provider: ProviderType, + create_images: callable, + create_async: callable, + system_message: str = system_message, + include_placeholder: bool = True + ) -> None: + """ + Initializes the CreateImagesProvider. + + Args: + provider (ProviderType): The underlying provider. + create_images (callable): Function to create images synchronously. + create_async (callable): Function to create images asynchronously. + system_message (str, optional): System message to be prefixed to messages. Defaults to a predefined message. + include_placeholder (bool, optional): Whether to include image placeholders in the output. Defaults to True. + """ + self.provider = provider + self.create_images = create_images + self.create_images_async = create_async + self.system_message = system_message + self.include_placeholder = include_placeholder + self.__name__ = provider.__name__ + self.url = provider.url + self.working = provider.working + self.supports_stream = provider.supports_stream + + def create_completion( + self, + model: str, + messages: Messages, + stream: bool = False, + **kwargs + ) -> CreateResult: + """ + Creates a completion result, processing any image creation prompts found within the messages. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process, which may contain image prompts. + stream (bool, optional): Indicates whether to stream the results. Defaults to False. + **kwargs: Additional keywordarguments for the provider. + + Yields: + CreateResult: Yields chunks of the processed messages, including image data if applicable. + + Note: + This method processes messages to detect image creation prompts. When such a prompt is found, + it calls the synchronous image creation function and includes the resulting image in the output. + """ + messages.insert(0, {"role": "system", "content": self.system_message}) + buffer = "" + for chunk in self.provider.create_completion(model, messages, stream, **kwargs): + if isinstance(chunk, str) and buffer or "<" in chunk: + buffer += chunk + if ">" in buffer: + match = re.search(r'<img data-prompt="(.*?)">', buffer) + if match: + placeholder, prompt = match.group(0), match.group(1) + start, append = buffer.split(placeholder, 1) + if start: + yield start + if self.include_placeholder: + yield placeholder + if debug.logging: + print(f"Create images with prompt: {prompt}") + yield from self.create_images(prompt) + if append: + yield append + else: + yield buffer + buffer = "" + else: + yield chunk + + async def create_async( + self, + model: str, + messages: Messages, + **kwargs + ) -> str: + """ + Asynchronously creates a response, processing any image creation prompts found within the messages. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process, which may contain image prompts. + **kwargs: Additional keyword arguments for the provider. + + Returns: + str: The processed response string, including asynchronously generated image data if applicable. + + Note: + This method processes messages to detect image creation prompts. When such a prompt is found, + it calls the asynchronous image creation function and includes the resulting image in the output. + """ + messages.insert(0, {"role": "system", "content": self.system_message}) + response = await self.provider.create_async(model, messages, **kwargs) + matches = re.findall(r'(<img data-prompt="(.*?)">)', response) + results = [] + placeholders = [] + for placeholder, prompt in matches: + if placeholder not in placeholders: + if debug.logging: + print(f"Create images with prompt: {prompt}") + results.append(self.create_images_async(prompt)) + placeholders.append(placeholder) + results = await asyncio.gather(*results) + for idx, result in enumerate(results): + placeholder = placeholder[idx] + if self.include_placeholder: + result = placeholder + result + response = response.replace(placeholder, result) + return response
\ No newline at end of file diff --git a/g4f/providers/helper.py b/g4f/providers/helper.py new file mode 100644 index 00000000..49d033d1 --- /dev/null +++ b/g4f/providers/helper.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import random +import secrets +import string +from aiohttp import BaseConnector + +from ..typing import Messages, Optional +from ..errors import MissingRequirementsError + +def format_prompt(messages: Messages, add_special_tokens=False) -> str: + """ + Format a series of messages into a single string, optionally adding special tokens. + + Args: + messages (Messages): A list of message dictionaries, each containing 'role' and 'content'. + add_special_tokens (bool): Whether to add special formatting tokens. + + Returns: + str: A formatted string containing all messages. + """ + if not add_special_tokens and len(messages) <= 1: + return messages[0]["content"] + formatted = "\n".join([ + f'{message["role"].capitalize()}: {message["content"]}' + for message in messages + ]) + return f"{formatted}\nAssistant:" + +def get_random_string(length: int = 10) -> str: + """ + Generate a random string of specified length, containing lowercase letters and digits. + + Args: + length (int, optional): Length of the random string to generate. Defaults to 10. + + Returns: + str: A random string of the specified length. + """ + return ''.join( + random.choice(string.ascii_lowercase + string.digits) + for _ in range(length) + ) + +def get_random_hex() -> str: + """ + Generate a random hexadecimal string of a fixed length. + + Returns: + str: A random hexadecimal string of 32 characters (16 bytes). + """ + return secrets.token_hex(16).zfill(32) + +def get_connector(connector: BaseConnector = None, proxy: str = None) -> Optional[BaseConnector]: + if proxy and not connector: + try: + from aiohttp_socks import ProxyConnector + connector = ProxyConnector.from_url(proxy) + except ImportError: + raise MissingRequirementsError('Install "aiohttp_socks" package for proxy support') + return connector
\ No newline at end of file diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py new file mode 100644 index 00000000..a7ab2881 --- /dev/null +++ b/g4f/providers/retry_provider.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import asyncio +import random + +from ..typing import CreateResult, Messages +from .types import BaseRetryProvider +from .. import debug +from ..errors import RetryProviderError, RetryNoProviderError + +class RetryProvider(BaseRetryProvider): + """ + A provider class to handle retries for creating completions with different providers. + + Attributes: + providers (list): A list of provider instances. + shuffle (bool): A flag indicating whether to shuffle providers before use. + exceptions (dict): A dictionary to store exceptions encountered during retries. + last_provider (BaseProvider): The last provider that was used. + """ + + def create_completion( + self, + model: str, + messages: Messages, + stream: bool = False, + **kwargs + ) -> CreateResult: + """ + Create a completion using available providers, with an option to stream the response. + + Args: + model (str): The model to be used for completion. + messages (Messages): The messages to be used for generating completion. + stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False. + + Yields: + CreateResult: Tokens or results from the completion. + + Raises: + Exception: Any exception encountered during the completion process. + """ + providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers + if self.shuffle: + random.shuffle(providers) + + self.exceptions = {} + started: bool = False + for provider in providers: + self.last_provider = provider + try: + if debug.logging: + print(f"Using {provider.__name__} provider") + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token + started = True + if started: + return + except Exception as e: + self.exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + if started: + raise e + + self.raise_exceptions() + + async def create_async( + self, + model: str, + messages: Messages, + **kwargs + ) -> str: + """ + Asynchronously create a completion using available providers. + + Args: + model (str): The model to be used for completion. + messages (Messages): The messages to be used for generating completion. + + Returns: + str: The result of the asynchronous completion. + + Raises: + Exception: Any exception encountered during the asynchronous completion process. + """ + providers = self.providers + if self.shuffle: + random.shuffle(providers) + + self.exceptions = {} + for provider in providers: + self.last_provider = provider + try: + return await asyncio.wait_for( + provider.create_async(model, messages, **kwargs), + timeout=kwargs.get("timeout", 60) + ) + except Exception as e: + self.exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + + self.raise_exceptions() + + def raise_exceptions(self) -> None: + """ + Raise a combined exception if any occurred during retries. + + Raises: + RetryProviderError: If any provider encountered an exception. + RetryNoProviderError: If no provider is found. + """ + if self.exceptions: + raise RetryProviderError("RetryProvider failed:\n" + "\n".join([ + f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items() + ])) + + raise RetryNoProviderError("No provider found")
\ No newline at end of file diff --git a/g4f/providers/types.py b/g4f/providers/types.py new file mode 100644 index 00000000..7b11ec43 --- /dev/null +++ b/g4f/providers/types.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Union, List, Dict, Type +from ..typing import Messages, CreateResult + +class BaseProvider(ABC): + """ + Abstract base class for a provider. + + Attributes: + url (str): URL of the provider. + working (bool): Indicates if the provider is currently working. + needs_auth (bool): Indicates if the provider needs authentication. + supports_stream (bool): Indicates if the provider supports streaming. + supports_gpt_35_turbo (bool): Indicates if the provider supports GPT-3.5 Turbo. + supports_gpt_4 (bool): Indicates if the provider supports GPT-4. + supports_message_history (bool): Indicates if the provider supports message history. + params (str): List parameters for the provider. + """ + + url: str = None + working: bool = False + needs_auth: bool = False + supports_stream: bool = False + supports_gpt_35_turbo: bool = False + supports_gpt_4: bool = False + supports_message_history: bool = False + params: str + + @classmethod + @abstractmethod + def create_completion( + cls, + model: str, + messages: Messages, + stream: bool, + **kwargs + ) -> CreateResult: + """ + Create a completion with the given parameters. + + Args: + model (str): The model to use. + messages (Messages): The messages to process. + stream (bool): Whether to use streaming. + **kwargs: Additional keyword arguments. + + Returns: + CreateResult: The result of the creation process. + """ + raise NotImplementedError() + + @classmethod + @abstractmethod + async def create_async( + cls, + model: str, + messages: Messages, + **kwargs + ) -> str: + """ + Asynchronously create a completion with the given parameters. + + Args: + model (str): The model to use. + messages (Messages): The messages to process. + **kwargs: Additional keyword arguments. + + Returns: + str: The result of the creation process. + """ + raise NotImplementedError() + + @classmethod + def get_dict(cls) -> Dict[str, str]: + """ + Get a dictionary representation of the provider. + + Returns: + Dict[str, str]: A dictionary with provider's details. + """ + return {'name': cls.__name__, 'url': cls.url} + +class BaseRetryProvider(BaseProvider): + """ + Base class for a provider that implements retry logic. + + Attributes: + providers (List[Type[BaseProvider]]): List of providers to use for retries. + shuffle (bool): Whether to shuffle the providers list. + exceptions (Dict[str, Exception]): Dictionary of exceptions encountered. + last_provider (Type[BaseProvider]): The last provider used. + """ + + __name__: str = "RetryProvider" + supports_stream: bool = True + + def __init__( + self, + providers: List[Type[BaseProvider]], + shuffle: bool = True + ) -> None: + """ + Initialize the BaseRetryProvider. + + Args: + providers (List[Type[BaseProvider]]): List of providers to use. + shuffle (bool): Whether to shuffle the providers list. + """ + self.providers = providers + self.shuffle = shuffle + self.working = True + self.exceptions: Dict[str, Exception] = {} + self.last_provider: Type[BaseProvider] = None + +ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
\ No newline at end of file |