From c31f5435c43ede7847dae0f3ed007357e7ff198c Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Thu, 28 Nov 2024 17:46:46 +0100 Subject: Fix api with default providers, add unittests for RetryProvider --- g4f/client/__init__.py | 2 +- g4f/client/service.py | 4 +- g4f/providers/base_provider.py | 6 +- g4f/providers/retry_provider.py | 145 +++++++++++++++++----------------------- 4 files changed, 69 insertions(+), 88 deletions(-) (limited to 'g4f') diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index b00c5a65..ea47ec73 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -474,7 +474,7 @@ class AsyncCompletions: **kwargs ) - if not isinstance(response, AsyncIterator): + if not hasattr(response, "__aiter__"): response = to_async_iterator(response) response = async_iter_response(response, stream, response_format, max_tokens, stop) response = async_iter_append_model_and_provider(response) diff --git a/g4f/client/service.py b/g4f/client/service.py index 45230c79..80dc70df 100644 --- a/g4f/client/service.py +++ b/g4f/client/service.py @@ -7,14 +7,14 @@ from ..errors import ProviderNotFoundError, ModelNotFoundError, ProviderNotWorki from ..models import Model, ModelUtils, default from ..Provider import ProviderUtils from ..providers.types import BaseRetryProvider, ProviderType -from ..providers.retry_provider import IterProvider +from ..providers.retry_provider import IterListProvider def convert_to_provider(provider: str) -> ProviderType: if " " in provider: provider_list = [ProviderUtils.convert[p] for p in provider.split() if p in ProviderUtils.convert] if not provider_list: raise ProviderNotFoundError(f'Providers not found: {provider}') - provider = IterProvider(provider_list) + provider = IterListProvider(provider_list, False) elif provider in ProviderUtils.convert: provider = ProviderUtils.convert[provider] elif provider: diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py index 80a9e09d..e8a47154 100644 --- a/g4f/providers/base_provider.py +++ b/g4f/providers/base_provider.py @@ -57,7 +57,9 @@ class AbstractProvider(BaseProvider): loop = loop or asyncio.get_running_loop() def create_func() -> str: - return "".join(cls.create_completion(model, messages, False, **kwargs)) + chunks = [str(chunk) for chunk in cls.create_completion(model, messages, False, **kwargs) if chunk] + if chunks: + return "".join(chunks) return await asyncio.wait_for( loop.run_in_executor(executor, create_func), @@ -205,7 +207,7 @@ class AsyncGeneratorProvider(AsyncProvider): """ return "".join([ str(chunk) async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs) - if not isinstance(chunk, (Exception, FinishReason, BaseConversation, SynthesizeData)) + if chunk and not isinstance(chunk, (Exception, FinishReason, BaseConversation, SynthesizeData)) ]) @staticmethod diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py index efcae375..92386955 100644 --- a/g4f/providers/retry_provider.py +++ b/g4f/providers/retry_provider.py @@ -8,6 +8,8 @@ from .types import BaseProvider, BaseRetryProvider, ProviderType from .. import debug from ..errors import RetryProviderError, RetryNoProviderError +DEFAULT_TIMEOUT = 60 + class IterListProvider(BaseRetryProvider): def __init__( self, @@ -50,12 +52,12 @@ class IterListProvider(BaseRetryProvider): for provider in self.get_providers(stream): self.last_provider = provider + debug.log(f"Using {provider.__name__} provider") try: - if debug.logging: - print(f"Using {provider.__name__} provider") - for token in provider.create_completion(model, messages, stream, **kwargs): - yield token - started = True + for chunk in provider.create_completion(model, messages, stream, **kwargs): + if chunk: + yield chunk + started = True if started: return except Exception as e: @@ -87,13 +89,14 @@ class IterListProvider(BaseRetryProvider): for provider in self.get_providers(False): self.last_provider = provider + debug.log(f"Using {provider.__name__} provider") try: - if debug.logging: - print(f"Using {provider.__name__} provider") - return await asyncio.wait_for( + chunk = await asyncio.wait_for( provider.create_async(model, messages, **kwargs), - timeout=kwargs.get("timeout", 60), + timeout=kwargs.get("timeout", DEFAULT_TIMEOUT), ) + if chunk: + return chunk except Exception as e: exceptions[provider.__name__] = e if debug.logging: @@ -119,16 +122,21 @@ class IterListProvider(BaseRetryProvider): for provider in self.get_providers(stream): self.last_provider = provider + debug.log(f"Using {provider.__name__} provider") try: - if debug.logging: - print(f"Using {provider.__name__} provider") if not stream: - yield await provider.create_async(model, messages, **kwargs) - started = True - elif hasattr(provider, "create_async_generator"): - async for token in provider.create_async_generator(model, messages, stream=stream, **kwargs): - yield token + chunk = await asyncio.wait_for( + provider.create_async(model, messages, **kwargs), + timeout=kwargs.get("timeout", DEFAULT_TIMEOUT), + ) + if chunk: + yield chunk started = True + elif hasattr(provider, "create_async_generator"): + async for chunk in provider.create_async_generator(model, messages, stream=stream, **kwargs): + if chunk: + yield chunk + started = True else: for token in provider.create_completion(model, messages, stream, **kwargs): yield token @@ -137,8 +145,7 @@ class IterListProvider(BaseRetryProvider): return except Exception as e: exceptions[provider.__name__] = e - if debug.logging: - print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}") if started: raise e @@ -243,76 +250,48 @@ class RetryProvider(IterListProvider): else: return await super().create_async(model, messages, **kwargs) -class IterProvider(BaseRetryProvider): - __name__ = "IterProvider" - - def __init__( - self, - providers: List[BaseProvider], - ) -> None: - providers.reverse() - self.providers: List[BaseProvider] = providers - self.working: bool = True - self.last_provider: BaseProvider = None - - def create_completion( - self, - model: str, - messages: Messages, - stream: bool = False, - **kwargs - ) -> CreateResult: - exceptions: dict = {} - started: bool = False - for provider in self.iter_providers(): - if stream and not provider.supports_stream: - continue - try: - for token in provider.create_completion(model, messages, stream, **kwargs): - yield token - started = True - if started: - return - except Exception as e: - exceptions[provider.__name__] = e - if debug.logging: - print(f"{provider.__name__}: {e.__class__.__name__}: {e}") - if started: - raise e - raise_exceptions(exceptions) - - async def create_async( + async def create_async_generator( self, model: str, messages: Messages, + stream: bool = True, **kwargs - ) -> str: - exceptions: dict = {} - for provider in self.iter_providers(): - try: - return await asyncio.wait_for( - provider.create_async(model, messages, **kwargs), - timeout=kwargs.get("timeout", 60) - ) - except Exception as e: - exceptions[provider.__name__] = e - if debug.logging: - print(f"{provider.__name__}: {e.__class__.__name__}: {e}") - raise_exceptions(exceptions) + ) -> AsyncResult: + exceptions = {} + started = False - def iter_providers(self) -> Iterator[BaseProvider]: - used_provider = [] - try: - while self.providers: - provider = self.providers.pop() - used_provider.append(provider) - self.last_provider = provider - if debug.logging: - print(f"Using {provider.__name__} provider") - yield provider - finally: - used_provider.reverse() - self.providers = [*used_provider, *self.providers] + if self.single_provider_retry: + provider = self.providers[0] + self.last_provider = provider + for attempt in range(self.max_retries): + try: + debug.log(f"Using {provider.__name__} provider (attempt {attempt + 1})") + if not stream: + chunk = await asyncio.wait_for( + provider.create_async(model, messages, **kwargs), + timeout=kwargs.get("timeout", DEFAULT_TIMEOUT), + ) + if chunk: + started = True + elif hasattr(provider, "create_async_generator"): + async for chunk in provider.create_async_generator(model, messages, stream=stream, **kwargs): + if chunk: + yield chunk + started = True + else: + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token + started = True + if started: + return + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + raise_exceptions(exceptions) + else: + async for chunk in super().create_async_generator(model, messages, stream, **kwargs): + yield chunk def raise_exceptions(exceptions: dict) -> None: """ -- cgit v1.2.3