diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/providers/retry_provider.py | 186 |
1 files changed, 146 insertions, 40 deletions
diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py index d64e8471..e2520437 100644 --- a/g4f/providers/retry_provider.py +++ b/g4f/providers/retry_provider.py @@ -3,18 +3,16 @@ from __future__ import annotations import asyncio import random -from ..typing import Type, List, CreateResult, Messages, Iterator +from ..typing import Type, List, CreateResult, Messages, Iterator, AsyncResult from .types import BaseProvider, BaseRetryProvider from .. import debug from ..errors import RetryProviderError, RetryNoProviderError -class RetryProvider(BaseRetryProvider): +class NewBaseRetryProvider(BaseRetryProvider): def __init__( self, providers: List[Type[BaseProvider]], - shuffle: bool = True, - single_provider_retry: bool = False, - max_retries: int = 3, + shuffle: bool = True ) -> None: """ Initialize the BaseRetryProvider. @@ -26,8 +24,6 @@ class RetryProvider(BaseRetryProvider): """ self.providers = providers self.shuffle = shuffle - self.single_provider_retry = single_provider_retry - self.max_retries = max_retries self.working = True self.last_provider: Type[BaseProvider] = None @@ -56,7 +52,146 @@ class RetryProvider(BaseRetryProvider): exceptions = {} started: bool = False + for provider in providers: + self.last_provider = provider + try: + if debug.logging: + print(f"Using {provider.__name__} provider") + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token + started = True + if started: + return + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + if started: + raise e + + raise_exceptions(exceptions) + + async def create_async( + self, + model: str, + messages: Messages, + **kwargs, + ) -> str: + """ + Asynchronously create a completion using available providers. + Args: + model (str): The model to be used for completion. + messages (Messages): The messages to be used for generating completion. + Returns: + str: The result of the asynchronous completion. + Raises: + Exception: Any exception encountered during the asynchronous completion process. + """ + providers = self.providers + if self.shuffle: + random.shuffle(providers) + + exceptions = {} + + for provider in providers: + self.last_provider = provider + try: + if debug.logging: + print(f"Using {provider.__name__} provider") + return await asyncio.wait_for( + provider.create_async(model, messages, **kwargs), + timeout=kwargs.get("timeout", 60), + ) + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + + raise_exceptions(exceptions) + + def get_providers(self, stream: bool): + providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers + if self.shuffle: + random.shuffle(providers) + return providers + + async def create_async_generator( + self, + model: str, + messages: Messages, + stream: bool = True, + **kwargs + ) -> AsyncResult: + exceptions = {} + started: bool = False + + for provider in self.get_providers(stream): + self.last_provider = provider + try: + if debug.logging: + print(f"Using {provider.__name__} provider") + if not stream: + yield await provider.create_async(model, messages, **kwargs) + elif hasattr(provider, "create_async_generator"): + async for token in provider.create_async_generator(model, messages, stream, **kwargs): + yield token + else: + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token + started = True + if started: + return + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + if started: + raise e + + raise_exceptions(exceptions) + +class RetryProvider(NewBaseRetryProvider): + def __init__( + self, + providers: List[Type[BaseProvider]], + shuffle: bool = True, + single_provider_retry: bool = False, + max_retries: int = 3, + ) -> None: + """ + Initialize the BaseRetryProvider. + Args: + providers (List[Type[BaseProvider]]): List of providers to use. + shuffle (bool): Whether to shuffle the providers list. + single_provider_retry (bool): Whether to retry a single provider if it fails. + max_retries (int): Maximum number of retries for a single provider. + """ + super().__init__(providers, shuffle) + self.single_provider_retry = single_provider_retry + self.max_retries = max_retries + + def create_completion( + self, + model: str, + messages: Messages, + stream: bool = False, + **kwargs, + ) -> CreateResult: + """ + Create a completion using available providers, with an option to stream the response. + Args: + model (str): The model to be used for completion. + messages (Messages): The messages to be used for generating completion. + stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False. + Yields: + CreateResult: Tokens or results from the completion. + Raises: + Exception: Any exception encountered during the completion process. + """ + providers = self.get_providers(stream) if self.single_provider_retry and len(providers) == 1: + exceptions = {} + started: bool = False provider = providers[0] self.last_provider = provider for attempt in range(self.max_retries): @@ -74,25 +209,9 @@ class RetryProvider(BaseRetryProvider): print(f"{provider.__name__}: {e.__class__.__name__}: {e}") if started: raise e + raise_exceptions(exceptions) else: - for provider in providers: - self.last_provider = provider - try: - if debug.logging: - print(f"Using {provider.__name__} provider") - for token in provider.create_completion(model, messages, stream, **kwargs): - yield token - started = True - if started: - return - except Exception as e: - exceptions[provider.__name__] = e - if debug.logging: - print(f"{provider.__name__}: {e.__class__.__name__}: {e}") - if started: - raise e - - raise_exceptions(exceptions) + yield from super().create_completion(model, messages, stream, **kwargs) async def create_async( self, @@ -131,22 +250,9 @@ class RetryProvider(BaseRetryProvider): exceptions[provider.__name__] = e if debug.logging: print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + raise_exceptions(exceptions) else: - for provider in providers: - self.last_provider = provider - try: - if debug.logging: - print(f"Using {provider.__name__} provider") - return await asyncio.wait_for( - provider.create_async(model, messages, **kwargs), - timeout=kwargs.get("timeout", 60), - ) - except Exception as e: - exceptions[provider.__name__] = e - if debug.logging: - print(f"{provider.__name__}: {e.__class__.__name__}: {e}") - - raise_exceptions(exceptions) + return await super().create_async(model, messages, **kwargs) class IterProvider(BaseRetryProvider): __name__ = "IterProvider" |