diff options
Diffstat (limited to 'g4f/providers/retry_provider.py')
-rw-r--r-- | g4f/providers/retry_provider.py | 114 |
1 files changed, 72 insertions, 42 deletions
diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py index 52f473e9..d64e8471 100644 --- a/g4f/providers/retry_provider.py +++ b/g4f/providers/retry_provider.py @@ -12,46 +12,40 @@ class RetryProvider(BaseRetryProvider): def __init__( self, providers: List[Type[BaseProvider]], - shuffle: bool = True + shuffle: bool = True, + single_provider_retry: bool = False, + max_retries: int = 3, ) -> None: """ Initialize the BaseRetryProvider. - Args: providers (List[Type[BaseProvider]]): List of providers to use. shuffle (bool): Whether to shuffle the providers list. + single_provider_retry (bool): Whether to retry a single provider if it fails. + max_retries (int): Maximum number of retries for a single provider. """ self.providers = providers self.shuffle = shuffle + self.single_provider_retry = single_provider_retry + self.max_retries = max_retries self.working = True self.last_provider: Type[BaseProvider] = None - """ - A provider class to handle retries for creating completions with different providers. - - Attributes: - providers (list): A list of provider instances. - shuffle (bool): A flag indicating whether to shuffle providers before use. - last_provider (BaseProvider): The last provider that was used. - """ def create_completion( self, model: str, messages: Messages, stream: bool = False, - **kwargs + **kwargs, ) -> CreateResult: """ Create a completion using available providers, with an option to stream the response. - Args: model (str): The model to be used for completion. messages (Messages): The messages to be used for generating completion. stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False. - Yields: CreateResult: Tokens or results from the completion. - Raises: Exception: Any exception encountered during the completion process. """ @@ -61,22 +55,42 @@ class RetryProvider(BaseRetryProvider): exceptions = {} started: bool = False - for provider in providers: + + if self.single_provider_retry and len(providers) == 1: + provider = providers[0] self.last_provider = provider - try: - if debug.logging: - print(f"Using {provider.__name__} provider") - for token in provider.create_completion(model, messages, stream, **kwargs): - yield token + for attempt in range(self.max_retries): + try: + if debug.logging: + print(f"Using {provider.__name__} provider (attempt {attempt + 1})") + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token started = True - if started: - return - except Exception as e: - exceptions[provider.__name__] = e - if debug.logging: - print(f"{provider.__name__}: {e.__class__.__name__}: {e}") - if started: - raise e + if started: + return + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + if started: + raise e + else: + for provider in providers: + self.last_provider = provider + try: + if debug.logging: + print(f"Using {provider.__name__} provider") + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token + started = True + if started: + return + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + if started: + raise e raise_exceptions(exceptions) @@ -84,18 +98,15 @@ class RetryProvider(BaseRetryProvider): self, model: str, messages: Messages, - **kwargs + **kwargs, ) -> str: """ Asynchronously create a completion using available providers. - Args: model (str): The model to be used for completion. messages (Messages): The messages to be used for generating completion. - Returns: str: The result of the asynchronous completion. - Raises: Exception: Any exception encountered during the asynchronous completion process. """ @@ -104,17 +115,36 @@ class RetryProvider(BaseRetryProvider): random.shuffle(providers) exceptions = {} - for provider in providers: + + if self.single_provider_retry and len(providers) == 1: + provider = providers[0] self.last_provider = provider - try: - return await asyncio.wait_for( - provider.create_async(model, messages, **kwargs), - timeout=kwargs.get("timeout", 60) - ) - except Exception as e: - exceptions[provider.__name__] = e - if debug.logging: - print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + for attempt in range(self.max_retries): + try: + if debug.logging: + print(f"Using {provider.__name__} provider (attempt {attempt + 1})") + return await asyncio.wait_for( + provider.create_async(model, messages, **kwargs), + timeout=kwargs.get("timeout", 60), + ) + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + else: + for provider in providers: + self.last_provider = provider + try: + if debug.logging: + print(f"Using {provider.__name__} provider") + return await asyncio.wait_for( + provider.create_async(model, messages, **kwargs), + timeout=kwargs.get("timeout", 60), + ) + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") raise_exceptions(exceptions) |