diff options
Diffstat (limited to 'g4f')
-rw-r--r-- | g4f/Provider/needs_auth/OpenaiChat.py | 3 | ||||
-rw-r--r-- | g4f/providers/retry_provider.py | 114 |
2 files changed, 75 insertions, 42 deletions
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 3145161a..ff3446ac 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -334,6 +334,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): Raises: RuntimeError: If an error occurs during processing. """ + async with StreamSession( proxies={"all": proxy}, impersonate="chrome", @@ -359,6 +360,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): if debug.logging: print("OpenaiChat: Load default_model failed") print(f"{e.__class__.__name__}: {e}") + arkose_token = None if cls.default_model is None: @@ -582,6 +584,7 @@ this.fetch = async (url, options) => { user_data_dir = user_config_dir("g4f-nodriver") except: user_data_dir = None + browser = await uc.start(user_data_dir=user_data_dir) page = await browser.get("https://chat.openai.com/") while await page.query_selector("#prompt-textarea") is None: diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py index 52f473e9..d64e8471 100644 --- a/g4f/providers/retry_provider.py +++ b/g4f/providers/retry_provider.py @@ -12,46 +12,40 @@ class RetryProvider(BaseRetryProvider): def __init__( self, providers: List[Type[BaseProvider]], - shuffle: bool = True + shuffle: bool = True, + single_provider_retry: bool = False, + max_retries: int = 3, ) -> None: """ Initialize the BaseRetryProvider. - Args: providers (List[Type[BaseProvider]]): List of providers to use. shuffle (bool): Whether to shuffle the providers list. + single_provider_retry (bool): Whether to retry a single provider if it fails. + max_retries (int): Maximum number of retries for a single provider. """ self.providers = providers self.shuffle = shuffle + self.single_provider_retry = single_provider_retry + self.max_retries = max_retries self.working = True self.last_provider: Type[BaseProvider] = None - """ - A provider class to handle retries for creating completions with different providers. - - Attributes: - providers (list): A list of provider instances. - shuffle (bool): A flag indicating whether to shuffle providers before use. - last_provider (BaseProvider): The last provider that was used. - """ def create_completion( self, model: str, messages: Messages, stream: bool = False, - **kwargs + **kwargs, ) -> CreateResult: """ Create a completion using available providers, with an option to stream the response. - Args: model (str): The model to be used for completion. messages (Messages): The messages to be used for generating completion. stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False. - Yields: CreateResult: Tokens or results from the completion. - Raises: Exception: Any exception encountered during the completion process. """ @@ -61,22 +55,42 @@ class RetryProvider(BaseRetryProvider): exceptions = {} started: bool = False - for provider in providers: + + if self.single_provider_retry and len(providers) == 1: + provider = providers[0] self.last_provider = provider - try: - if debug.logging: - print(f"Using {provider.__name__} provider") - for token in provider.create_completion(model, messages, stream, **kwargs): - yield token + for attempt in range(self.max_retries): + try: + if debug.logging: + print(f"Using {provider.__name__} provider (attempt {attempt + 1})") + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token started = True - if started: - return - except Exception as e: - exceptions[provider.__name__] = e - if debug.logging: - print(f"{provider.__name__}: {e.__class__.__name__}: {e}") - if started: - raise e + if started: + return + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + if started: + raise e + else: + for provider in providers: + self.last_provider = provider + try: + if debug.logging: + print(f"Using {provider.__name__} provider") + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token + started = True + if started: + return + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + if started: + raise e raise_exceptions(exceptions) @@ -84,18 +98,15 @@ class RetryProvider(BaseRetryProvider): self, model: str, messages: Messages, - **kwargs + **kwargs, ) -> str: """ Asynchronously create a completion using available providers. - Args: model (str): The model to be used for completion. messages (Messages): The messages to be used for generating completion. - Returns: str: The result of the asynchronous completion. - Raises: Exception: Any exception encountered during the asynchronous completion process. """ @@ -104,17 +115,36 @@ class RetryProvider(BaseRetryProvider): random.shuffle(providers) exceptions = {} - for provider in providers: + + if self.single_provider_retry and len(providers) == 1: + provider = providers[0] self.last_provider = provider - try: - return await asyncio.wait_for( - provider.create_async(model, messages, **kwargs), - timeout=kwargs.get("timeout", 60) - ) - except Exception as e: - exceptions[provider.__name__] = e - if debug.logging: - print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + for attempt in range(self.max_retries): + try: + if debug.logging: + print(f"Using {provider.__name__} provider (attempt {attempt + 1})") + return await asyncio.wait_for( + provider.create_async(model, messages, **kwargs), + timeout=kwargs.get("timeout", 60), + ) + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + else: + for provider in providers: + self.last_provider = provider + try: + if debug.logging: + print(f"Using {provider.__name__} provider") + return await asyncio.wait_for( + provider.create_async(model, messages, **kwargs), + timeout=kwargs.get("timeout", 60), + ) + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") raise_exceptions(exceptions) |