From e9f96ced9c534f313fd2d3b82b2464cd8424281a Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Thu, 21 Sep 2023 20:10:59 +0200 Subject: Add RetryProvider --- g4f/Provider/__init__.py | 6 +++- g4f/Provider/retry_provider.py | 81 ++++++++++++++++++++++++++++++++++++++++++ g4f/__init__.py | 10 ++---- g4f/models.py | 29 +++++++++++---- 4 files changed, 110 insertions(+), 16 deletions(-) create mode 100644 g4f/Provider/retry_provider.py diff --git a/g4f/Provider/__init__.py b/g4f/Provider/__init__.py index 0ca22533..b9ee2544 100644 --- a/g4f/Provider/__init__.py +++ b/g4f/Provider/__init__.py @@ -38,10 +38,14 @@ from .FastGpt import FastGpt from .V50 import V50 from .Wuguokai import Wuguokai -from .base_provider import BaseProvider, AsyncProvider, AsyncGeneratorProvider +from .base_provider import BaseProvider, AsyncProvider, AsyncGeneratorProvider +from .retry_provider import RetryProvider __all__ = [ 'BaseProvider', + 'AsyncProvider', + 'AsyncGeneratorProvider', + 'RetryProvider', 'Acytoo', 'Aichat', 'Ails', diff --git a/g4f/Provider/retry_provider.py b/g4f/Provider/retry_provider.py new file mode 100644 index 00000000..e1a9cd1f --- /dev/null +++ b/g4f/Provider/retry_provider.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import random + +from ..typing import CreateResult +from .base_provider import BaseProvider, AsyncProvider + + +class RetryProvider(AsyncProvider): + __name__ = "RetryProvider" + working = True + needs_auth = False + supports_stream = True + supports_gpt_35_turbo = False + supports_gpt_4 = False + + def __init__( + self, + providers: list[type[BaseProvider]], + shuffle: bool = True + ) -> None: + self.providers = providers + self.shuffle = shuffle + + + def create_completion( + self, + model: str, + messages: list[dict[str, str]], + stream: bool = False, + **kwargs + ) -> CreateResult: + if stream: + providers = [provider for provider in self.providers if provider.supports_stream] + else: + providers = self.providers + if self.shuffle: + random.shuffle(providers) + + self.exceptions = {} + started = False + for provider in providers: + try: + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token + started = True + if started: + return + except Exception as e: + self.exceptions[provider.__name__] = e + if started: + break + + self.raise_exceptions() + + async def create_async( + self, + model: str, + messages: list[dict[str, str]], + **kwargs + ) -> str: + providers = [provider for provider in self.providers if issubclass(provider, AsyncProvider)] + if self.shuffle: + random.shuffle(providers) + + self.exceptions = {} + for provider in providers: + try: + return await provider.create_async(model, messages, **kwargs) + except Exception as e: + self.exceptions[provider.__name__] = e + + self.raise_exceptions() + + def raise_exceptions(self): + if self.exceptions: + raise RuntimeError("\n".join(["All providers failed:"] + [ + f"{p}: {self.exceptions[p].__class__.__name__}: {self.exceptions[p]}" for p in self.exceptions + ])) + + raise RuntimeError("No provider found") \ No newline at end of file diff --git a/g4f/__init__.py b/g4f/__init__.py index e42be8a8..8fdfe21f 100644 --- a/g4f/__init__.py +++ b/g4f/__init__.py @@ -14,13 +14,7 @@ def get_model_and_provider(model: Union[models.Model, str], provider: type[BaseP raise Exception(f'The model: {model} does not exist') if not provider: - if isinstance(model.best_provider, list): - if stream: - provider = random.choice([p for p in model.best_provider if p.supports_stream]) - else: - provider = random.choice(model.best_provider) - else: - provider = model.best_provider + provider = model.best_provider if not provider: raise Exception(f'No provider found for model: {model}') @@ -70,7 +64,7 @@ class ChatCompletion: model, provider = get_model_and_provider(model, provider, False) - if not issubclass(provider, AsyncProvider): + if not issubclass(type(provider), AsyncProvider): raise Exception(f"Provider: {provider.__name__} doesn't support create_async") return await provider.create_async(model.name, messages, **kwargs) diff --git a/g4f/models.py b/g4f/models.py index 1066e1aa..f1b0ec31 100644 --- a/g4f/models.py +++ b/g4f/models.py @@ -1,17 +1,23 @@ from __future__ import annotations from dataclasses import dataclass from .typing import Union -from .Provider import BaseProvider +from .Provider import BaseProvider, RetryProvider from .Provider import ( ChatgptLogin, - CodeLinkAva, ChatgptAi, ChatBase, Vercel, DeepAi, Aivvm, Bard, - H2o + H2o, + GptGo, + Bing, + PerplexityAi, + Wewordle, + Yqcloud, + AItianhu, + Aichat, ) @dataclass(unsafe_hash=True) @@ -24,15 +30,24 @@ class Model: # Works for Liaobots, H2o, OpenaiChat, Yqcloud, You default = Model( name = "", - base_provider = "huggingface") + base_provider = "", + best_provider = RetryProvider([ + Bing, # Not fully GPT 3 or 4 + PerplexityAi, # Adds references to sources + Wewordle, # Responds with markdown + Yqcloud, # Answers short questions in chinese + ChatBase, # Don't want to answer creatively + DeepAi, ChatgptLogin, ChatgptAi, Aivvm, GptGo, AItianhu, Aichat, + ]) +) # GPT-3.5 / GPT-4 gpt_35_turbo = Model( name = 'gpt-3.5-turbo', base_provider = 'openai', - best_provider = [ - DeepAi, CodeLinkAva, ChatgptLogin, ChatgptAi, ChatBase, Aivvm - ] + best_provider = RetryProvider([ + DeepAi, ChatgptLogin, ChatgptAi, Aivvm, GptGo, AItianhu, Aichat, + ]) ) gpt_4 = Model( -- cgit v1.2.3