diff options
author | H Lohaus <hlohaus@users.noreply.github.com> | 2024-01-14 15:32:51 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-14 15:32:51 +0100 |
commit | 1ca80ed48b55d6462b4bd445e66d4f7de7442c2b (patch) | |
tree | 05a94b53b83461b8249de965e093b4fd3722e2d1 /g4f/Provider/base_provider.py | |
parent | Merge pull request #1466 from hlohaus/upp (diff) | |
parent | Change doctypes style to Google (diff) | |
download | gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar.gz gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar.bz2 gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar.lz gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar.xz gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar.zst gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.zip |
Diffstat (limited to 'g4f/Provider/base_provider.py')
-rw-r--r-- | g4f/Provider/base_provider.py | 198 |
1 files changed, 133 insertions, 65 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index e7e88841..fd92d17a 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -1,28 +1,29 @@ from __future__ import annotations - import sys import asyncio -from asyncio import AbstractEventLoop +from asyncio import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor -from abc import abstractmethod -from inspect import signature, Parameter -from .helper import get_event_loop, get_cookies, format_prompt -from ..typing import CreateResult, AsyncResult, Messages -from ..base_provider import BaseProvider +from abc import abstractmethod +from inspect import signature, Parameter +from .helper import get_event_loop, get_cookies, format_prompt +from ..typing import CreateResult, AsyncResult, Messages +from ..base_provider import BaseProvider if sys.version_info < (3, 10): NoneType = type(None) else: from types import NoneType -# Change event loop policy on windows for curl_cffi +# Set Windows event loop policy for better compatibility with asyncio and curl_cffi if sys.platform == 'win32': - if isinstance( - asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy - ): + if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) class AbstractProvider(BaseProvider): + """ + Abstract class for providing asynchronous functionality to derived classes. + """ + @classmethod async def create_async( cls, @@ -33,62 +34,67 @@ class AbstractProvider(BaseProvider): executor: ThreadPoolExecutor = None, **kwargs ) -> str: - if not loop: - loop = get_event_loop() + """ + Asynchronously creates a result based on the given model and messages. + + Args: + cls (type): The class on which this method is called. + model (str): The model to use for creation. + messages (Messages): The messages to process. + loop (AbstractEventLoop, optional): The event loop to use. Defaults to None. + executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + str: The created result as a string. + """ + loop = loop or get_event_loop() def create_func() -> str: - return "".join(cls.create_completion( - model, - messages, - False, - **kwargs - )) + return "".join(cls.create_completion(model, messages, False, **kwargs)) return await asyncio.wait_for( - loop.run_in_executor( - executor, - create_func - ), + loop.run_in_executor(executor, create_func), timeout=kwargs.get("timeout", 0) ) @classmethod @property def params(cls) -> str: - if issubclass(cls, AsyncGeneratorProvider): - sig = signature(cls.create_async_generator) - elif issubclass(cls, AsyncProvider): - sig = signature(cls.create_async) - else: - sig = signature(cls.create_completion) + """ + Returns the parameters supported by the provider. + + Args: + cls (type): The class on which this property is called. + + Returns: + str: A string listing the supported parameters. + """ + sig = signature( + cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else + cls.create_async if issubclass(cls, AsyncProvider) else + cls.create_completion + ) def get_type_name(annotation: type) -> str: - if hasattr(annotation, "__name__"): - annotation = annotation.__name__ - elif isinstance(annotation, NoneType): - annotation = "None" - return str(annotation) - + return annotation.__name__ if hasattr(annotation, "__name__") else str(annotation) + args = "" for name, param in sig.parameters.items(): - if name in ("self", "kwargs"): - continue - if name == "stream" and not cls.supports_stream: + if name in ("self", "kwargs") or (name == "stream" and not cls.supports_stream): continue - if args: - args += ", " - args += "\n " + name - if name != "model" and param.annotation is not Parameter.empty: - args += f": {get_type_name(param.annotation)}" - if param.default == "": - args += ' = ""' - elif param.default is not Parameter.empty: - args += f" = {param.default}" + args += f"\n {name}" + args += f": {get_type_name(param.annotation)}" if param.annotation is not Parameter.empty else "" + args += f' = "{param.default}"' if param.default == "" else f" = {param.default}" if param.default is not Parameter.empty else "" return f"g4f.Provider.{cls.__name__} supports: ({args}\n)" class AsyncProvider(AbstractProvider): + """ + Provides asynchronous functionality for creating completions. + """ + @classmethod def create_completion( cls, @@ -99,8 +105,21 @@ class AsyncProvider(AbstractProvider): loop: AbstractEventLoop = None, **kwargs ) -> CreateResult: - if not loop: - loop = get_event_loop() + """ + Creates a completion result synchronously. + + Args: + cls (type): The class on which this method is called. + model (str): The model to use for creation. + messages (Messages): The messages to process. + stream (bool): Indicates whether to stream the results. Defaults to False. + loop (AbstractEventLoop, optional): The event loop to use. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + CreateResult: The result of the completion creation. + """ + loop = loop or get_event_loop() coro = cls.create_async(model, messages, **kwargs) yield loop.run_until_complete(coro) @@ -111,10 +130,27 @@ class AsyncProvider(AbstractProvider): messages: Messages, **kwargs ) -> str: + """ + Abstract method for creating asynchronous results. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process. + **kwargs: Additional keyword arguments. + + Raises: + NotImplementedError: If this method is not overridden in derived classes. + + Returns: + str: The created result as a string. + """ raise NotImplementedError() class AsyncGeneratorProvider(AsyncProvider): + """ + Provides asynchronous generator functionality for streaming results. + """ supports_stream = True @classmethod @@ -127,15 +163,24 @@ class AsyncGeneratorProvider(AsyncProvider): loop: AbstractEventLoop = None, **kwargs ) -> CreateResult: - if not loop: - loop = get_event_loop() - generator = cls.create_async_generator( - model, - messages, - stream=stream, - **kwargs - ) + """ + Creates a streaming completion result synchronously. + + Args: + cls (type): The class on which this method is called. + model (str): The model to use for creation. + messages (Messages): The messages to process. + stream (bool): Indicates whether to stream the results. Defaults to True. + loop (AbstractEventLoop, optional): The event loop to use. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + CreateResult: The result of the streaming completion creation. + """ + loop = loop or get_event_loop() + generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) gen = generator.__aiter__() + while True: try: yield loop.run_until_complete(gen.__anext__()) @@ -149,21 +194,44 @@ class AsyncGeneratorProvider(AsyncProvider): messages: Messages, **kwargs ) -> str: + """ + Asynchronously creates a result from a generator. + + Args: + cls (type): The class on which this method is called. + model (str): The model to use for creation. + messages (Messages): The messages to process. + **kwargs: Additional keyword arguments. + + Returns: + str: The created result as a string. + """ return "".join([ - chunk async for chunk in cls.create_async_generator( - model, - messages, - stream=False, - **kwargs - ) if not isinstance(chunk, Exception) + chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs) + if not isinstance(chunk, Exception) ]) @staticmethod @abstractmethod - def create_async_generator( + async def create_async_generator( model: str, messages: Messages, stream: bool = True, **kwargs ) -> AsyncResult: - raise NotImplementedError() + """ + Abstract method for creating an asynchronous generator. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process. + stream (bool): Indicates whether to stream the results. Defaults to True. + **kwargs: Additional keyword arguments. + + Raises: + NotImplementedError: If this method is not overridden in derived classes. + + Returns: + AsyncResult: An asynchronous generator yielding results. + """ + raise NotImplementedError()
\ No newline at end of file |