diff options
Diffstat (limited to 'g4f')
-rw-r--r-- | g4f/Provider/Cohere.py | 106 | ||||
-rw-r--r-- | g4f/Provider/DeepInfra.py | 10 | ||||
-rw-r--r-- | g4f/Provider/HuggingChat.py | 14 | ||||
-rw-r--r-- | g4f/Provider/PerplexityLabs.py | 5 | ||||
-rw-r--r-- | g4f/Provider/You.py | 8 | ||||
-rw-r--r-- | g4f/Provider/__init__.py | 1 | ||||
-rw-r--r-- | g4f/Provider/needs_auth/Openai.py | 1 | ||||
-rw-r--r-- | g4f/Provider/needs_auth/OpenaiChat.py | 12 | ||||
-rw-r--r-- | g4f/Provider/you/har_file.py | 87 | ||||
-rw-r--r-- | g4f/client/async_client.py | 4 | ||||
-rw-r--r-- | g4f/client/service.py | 3 | ||||
-rw-r--r-- | g4f/gui/client/static/js/chat.v1.js | 28 | ||||
-rw-r--r-- | g4f/models.py | 43 | ||||
-rw-r--r-- | g4f/providers/retry_provider.py | 114 |
14 files changed, 347 insertions, 89 deletions
diff --git a/g4f/Provider/Cohere.py b/g4f/Provider/Cohere.py new file mode 100644 index 00000000..4f9fd30a --- /dev/null +++ b/g4f/Provider/Cohere.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import json, random, requests, threading +from aiohttp import ClientSession + +from ..typing import CreateResult, Messages +from .base_provider import AbstractProvider +from .helper import format_prompt + +class Cohere(AbstractProvider): + url = "https://cohereforai-c4ai-command-r-plus.hf.space" + working = True + supports_gpt_35_turbo = False + supports_gpt_4 = False + supports_stream = True + + @staticmethod + def create_completion( + model: str, + messages: Messages, + stream: bool, + proxy: str = None, + max_retries: int = 6, + **kwargs + ) -> CreateResult: + + prompt = format_prompt(messages) + + headers = { + 'accept': 'text/event-stream', + 'accept-language': 'en,fr-FR;q=0.9,fr;q=0.8,es-ES;q=0.7,es;q=0.6,en-US;q=0.5,am;q=0.4,de;q=0.3', + 'cache-control': 'no-cache', + 'pragma': 'no-cache', + 'referer': 'https://cohereforai-c4ai-command-r-plus.hf.space/?__theme=light', + 'sec-ch-ua': '"Google Chrome";v="123", "Not:A-Brand";v="8", "Chromium";v="123"', + 'sec-ch-ua-mobile': '?0', + 'sec-ch-ua-platform': '"macOS"', + 'sec-fetch-dest': 'empty', + 'sec-fetch-mode': 'cors', + 'sec-fetch-site': 'same-origin', + 'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36', + } + + session_hash = ''.join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=11)) + + params = { + 'fn_index': '1', + 'session_hash': session_hash, + } + + response = requests.get( + 'https://cohereforai-c4ai-command-r-plus.hf.space/queue/join', + params=params, + headers=headers, + stream=True + ) + + completion = '' + + for line in response.iter_lines(): + if line: + json_data = json.loads(line[6:]) + + if b"send_data" in (line): + event_id = json_data["event_id"] + + threading.Thread(target=send_data, args=[session_hash, event_id, prompt]).start() + + if b"process_generating" in line or b"process_completed" in line: + token = (json_data['output']['data'][0][0][1]) + + yield (token.replace(completion, "")) + completion = token + +def send_data(session_hash, event_id, prompt): + headers = { + 'accept': '*/*', + 'accept-language': 'en,fr-FR;q=0.9,fr;q=0.8,es-ES;q=0.7,es;q=0.6,en-US;q=0.5,am;q=0.4,de;q=0.3', + 'cache-control': 'no-cache', + 'content-type': 'application/json', + 'origin': 'https://cohereforai-c4ai-command-r-plus.hf.space', + 'pragma': 'no-cache', + 'referer': 'https://cohereforai-c4ai-command-r-plus.hf.space/?__theme=light', + 'sec-ch-ua': '"Google Chrome";v="123", "Not:A-Brand";v="8", "Chromium";v="123"', + 'sec-ch-ua-mobile': '?0', + 'sec-ch-ua-platform': '"macOS"', + 'sec-fetch-dest': 'empty', + 'sec-fetch-mode': 'cors', + 'sec-fetch-site': 'same-origin', + 'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36', + } + + json_data = { + 'data': [ + prompt, + '', + [], + ], + 'event_data': None, + 'fn_index': 1, + 'session_hash': session_hash, + 'event_id': event_id + } + + requests.post('https://cohereforai-c4ai-command-r-plus.hf.space/queue/data', + json = json_data, headers=headers)
\ No newline at end of file diff --git a/g4f/Provider/DeepInfra.py b/g4f/Provider/DeepInfra.py index 68aaf8b9..971424b7 100644 --- a/g4f/Provider/DeepInfra.py +++ b/g4f/Provider/DeepInfra.py @@ -11,7 +11,7 @@ class DeepInfra(Openai): needs_auth = False supports_stream = True supports_message_history = True - default_model = 'meta-llama/Llama-2-70b-chat-hf' + default_model = 'HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1' @classmethod def get_models(cls): @@ -32,6 +32,14 @@ class DeepInfra(Openai): max_tokens: int = 1028, **kwargs ) -> AsyncResult: + + if not '/' in model: + models = { + 'mixtral-8x22b': 'HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1', + 'dbrx-instruct': 'databricks/dbrx-instruct', + } + model = models.get(model, model) + headers = { 'Accept-Encoding': 'gzip, deflate, br', 'Accept-Language': 'en-US', diff --git a/g4f/Provider/HuggingChat.py b/g4f/Provider/HuggingChat.py index b80795fe..882edb78 100644 --- a/g4f/Provider/HuggingChat.py +++ b/g4f/Provider/HuggingChat.py @@ -14,13 +14,12 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin): working = True default_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" models = [ - "mistralai/Mixtral-8x7B-Instruct-v0.1", - "google/gemma-7b-it", - "meta-llama/Llama-2-70b-chat-hf", - "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", - "codellama/CodeLlama-34b-Instruct-hf", - "mistralai/Mistral-7B-Instruct-v0.2", - "openchat/openchat-3.5-0106", + "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1", + 'CohereForAI/c4ai-command-r-plus', + 'mistralai/Mixtral-8x7B-Instruct-v0.1', + 'google/gemma-1.1-7b-it', + 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO', + 'mistralai/Mistral-7B-Instruct-v0.2' ] model_aliases = { "openchat/openchat_3.5": "openchat/openchat-3.5-0106", @@ -48,6 +47,7 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin): **kwargs ) -> AsyncResult: options = {"model": cls.get_model(model)} + system_prompt = "\n".join([message["content"] for message in messages if message["role"] == "system"]) if system_prompt: options["preprompt"] = system_prompt diff --git a/g4f/Provider/PerplexityLabs.py b/g4f/Provider/PerplexityLabs.py index 6c80efee..ba956100 100644 --- a/g4f/Provider/PerplexityLabs.py +++ b/g4f/Provider/PerplexityLabs.py @@ -19,13 +19,14 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin): "sonar-small-online", "sonar-medium-online", "sonar-small-chat", "sonar-medium-chat", "mistral-7b-instruct", "codellama-70b-instruct", "llava-v1.5-7b-wrapper", "llava-v1.6-34b", "mixtral-8x7b-instruct", "gemma-2b-it", "gemma-7b-it" - "mistral-medium", "related" + "mistral-medium", "related", "dbrx-instruct" ] model_aliases = { "mistralai/Mistral-7B-Instruct-v0.1": "mistral-7b-instruct", "mistralai/Mixtral-8x7B-Instruct-v0.1": "mixtral-8x7b-instruct", "codellama/CodeLlama-70b-Instruct-hf": "codellama-70b-instruct", - "llava-v1.5-7b": "llava-v1.5-7b-wrapper" + "llava-v1.5-7b": "llava-v1.5-7b-wrapper", + 'databricks/dbrx-instruct': "dbrx-instruct" } @classmethod diff --git a/g4f/Provider/You.py b/g4f/Provider/You.py index be4ab523..3ebd40f2 100644 --- a/g4f/Provider/You.py +++ b/g4f/Provider/You.py @@ -65,6 +65,7 @@ class You(AsyncGeneratorProvider, ProviderModelMixin): timeout=(30, timeout) ) as session: cookies = await cls.get_cookies(session) if chat_mode != "default" else None + upload = json.dumps([await cls.upload_file(session, cookies, to_bytes(image), image_name)]) if image else "" headers = { "Accept": "text/event-stream", @@ -131,6 +132,7 @@ class You(AsyncGeneratorProvider, ProviderModelMixin): @classmethod async def get_cookies(cls, client: StreamSession) -> Cookies: + if not cls._cookies or cls._cookies_used >= 5: cls._cookies = await cls.create_cookies(client) cls._cookies_used = 0 @@ -151,8 +153,8 @@ class You(AsyncGeneratorProvider, ProviderModelMixin): }}).encode()).decode() def get_auth() -> str: - auth_uuid = "507a52ad-7e69-496b-aee0-1c9863c7c8" - auth_token = f"public-token-live-{auth_uuid}bb:public-token-live-{auth_uuid}19" + auth_uuid = "507a52ad-7e69-496b-aee0-1c9863c7c819" + auth_token = f"public-token-live-{auth_uuid}:public-token-live-{auth_uuid}" auth = base64.standard_b64encode(auth_token.encode()).decode() return f"Basic {auth}" @@ -172,12 +174,12 @@ class You(AsyncGeneratorProvider, ProviderModelMixin): "dfp_telemetry_id": await get_dfp_telemetry_id(), "email": f"{user_uuid}@gmail.com", "password": f"{user_uuid}#{user_uuid}", - "dfp_telemetry_id": f"{uuid.uuid4()}", "session_duration_minutes": 129600 } ) as response: await raise_for_status(response) session = (await response.json())["data"] + return { "stytch_session": session["session_token"], 'stytch_session_jwt': session["session_jwt"], diff --git a/g4f/Provider/__init__.py b/g4f/Provider/__init__.py index b818a752..ea64f80a 100644 --- a/g4f/Provider/__init__.py +++ b/g4f/Provider/__init__.py @@ -46,6 +46,7 @@ from .ReplicateImage import ReplicateImage from .Vercel import Vercel from .WhiteRabbitNeo import WhiteRabbitNeo from .You import You +from .Cohere import Cohere import sys diff --git a/g4f/Provider/needs_auth/Openai.py b/g4f/Provider/needs_auth/Openai.py index 81ba5981..80318f6d 100644 --- a/g4f/Provider/needs_auth/Openai.py +++ b/g4f/Provider/needs_auth/Openai.py @@ -51,6 +51,7 @@ class Openai(AsyncGeneratorProvider, ProviderModelMixin): stream=stream, **extra_data ) + async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response: await raise_for_status(response) if not stream: diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 3145161a..b34daa3e 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -44,7 +44,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): supports_system_message = True default_model = None models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"] - model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo", "": "gpt-3.5-turbo"} + model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo", "": "gpt-3.5-turbo", "gpt-4-turbo-preview": "gpt-4"} _api_key: str = None _headers: dict = None _cookies: Cookies = None @@ -334,6 +334,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): Raises: RuntimeError: If an error occurs during processing. """ + async with StreamSession( proxies={"all": proxy}, impersonate="chrome", @@ -359,6 +360,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): if debug.logging: print("OpenaiChat: Load default_model failed") print(f"{e.__class__.__name__}: {e}") + arkose_token = None if cls.default_model is None: @@ -369,6 +371,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): except NoValidHarFileError: ... if cls._api_key is None: + if debug.logging: + print("Getting access token with nodriver.") await cls.nodriver_access_token() cls.default_model = cls.get_model(await cls.get_default_model(session, cls._headers)) @@ -384,6 +388,9 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): blob = data["arkose"]["dx"] need_arkose = data["arkose"]["required"] chat_token = data["token"] + + if debug.logging: + print(f'Arkose: {need_arkose} Turnstile: {data["turnstile"]["required"]}') if need_arkose and arkose_token is None: arkose_token, api_key, cookies = await getArkoseAndAccessToken(proxy) @@ -582,6 +589,7 @@ this.fetch = async (url, options) => { user_data_dir = user_config_dir("g4f-nodriver") except: user_data_dir = None + browser = await uc.start(user_data_dir=user_data_dir) page = await browser.get("https://chat.openai.com/") while await page.query_selector("#prompt-textarea") is None: @@ -781,4 +789,4 @@ class Response(): async def get_messages(self) -> list: messages = self._messages messages.append({"role": "assistant", "content": await self.message()}) - return messages
\ No newline at end of file + return messages diff --git a/g4f/Provider/you/har_file.py b/g4f/Provider/you/har_file.py index 281f37e2..a6981296 100644 --- a/g4f/Provider/you/har_file.py +++ b/g4f/Provider/you/har_file.py @@ -4,6 +4,8 @@ import json import os import random import uuid +import asyncio +import requests from ...requests import StreamSession, raise_for_status @@ -65,8 +67,89 @@ async def sendRequest(tmpArk: arkReq, proxy: str = None): return await response.text() async def get_dfp_telemetry_id(proxy: str = None): - return str(uuid.uuid4()) + return await telemetry_id_with_driver(proxy) global chatArks if chatArks is None: chatArks = readHAR() - return await sendRequest(random.choice(chatArks), proxy)
\ No newline at end of file + return await sendRequest(random.choice(chatArks), proxy) + +async def telemetry_id_with_driver(proxy: str = None): + from ...debug import logging + if logging: + print('getting telemetry_id for you.com with nodriver') + try: + import nodriver as uc + from nodriver import start, cdp, loop + except ImportError: + if logging: + print('nodriver not found, random uuid (may fail)') + return str(uuid.uuid4()) + + CAN_EVAL = False + payload_received = False + payload = None + + try: + browser = await start() + tab = browser.main_tab + + async def send_handler(event: cdp.network.RequestWillBeSent): + nonlocal CAN_EVAL, payload_received, payload + if 'telemetry.js' in event.request.url: + CAN_EVAL = True + if "/submit" in event.request.url: + payload = event.request.post_data + payload_received = True + + tab.add_handler(cdp.network.RequestWillBeSent, send_handler) + await browser.get("https://you.com") + + while not CAN_EVAL: + await tab.sleep(1) + + await tab.evaluate('window.GetTelemetryID("public-token-live-507a52ad-7e69-496b-aee0-1c9863c7c819", "https://telemetry.stytch.com/submit");') + + while not payload_received: + await tab.sleep(.1) + + except Exception as e: + print(f"Error occurred: {str(e)}") + + finally: + try: + await tab.close() + except Exception as e: + print(f"Error occurred while closing tab: {str(e)}") + + try: + await browser.stop() + except Exception as e: + pass + + headers = { + 'Accept': '*/*', + 'Accept-Language': 'en,fr-FR;q=0.9,fr;q=0.8,es-ES;q=0.7,es;q=0.6,en-US;q=0.5,am;q=0.4,de;q=0.3', + 'Connection': 'keep-alive', + 'Content-type': 'application/x-www-form-urlencoded', + 'Origin': 'https://you.com', + 'Referer': 'https://you.com/', + 'Sec-Fetch-Dest': 'empty', + 'Sec-Fetch-Mode': 'cors', + 'Sec-Fetch-Site': 'cross-site', + 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36', + 'sec-ch-ua': '"Google Chrome";v="123", "Not:A-Brand";v="8", "Chromium";v="123"', + 'sec-ch-ua-mobile': '?0', + 'sec-ch-ua-platform': '"macOS"', + } + + proxies = { + 'http': proxy, + 'https': proxy} if proxy else None + + response = requests.post('https://telemetry.stytch.com/submit', + headers=headers, data=payload, proxies=proxies) + + if '-' in response.text: + print(f'telemetry generated: {response.text}') + + return (response.text) diff --git a/g4f/client/async_client.py b/g4f/client/async_client.py index 51a9cf83..8e1ee33c 100644 --- a/g4f/client/async_client.py +++ b/g4f/client/async_client.py @@ -144,7 +144,7 @@ class Completions(): proxy=self.client.get_proxy() if proxy is None else proxy, max_tokens=max_tokens, stop=stop, - api_key=self.client.api_key if api_key is None else api_key + api_key=self.client.api_key if api_key is None else api_key, **kwargs ) response = iter_response(response, stream, response_format, max_tokens, stop) @@ -207,4 +207,4 @@ class Images(): result = iter_image_response(response) if result is None: raise NoImageResponseError() - return result
\ No newline at end of file + return result diff --git a/g4f/client/service.py b/g4f/client/service.py index f3565f6d..d25c923d 100644 --- a/g4f/client/service.py +++ b/g4f/client/service.py @@ -55,9 +55,10 @@ def get_model_and_provider(model : Union[Model, str], provider = convert_to_provider(provider) if isinstance(model, str): + if model in ModelUtils.convert: model = ModelUtils.convert[model] - + if not provider: if isinstance(model, str): raise ModelNotFoundError(f'Model not found: {model}') diff --git a/g4f/gui/client/static/js/chat.v1.js b/g4f/gui/client/static/js/chat.v1.js index 8933b442..7f4011a2 100644 --- a/g4f/gui/client/static/js/chat.v1.js +++ b/g4f/gui/client/static/js/chat.v1.js @@ -1074,7 +1074,7 @@ async function load_version() { } setTimeout(load_version, 2000); -for (const el of [imageInput, cameraInput]) { +[imageInput, cameraInput].forEach((el) => { el.addEventListener('click', async () => { el.value = ''; if (imageInput.dataset.src) { @@ -1082,7 +1082,7 @@ for (const el of [imageInput, cameraInput]) { delete imageInput.dataset.src } }); -} +}); fileInput.addEventListener('click', async (event) => { fileInput.value = ''; @@ -1261,31 +1261,26 @@ if (SpeechRecognition) { recognition.interimResults = true; recognition.maxAlternatives = 1; - function may_stop() { - if (microLabel.classList.contains("recognition")) { - recognition.stop(); - } - } - let startValue; - let timeoutHandle; + let shouldStop; let lastDebounceTranscript; recognition.onstart = function() { microLabel.classList.add("recognition"); startValue = messageInput.value; + shouldStop = false; lastDebounceTranscript = ""; - timeoutHandle = window.setTimeout(may_stop, 10000); }; recognition.onend = function() { - microLabel.classList.remove("recognition"); - messageInput.focus(); + if (shouldStop) { + messageInput.focus(); + } else { + recognition.start(); + } }; recognition.onresult = function(event) { if (!event.results) { return; } - window.clearTimeout(timeoutHandle); - let result = event.results[event.resultIndex]; let isFinal = result.isFinal && (result[0].confidence > 0); let transcript = result[0].transcript; @@ -1303,14 +1298,13 @@ if (SpeechRecognition) { messageInput.style.height = messageInput.scrollHeight + "px"; messageInput.scrollTop = messageInput.scrollHeight; } - - timeoutHandle = window.setTimeout(may_stop, transcript ? 10000 : 8000); }; microLabel.addEventListener("click", () => { if (microLabel.classList.contains("recognition")) { - window.clearTimeout(timeoutHandle); + shouldStop = true; recognition.stop(); + microLabel.classList.remove("recognition"); } else { const lang = document.getElementById("recognition-language")?.value; recognition.lang = lang || navigator.language; diff --git a/g4f/models.py b/g4f/models.py index 4480dc10..fe99958c 100644 --- a/g4f/models.py +++ b/g4f/models.py @@ -20,6 +20,7 @@ from .Provider import ( Vercel, Gemini, Koala, + Cohere, Bing, You, Pi, @@ -77,6 +78,7 @@ gpt_35_turbo = Model( You, ChatgptNext, Koala, + OpenaiChat, ]) ) @@ -161,11 +163,11 @@ mistral_7b_v02 = Model( best_provider = DeepInfra ) -# mixtral_8x22b = Model( -# name = "mistralai/Mixtral-8x22B-v0.1", -# base_provider = "huggingface", -# best_provider = DeepInfra -# ) +mixtral_8x22b = Model( + name = "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1", + base_provider = "huggingface", + best_provider = RetryProvider([HuggingChat, DeepInfra]) +) # Misc models dolphin_mixtral_8x7b = Model( @@ -265,6 +267,18 @@ pi = Model( best_provider = Pi ) +dbrx_instruct = Model( + name = 'databricks/dbrx-instruct', + base_provider = 'mistral', + best_provider = RetryProvider([DeepInfra, PerplexityLabs]) +) + +command_r_plus = Model( + name = 'CohereForAI/c4ai-command-r-plus', + base_provider = 'mistral', + best_provider = RetryProvider([HuggingChat, Cohere]) +) + class ModelUtils: """ Utility class for mapping string identifiers to Model instances. @@ -299,20 +313,29 @@ class ModelUtils: 'gigachat' : gigachat, 'gigachat_plus': gigachat_plus, 'gigachat_pro' : gigachat_pro, - + + # Mistral Opensource 'mixtral-8x7b': mixtral_8x7b, 'mistral-7b': mistral_7b, 'mistral-7b-v02': mistral_7b_v02, - # 'mixtral-8x22b': mixtral_8x22b, + 'mixtral-8x22b': mixtral_8x22b, 'dolphin-mixtral-8x7b': dolphin_mixtral_8x7b, - 'lzlv-70b': lzlv_70b, - 'airoboros-70b': airoboros_70b, - 'openchat_3.5': openchat_35, + + # google gemini 'gemini': gemini, 'gemini-pro': gemini_pro, + + # anthropic 'claude-v2': claude_v2, 'claude-3-opus': claude_3_opus, 'claude-3-sonnet': claude_3_sonnet, + + # other + 'command-r+': command_r_plus, + 'dbrx-instruct': dbrx_instruct, + 'lzlv-70b': lzlv_70b, + 'airoboros-70b': airoboros_70b, + 'openchat_3.5': openchat_35, 'pi': pi } diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py index 52f473e9..d64e8471 100644 --- a/g4f/providers/retry_provider.py +++ b/g4f/providers/retry_provider.py @@ -12,46 +12,40 @@ class RetryProvider(BaseRetryProvider): def __init__( self, providers: List[Type[BaseProvider]], - shuffle: bool = True + shuffle: bool = True, + single_provider_retry: bool = False, + max_retries: int = 3, ) -> None: """ Initialize the BaseRetryProvider. - Args: providers (List[Type[BaseProvider]]): List of providers to use. shuffle (bool): Whether to shuffle the providers list. + single_provider_retry (bool): Whether to retry a single provider if it fails. + max_retries (int): Maximum number of retries for a single provider. """ self.providers = providers self.shuffle = shuffle + self.single_provider_retry = single_provider_retry + self.max_retries = max_retries self.working = True self.last_provider: Type[BaseProvider] = None - """ - A provider class to handle retries for creating completions with different providers. - - Attributes: - providers (list): A list of provider instances. - shuffle (bool): A flag indicating whether to shuffle providers before use. - last_provider (BaseProvider): The last provider that was used. - """ def create_completion( self, model: str, messages: Messages, stream: bool = False, - **kwargs + **kwargs, ) -> CreateResult: """ Create a completion using available providers, with an option to stream the response. - Args: model (str): The model to be used for completion. messages (Messages): The messages to be used for generating completion. stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False. - Yields: CreateResult: Tokens or results from the completion. - Raises: Exception: Any exception encountered during the completion process. """ @@ -61,22 +55,42 @@ class RetryProvider(BaseRetryProvider): exceptions = {} started: bool = False - for provider in providers: + + if self.single_provider_retry and len(providers) == 1: + provider = providers[0] self.last_provider = provider - try: - if debug.logging: - print(f"Using {provider.__name__} provider") - for token in provider.create_completion(model, messages, stream, **kwargs): - yield token + for attempt in range(self.max_retries): + try: + if debug.logging: + print(f"Using {provider.__name__} provider (attempt {attempt + 1})") + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token started = True - if started: - return - except Exception as e: - exceptions[provider.__name__] = e - if debug.logging: - print(f"{provider.__name__}: {e.__class__.__name__}: {e}") - if started: - raise e + if started: + return + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + if started: + raise e + else: + for provider in providers: + self.last_provider = provider + try: + if debug.logging: + print(f"Using {provider.__name__} provider") + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token + started = True + if started: + return + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + if started: + raise e raise_exceptions(exceptions) @@ -84,18 +98,15 @@ class RetryProvider(BaseRetryProvider): self, model: str, messages: Messages, - **kwargs + **kwargs, ) -> str: """ Asynchronously create a completion using available providers. - Args: model (str): The model to be used for completion. messages (Messages): The messages to be used for generating completion. - Returns: str: The result of the asynchronous completion. - Raises: Exception: Any exception encountered during the asynchronous completion process. """ @@ -104,17 +115,36 @@ class RetryProvider(BaseRetryProvider): random.shuffle(providers) exceptions = {} - for provider in providers: + + if self.single_provider_retry and len(providers) == 1: + provider = providers[0] self.last_provider = provider - try: - return await asyncio.wait_for( - provider.create_async(model, messages, **kwargs), - timeout=kwargs.get("timeout", 60) - ) - except Exception as e: - exceptions[provider.__name__] = e - if debug.logging: - print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + for attempt in range(self.max_retries): + try: + if debug.logging: + print(f"Using {provider.__name__} provider (attempt {attempt + 1})") + return await asyncio.wait_for( + provider.create_async(model, messages, **kwargs), + timeout=kwargs.get("timeout", 60), + ) + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + else: + for provider in providers: + self.last_provider = provider + try: + if debug.logging: + print(f"Using {provider.__name__} provider") + return await asyncio.wait_for( + provider.create_async(model, messages, **kwargs), + timeout=kwargs.get("timeout", 60), + ) + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") raise_exceptions(exceptions) |