summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/retry_provider.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/retry_provider.py58
1 files changed, 49 insertions, 9 deletions
diff --git a/g4f/Provider/retry_provider.py b/g4f/Provider/retry_provider.py
index 4d3e77ac..9cc026fc 100644
--- a/g4f/Provider/retry_provider.py
+++ b/g4f/Provider/retry_provider.py
@@ -7,8 +7,17 @@ from ..base_provider import BaseRetryProvider
from .. import debug
from ..errors import RetryProviderError, RetryNoProviderError
-
class RetryProvider(BaseRetryProvider):
+ """
+ A provider class to handle retries for creating completions with different providers.
+
+ Attributes:
+ providers (list): A list of provider instances.
+ shuffle (bool): A flag indicating whether to shuffle providers before use.
+ exceptions (dict): A dictionary to store exceptions encountered during retries.
+ last_provider (BaseProvider): The last provider that was used.
+ """
+
def create_completion(
self,
model: str,
@@ -16,10 +25,21 @@ class RetryProvider(BaseRetryProvider):
stream: bool = False,
**kwargs
) -> CreateResult:
- if stream:
- providers = [provider for provider in self.providers if provider.supports_stream]
- else:
- providers = self.providers
+ """
+ Create a completion using available providers, with an option to stream the response.
+
+ Args:
+ model (str): The model to be used for completion.
+ messages (Messages): The messages to be used for generating completion.
+ stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
+
+ Yields:
+ CreateResult: Tokens or results from the completion.
+
+ Raises:
+ Exception: Any exception encountered during the completion process.
+ """
+ providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers
if self.shuffle:
random.shuffle(providers)
@@ -50,10 +70,23 @@ class RetryProvider(BaseRetryProvider):
messages: Messages,
**kwargs
) -> str:
+ """
+ Asynchronously create a completion using available providers.
+
+ Args:
+ model (str): The model to be used for completion.
+ messages (Messages): The messages to be used for generating completion.
+
+ Returns:
+ str: The result of the asynchronous completion.
+
+ Raises:
+ Exception: Any exception encountered during the asynchronous completion process.
+ """
providers = self.providers
if self.shuffle:
random.shuffle(providers)
-
+
self.exceptions = {}
for provider in providers:
self.last_provider = provider
@@ -66,13 +99,20 @@ class RetryProvider(BaseRetryProvider):
self.exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
-
+
self.raise_exceptions()
-
+
def raise_exceptions(self) -> None:
+ """
+ Raise a combined exception if any occurred during retries.
+
+ Raises:
+ RetryProviderError: If any provider encountered an exception.
+ RetryNoProviderError: If no provider is found.
+ """
if self.exceptions:
raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items()
]))
-
+
raise RetryNoProviderError("No provider found") \ No newline at end of file