summaryrefslogtreecommitdiffstats
path: root/g4f
diff options
context:
space:
mode:
Diffstat (limited to 'g4f')
-rw-r--r--g4f/Provider/Cohere.py106
-rw-r--r--g4f/Provider/DeepInfra.py10
-rw-r--r--g4f/Provider/HuggingChat.py14
-rw-r--r--g4f/Provider/PerplexityLabs.py5
-rw-r--r--g4f/Provider/You.py8
-rw-r--r--g4f/Provider/__init__.py1
-rw-r--r--g4f/Provider/needs_auth/Openai.py1
-rw-r--r--g4f/Provider/needs_auth/OpenaiChat.py12
-rw-r--r--g4f/Provider/you/har_file.py87
-rw-r--r--g4f/client/async_client.py4
-rw-r--r--g4f/client/service.py3
-rw-r--r--g4f/gui/client/static/js/chat.v1.js28
-rw-r--r--g4f/models.py43
-rw-r--r--g4f/providers/retry_provider.py114
14 files changed, 347 insertions, 89 deletions
diff --git a/g4f/Provider/Cohere.py b/g4f/Provider/Cohere.py
new file mode 100644
index 00000000..4f9fd30a
--- /dev/null
+++ b/g4f/Provider/Cohere.py
@@ -0,0 +1,106 @@
+from __future__ import annotations
+
+import json, random, requests, threading
+from aiohttp import ClientSession
+
+from ..typing import CreateResult, Messages
+from .base_provider import AbstractProvider
+from .helper import format_prompt
+
+class Cohere(AbstractProvider):
+ url = "https://cohereforai-c4ai-command-r-plus.hf.space"
+ working = True
+ supports_gpt_35_turbo = False
+ supports_gpt_4 = False
+ supports_stream = True
+
+ @staticmethod
+ def create_completion(
+ model: str,
+ messages: Messages,
+ stream: bool,
+ proxy: str = None,
+ max_retries: int = 6,
+ **kwargs
+ ) -> CreateResult:
+
+ prompt = format_prompt(messages)
+
+ headers = {
+ 'accept': 'text/event-stream',
+ 'accept-language': 'en,fr-FR;q=0.9,fr;q=0.8,es-ES;q=0.7,es;q=0.6,en-US;q=0.5,am;q=0.4,de;q=0.3',
+ 'cache-control': 'no-cache',
+ 'pragma': 'no-cache',
+ 'referer': 'https://cohereforai-c4ai-command-r-plus.hf.space/?__theme=light',
+ 'sec-ch-ua': '"Google Chrome";v="123", "Not:A-Brand";v="8", "Chromium";v="123"',
+ 'sec-ch-ua-mobile': '?0',
+ 'sec-ch-ua-platform': '"macOS"',
+ 'sec-fetch-dest': 'empty',
+ 'sec-fetch-mode': 'cors',
+ 'sec-fetch-site': 'same-origin',
+ 'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36',
+ }
+
+ session_hash = ''.join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=11))
+
+ params = {
+ 'fn_index': '1',
+ 'session_hash': session_hash,
+ }
+
+ response = requests.get(
+ 'https://cohereforai-c4ai-command-r-plus.hf.space/queue/join',
+ params=params,
+ headers=headers,
+ stream=True
+ )
+
+ completion = ''
+
+ for line in response.iter_lines():
+ if line:
+ json_data = json.loads(line[6:])
+
+ if b"send_data" in (line):
+ event_id = json_data["event_id"]
+
+ threading.Thread(target=send_data, args=[session_hash, event_id, prompt]).start()
+
+ if b"process_generating" in line or b"process_completed" in line:
+ token = (json_data['output']['data'][0][0][1])
+
+ yield (token.replace(completion, ""))
+ completion = token
+
+def send_data(session_hash, event_id, prompt):
+ headers = {
+ 'accept': '*/*',
+ 'accept-language': 'en,fr-FR;q=0.9,fr;q=0.8,es-ES;q=0.7,es;q=0.6,en-US;q=0.5,am;q=0.4,de;q=0.3',
+ 'cache-control': 'no-cache',
+ 'content-type': 'application/json',
+ 'origin': 'https://cohereforai-c4ai-command-r-plus.hf.space',
+ 'pragma': 'no-cache',
+ 'referer': 'https://cohereforai-c4ai-command-r-plus.hf.space/?__theme=light',
+ 'sec-ch-ua': '"Google Chrome";v="123", "Not:A-Brand";v="8", "Chromium";v="123"',
+ 'sec-ch-ua-mobile': '?0',
+ 'sec-ch-ua-platform': '"macOS"',
+ 'sec-fetch-dest': 'empty',
+ 'sec-fetch-mode': 'cors',
+ 'sec-fetch-site': 'same-origin',
+ 'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36',
+ }
+
+ json_data = {
+ 'data': [
+ prompt,
+ '',
+ [],
+ ],
+ 'event_data': None,
+ 'fn_index': 1,
+ 'session_hash': session_hash,
+ 'event_id': event_id
+ }
+
+ requests.post('https://cohereforai-c4ai-command-r-plus.hf.space/queue/data',
+ json = json_data, headers=headers) \ No newline at end of file
diff --git a/g4f/Provider/DeepInfra.py b/g4f/Provider/DeepInfra.py
index 68aaf8b9..971424b7 100644
--- a/g4f/Provider/DeepInfra.py
+++ b/g4f/Provider/DeepInfra.py
@@ -11,7 +11,7 @@ class DeepInfra(Openai):
needs_auth = False
supports_stream = True
supports_message_history = True
- default_model = 'meta-llama/Llama-2-70b-chat-hf'
+ default_model = 'HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1'
@classmethod
def get_models(cls):
@@ -32,6 +32,14 @@ class DeepInfra(Openai):
max_tokens: int = 1028,
**kwargs
) -> AsyncResult:
+
+ if not '/' in model:
+ models = {
+ 'mixtral-8x22b': 'HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1',
+ 'dbrx-instruct': 'databricks/dbrx-instruct',
+ }
+ model = models.get(model, model)
+
headers = {
'Accept-Encoding': 'gzip, deflate, br',
'Accept-Language': 'en-US',
diff --git a/g4f/Provider/HuggingChat.py b/g4f/Provider/HuggingChat.py
index b80795fe..882edb78 100644
--- a/g4f/Provider/HuggingChat.py
+++ b/g4f/Provider/HuggingChat.py
@@ -14,13 +14,12 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin):
working = True
default_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
models = [
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
- "google/gemma-7b-it",
- "meta-llama/Llama-2-70b-chat-hf",
- "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
- "codellama/CodeLlama-34b-Instruct-hf",
- "mistralai/Mistral-7B-Instruct-v0.2",
- "openchat/openchat-3.5-0106",
+ "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
+ 'CohereForAI/c4ai-command-r-plus',
+ 'mistralai/Mixtral-8x7B-Instruct-v0.1',
+ 'google/gemma-1.1-7b-it',
+ 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO',
+ 'mistralai/Mistral-7B-Instruct-v0.2'
]
model_aliases = {
"openchat/openchat_3.5": "openchat/openchat-3.5-0106",
@@ -48,6 +47,7 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin):
**kwargs
) -> AsyncResult:
options = {"model": cls.get_model(model)}
+
system_prompt = "\n".join([message["content"] for message in messages if message["role"] == "system"])
if system_prompt:
options["preprompt"] = system_prompt
diff --git a/g4f/Provider/PerplexityLabs.py b/g4f/Provider/PerplexityLabs.py
index 6c80efee..ba956100 100644
--- a/g4f/Provider/PerplexityLabs.py
+++ b/g4f/Provider/PerplexityLabs.py
@@ -19,13 +19,14 @@ class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
"sonar-small-online", "sonar-medium-online", "sonar-small-chat", "sonar-medium-chat", "mistral-7b-instruct",
"codellama-70b-instruct", "llava-v1.5-7b-wrapper", "llava-v1.6-34b", "mixtral-8x7b-instruct",
"gemma-2b-it", "gemma-7b-it"
- "mistral-medium", "related"
+ "mistral-medium", "related", "dbrx-instruct"
]
model_aliases = {
"mistralai/Mistral-7B-Instruct-v0.1": "mistral-7b-instruct",
"mistralai/Mixtral-8x7B-Instruct-v0.1": "mixtral-8x7b-instruct",
"codellama/CodeLlama-70b-Instruct-hf": "codellama-70b-instruct",
- "llava-v1.5-7b": "llava-v1.5-7b-wrapper"
+ "llava-v1.5-7b": "llava-v1.5-7b-wrapper",
+ 'databricks/dbrx-instruct': "dbrx-instruct"
}
@classmethod
diff --git a/g4f/Provider/You.py b/g4f/Provider/You.py
index be4ab523..3ebd40f2 100644
--- a/g4f/Provider/You.py
+++ b/g4f/Provider/You.py
@@ -65,6 +65,7 @@ class You(AsyncGeneratorProvider, ProviderModelMixin):
timeout=(30, timeout)
) as session:
cookies = await cls.get_cookies(session) if chat_mode != "default" else None
+
upload = json.dumps([await cls.upload_file(session, cookies, to_bytes(image), image_name)]) if image else ""
headers = {
"Accept": "text/event-stream",
@@ -131,6 +132,7 @@ class You(AsyncGeneratorProvider, ProviderModelMixin):
@classmethod
async def get_cookies(cls, client: StreamSession) -> Cookies:
+
if not cls._cookies or cls._cookies_used >= 5:
cls._cookies = await cls.create_cookies(client)
cls._cookies_used = 0
@@ -151,8 +153,8 @@ class You(AsyncGeneratorProvider, ProviderModelMixin):
}}).encode()).decode()
def get_auth() -> str:
- auth_uuid = "507a52ad-7e69-496b-aee0-1c9863c7c8"
- auth_token = f"public-token-live-{auth_uuid}bb:public-token-live-{auth_uuid}19"
+ auth_uuid = "507a52ad-7e69-496b-aee0-1c9863c7c819"
+ auth_token = f"public-token-live-{auth_uuid}:public-token-live-{auth_uuid}"
auth = base64.standard_b64encode(auth_token.encode()).decode()
return f"Basic {auth}"
@@ -172,12 +174,12 @@ class You(AsyncGeneratorProvider, ProviderModelMixin):
"dfp_telemetry_id": await get_dfp_telemetry_id(),
"email": f"{user_uuid}@gmail.com",
"password": f"{user_uuid}#{user_uuid}",
- "dfp_telemetry_id": f"{uuid.uuid4()}",
"session_duration_minutes": 129600
}
) as response:
await raise_for_status(response)
session = (await response.json())["data"]
+
return {
"stytch_session": session["session_token"],
'stytch_session_jwt': session["session_jwt"],
diff --git a/g4f/Provider/__init__.py b/g4f/Provider/__init__.py
index b818a752..ea64f80a 100644
--- a/g4f/Provider/__init__.py
+++ b/g4f/Provider/__init__.py
@@ -46,6 +46,7 @@ from .ReplicateImage import ReplicateImage
from .Vercel import Vercel
from .WhiteRabbitNeo import WhiteRabbitNeo
from .You import You
+from .Cohere import Cohere
import sys
diff --git a/g4f/Provider/needs_auth/Openai.py b/g4f/Provider/needs_auth/Openai.py
index 81ba5981..80318f6d 100644
--- a/g4f/Provider/needs_auth/Openai.py
+++ b/g4f/Provider/needs_auth/Openai.py
@@ -51,6 +51,7 @@ class Openai(AsyncGeneratorProvider, ProviderModelMixin):
stream=stream,
**extra_data
)
+
async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response:
await raise_for_status(response)
if not stream:
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index 3145161a..b34daa3e 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -44,7 +44,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
supports_system_message = True
default_model = None
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"]
- model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo", "": "gpt-3.5-turbo"}
+ model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo", "": "gpt-3.5-turbo", "gpt-4-turbo-preview": "gpt-4"}
_api_key: str = None
_headers: dict = None
_cookies: Cookies = None
@@ -334,6 +334,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
Raises:
RuntimeError: If an error occurs during processing.
"""
+
async with StreamSession(
proxies={"all": proxy},
impersonate="chrome",
@@ -359,6 +360,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
if debug.logging:
print("OpenaiChat: Load default_model failed")
print(f"{e.__class__.__name__}: {e}")
+
arkose_token = None
if cls.default_model is None:
@@ -369,6 +371,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
except NoValidHarFileError:
...
if cls._api_key is None:
+ if debug.logging:
+ print("Getting access token with nodriver.")
await cls.nodriver_access_token()
cls.default_model = cls.get_model(await cls.get_default_model(session, cls._headers))
@@ -384,6 +388,9 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
blob = data["arkose"]["dx"]
need_arkose = data["arkose"]["required"]
chat_token = data["token"]
+
+ if debug.logging:
+ print(f'Arkose: {need_arkose} Turnstile: {data["turnstile"]["required"]}')
if need_arkose and arkose_token is None:
arkose_token, api_key, cookies = await getArkoseAndAccessToken(proxy)
@@ -582,6 +589,7 @@ this.fetch = async (url, options) => {
user_data_dir = user_config_dir("g4f-nodriver")
except:
user_data_dir = None
+
browser = await uc.start(user_data_dir=user_data_dir)
page = await browser.get("https://chat.openai.com/")
while await page.query_selector("#prompt-textarea") is None:
@@ -781,4 +789,4 @@ class Response():
async def get_messages(self) -> list:
messages = self._messages
messages.append({"role": "assistant", "content": await self.message()})
- return messages \ No newline at end of file
+ return messages
diff --git a/g4f/Provider/you/har_file.py b/g4f/Provider/you/har_file.py
index 281f37e2..a6981296 100644
--- a/g4f/Provider/you/har_file.py
+++ b/g4f/Provider/you/har_file.py
@@ -4,6 +4,8 @@ import json
import os
import random
import uuid
+import asyncio
+import requests
from ...requests import StreamSession, raise_for_status
@@ -65,8 +67,89 @@ async def sendRequest(tmpArk: arkReq, proxy: str = None):
return await response.text()
async def get_dfp_telemetry_id(proxy: str = None):
- return str(uuid.uuid4())
+ return await telemetry_id_with_driver(proxy)
global chatArks
if chatArks is None:
chatArks = readHAR()
- return await sendRequest(random.choice(chatArks), proxy) \ No newline at end of file
+ return await sendRequest(random.choice(chatArks), proxy)
+
+async def telemetry_id_with_driver(proxy: str = None):
+ from ...debug import logging
+ if logging:
+ print('getting telemetry_id for you.com with nodriver')
+ try:
+ import nodriver as uc
+ from nodriver import start, cdp, loop
+ except ImportError:
+ if logging:
+ print('nodriver not found, random uuid (may fail)')
+ return str(uuid.uuid4())
+
+ CAN_EVAL = False
+ payload_received = False
+ payload = None
+
+ try:
+ browser = await start()
+ tab = browser.main_tab
+
+ async def send_handler(event: cdp.network.RequestWillBeSent):
+ nonlocal CAN_EVAL, payload_received, payload
+ if 'telemetry.js' in event.request.url:
+ CAN_EVAL = True
+ if "/submit" in event.request.url:
+ payload = event.request.post_data
+ payload_received = True
+
+ tab.add_handler(cdp.network.RequestWillBeSent, send_handler)
+ await browser.get("https://you.com")
+
+ while not CAN_EVAL:
+ await tab.sleep(1)
+
+ await tab.evaluate('window.GetTelemetryID("public-token-live-507a52ad-7e69-496b-aee0-1c9863c7c819", "https://telemetry.stytch.com/submit");')
+
+ while not payload_received:
+ await tab.sleep(.1)
+
+ except Exception as e:
+ print(f"Error occurred: {str(e)}")
+
+ finally:
+ try:
+ await tab.close()
+ except Exception as e:
+ print(f"Error occurred while closing tab: {str(e)}")
+
+ try:
+ await browser.stop()
+ except Exception as e:
+ pass
+
+ headers = {
+ 'Accept': '*/*',
+ 'Accept-Language': 'en,fr-FR;q=0.9,fr;q=0.8,es-ES;q=0.7,es;q=0.6,en-US;q=0.5,am;q=0.4,de;q=0.3',
+ 'Connection': 'keep-alive',
+ 'Content-type': 'application/x-www-form-urlencoded',
+ 'Origin': 'https://you.com',
+ 'Referer': 'https://you.com/',
+ 'Sec-Fetch-Dest': 'empty',
+ 'Sec-Fetch-Mode': 'cors',
+ 'Sec-Fetch-Site': 'cross-site',
+ 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36',
+ 'sec-ch-ua': '"Google Chrome";v="123", "Not:A-Brand";v="8", "Chromium";v="123"',
+ 'sec-ch-ua-mobile': '?0',
+ 'sec-ch-ua-platform': '"macOS"',
+ }
+
+ proxies = {
+ 'http': proxy,
+ 'https': proxy} if proxy else None
+
+ response = requests.post('https://telemetry.stytch.com/submit',
+ headers=headers, data=payload, proxies=proxies)
+
+ if '-' in response.text:
+ print(f'telemetry generated: {response.text}')
+
+ return (response.text)
diff --git a/g4f/client/async_client.py b/g4f/client/async_client.py
index 51a9cf83..8e1ee33c 100644
--- a/g4f/client/async_client.py
+++ b/g4f/client/async_client.py
@@ -144,7 +144,7 @@ class Completions():
proxy=self.client.get_proxy() if proxy is None else proxy,
max_tokens=max_tokens,
stop=stop,
- api_key=self.client.api_key if api_key is None else api_key
+ api_key=self.client.api_key if api_key is None else api_key,
**kwargs
)
response = iter_response(response, stream, response_format, max_tokens, stop)
@@ -207,4 +207,4 @@ class Images():
result = iter_image_response(response)
if result is None:
raise NoImageResponseError()
- return result \ No newline at end of file
+ return result
diff --git a/g4f/client/service.py b/g4f/client/service.py
index f3565f6d..d25c923d 100644
--- a/g4f/client/service.py
+++ b/g4f/client/service.py
@@ -55,9 +55,10 @@ def get_model_and_provider(model : Union[Model, str],
provider = convert_to_provider(provider)
if isinstance(model, str):
+
if model in ModelUtils.convert:
model = ModelUtils.convert[model]
-
+
if not provider:
if isinstance(model, str):
raise ModelNotFoundError(f'Model not found: {model}')
diff --git a/g4f/gui/client/static/js/chat.v1.js b/g4f/gui/client/static/js/chat.v1.js
index 8933b442..7f4011a2 100644
--- a/g4f/gui/client/static/js/chat.v1.js
+++ b/g4f/gui/client/static/js/chat.v1.js
@@ -1074,7 +1074,7 @@ async function load_version() {
}
setTimeout(load_version, 2000);
-for (const el of [imageInput, cameraInput]) {
+[imageInput, cameraInput].forEach((el) => {
el.addEventListener('click', async () => {
el.value = '';
if (imageInput.dataset.src) {
@@ -1082,7 +1082,7 @@ for (const el of [imageInput, cameraInput]) {
delete imageInput.dataset.src
}
});
-}
+});
fileInput.addEventListener('click', async (event) => {
fileInput.value = '';
@@ -1261,31 +1261,26 @@ if (SpeechRecognition) {
recognition.interimResults = true;
recognition.maxAlternatives = 1;
- function may_stop() {
- if (microLabel.classList.contains("recognition")) {
- recognition.stop();
- }
- }
-
let startValue;
- let timeoutHandle;
+ let shouldStop;
let lastDebounceTranscript;
recognition.onstart = function() {
microLabel.classList.add("recognition");
startValue = messageInput.value;
+ shouldStop = false;
lastDebounceTranscript = "";
- timeoutHandle = window.setTimeout(may_stop, 10000);
};
recognition.onend = function() {
- microLabel.classList.remove("recognition");
- messageInput.focus();
+ if (shouldStop) {
+ messageInput.focus();
+ } else {
+ recognition.start();
+ }
};
recognition.onresult = function(event) {
if (!event.results) {
return;
}
- window.clearTimeout(timeoutHandle);
-
let result = event.results[event.resultIndex];
let isFinal = result.isFinal && (result[0].confidence > 0);
let transcript = result[0].transcript;
@@ -1303,14 +1298,13 @@ if (SpeechRecognition) {
messageInput.style.height = messageInput.scrollHeight + "px";
messageInput.scrollTop = messageInput.scrollHeight;
}
-
- timeoutHandle = window.setTimeout(may_stop, transcript ? 10000 : 8000);
};
microLabel.addEventListener("click", () => {
if (microLabel.classList.contains("recognition")) {
- window.clearTimeout(timeoutHandle);
+ shouldStop = true;
recognition.stop();
+ microLabel.classList.remove("recognition");
} else {
const lang = document.getElementById("recognition-language")?.value;
recognition.lang = lang || navigator.language;
diff --git a/g4f/models.py b/g4f/models.py
index 4480dc10..fe99958c 100644
--- a/g4f/models.py
+++ b/g4f/models.py
@@ -20,6 +20,7 @@ from .Provider import (
Vercel,
Gemini,
Koala,
+ Cohere,
Bing,
You,
Pi,
@@ -77,6 +78,7 @@ gpt_35_turbo = Model(
You,
ChatgptNext,
Koala,
+ OpenaiChat,
])
)
@@ -161,11 +163,11 @@ mistral_7b_v02 = Model(
best_provider = DeepInfra
)
-# mixtral_8x22b = Model(
-# name = "mistralai/Mixtral-8x22B-v0.1",
-# base_provider = "huggingface",
-# best_provider = DeepInfra
-# )
+mixtral_8x22b = Model(
+ name = "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
+ base_provider = "huggingface",
+ best_provider = RetryProvider([HuggingChat, DeepInfra])
+)
# Misc models
dolphin_mixtral_8x7b = Model(
@@ -265,6 +267,18 @@ pi = Model(
best_provider = Pi
)
+dbrx_instruct = Model(
+ name = 'databricks/dbrx-instruct',
+ base_provider = 'mistral',
+ best_provider = RetryProvider([DeepInfra, PerplexityLabs])
+)
+
+command_r_plus = Model(
+ name = 'CohereForAI/c4ai-command-r-plus',
+ base_provider = 'mistral',
+ best_provider = RetryProvider([HuggingChat, Cohere])
+)
+
class ModelUtils:
"""
Utility class for mapping string identifiers to Model instances.
@@ -299,20 +313,29 @@ class ModelUtils:
'gigachat' : gigachat,
'gigachat_plus': gigachat_plus,
'gigachat_pro' : gigachat_pro,
-
+
+ # Mistral Opensource
'mixtral-8x7b': mixtral_8x7b,
'mistral-7b': mistral_7b,
'mistral-7b-v02': mistral_7b_v02,
- # 'mixtral-8x22b': mixtral_8x22b,
+ 'mixtral-8x22b': mixtral_8x22b,
'dolphin-mixtral-8x7b': dolphin_mixtral_8x7b,
- 'lzlv-70b': lzlv_70b,
- 'airoboros-70b': airoboros_70b,
- 'openchat_3.5': openchat_35,
+
+ # google gemini
'gemini': gemini,
'gemini-pro': gemini_pro,
+
+ # anthropic
'claude-v2': claude_v2,
'claude-3-opus': claude_3_opus,
'claude-3-sonnet': claude_3_sonnet,
+
+ # other
+ 'command-r+': command_r_plus,
+ 'dbrx-instruct': dbrx_instruct,
+ 'lzlv-70b': lzlv_70b,
+ 'airoboros-70b': airoboros_70b,
+ 'openchat_3.5': openchat_35,
'pi': pi
}
diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py
index 52f473e9..d64e8471 100644
--- a/g4f/providers/retry_provider.py
+++ b/g4f/providers/retry_provider.py
@@ -12,46 +12,40 @@ class RetryProvider(BaseRetryProvider):
def __init__(
self,
providers: List[Type[BaseProvider]],
- shuffle: bool = True
+ shuffle: bool = True,
+ single_provider_retry: bool = False,
+ max_retries: int = 3,
) -> None:
"""
Initialize the BaseRetryProvider.
-
Args:
providers (List[Type[BaseProvider]]): List of providers to use.
shuffle (bool): Whether to shuffle the providers list.
+ single_provider_retry (bool): Whether to retry a single provider if it fails.
+ max_retries (int): Maximum number of retries for a single provider.
"""
self.providers = providers
self.shuffle = shuffle
+ self.single_provider_retry = single_provider_retry
+ self.max_retries = max_retries
self.working = True
self.last_provider: Type[BaseProvider] = None
- """
- A provider class to handle retries for creating completions with different providers.
-
- Attributes:
- providers (list): A list of provider instances.
- shuffle (bool): A flag indicating whether to shuffle providers before use.
- last_provider (BaseProvider): The last provider that was used.
- """
def create_completion(
self,
model: str,
messages: Messages,
stream: bool = False,
- **kwargs
+ **kwargs,
) -> CreateResult:
"""
Create a completion using available providers, with an option to stream the response.
-
Args:
model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion.
stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
-
Yields:
CreateResult: Tokens or results from the completion.
-
Raises:
Exception: Any exception encountered during the completion process.
"""
@@ -61,22 +55,42 @@ class RetryProvider(BaseRetryProvider):
exceptions = {}
started: bool = False
- for provider in providers:
+
+ if self.single_provider_retry and len(providers) == 1:
+ provider = providers[0]
self.last_provider = provider
- try:
- if debug.logging:
- print(f"Using {provider.__name__} provider")
- for token in provider.create_completion(model, messages, stream, **kwargs):
- yield token
+ for attempt in range(self.max_retries):
+ try:
+ if debug.logging:
+ print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
+ for token in provider.create_completion(model, messages, stream, **kwargs):
+ yield token
started = True
- if started:
- return
- except Exception as e:
- exceptions[provider.__name__] = e
- if debug.logging:
- print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
- if started:
- raise e
+ if started:
+ return
+ except Exception as e:
+ exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+ if started:
+ raise e
+ else:
+ for provider in providers:
+ self.last_provider = provider
+ try:
+ if debug.logging:
+ print(f"Using {provider.__name__} provider")
+ for token in provider.create_completion(model, messages, stream, **kwargs):
+ yield token
+ started = True
+ if started:
+ return
+ except Exception as e:
+ exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+ if started:
+ raise e
raise_exceptions(exceptions)
@@ -84,18 +98,15 @@ class RetryProvider(BaseRetryProvider):
self,
model: str,
messages: Messages,
- **kwargs
+ **kwargs,
) -> str:
"""
Asynchronously create a completion using available providers.
-
Args:
model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion.
-
Returns:
str: The result of the asynchronous completion.
-
Raises:
Exception: Any exception encountered during the asynchronous completion process.
"""
@@ -104,17 +115,36 @@ class RetryProvider(BaseRetryProvider):
random.shuffle(providers)
exceptions = {}
- for provider in providers:
+
+ if self.single_provider_retry and len(providers) == 1:
+ provider = providers[0]
self.last_provider = provider
- try:
- return await asyncio.wait_for(
- provider.create_async(model, messages, **kwargs),
- timeout=kwargs.get("timeout", 60)
- )
- except Exception as e:
- exceptions[provider.__name__] = e
- if debug.logging:
- print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+ for attempt in range(self.max_retries):
+ try:
+ if debug.logging:
+ print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
+ return await asyncio.wait_for(
+ provider.create_async(model, messages, **kwargs),
+ timeout=kwargs.get("timeout", 60),
+ )
+ except Exception as e:
+ exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+ else:
+ for provider in providers:
+ self.last_provider = provider
+ try:
+ if debug.logging:
+ print(f"Using {provider.__name__} provider")
+ return await asyncio.wait_for(
+ provider.create_async(model, messages, **kwargs),
+ timeout=kwargs.get("timeout", 60),
+ )
+ except Exception as e:
+ exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
raise_exceptions(exceptions)