diff options
author | H Lohaus <hlohaus@users.noreply.github.com> | 2024-02-28 09:48:57 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-28 09:48:57 +0100 |
commit | 96db520ff030cd0beae8b469876013b8f18b793a (patch) | |
tree | 0d2d6cf85371cc454279bd454ae851ca0fee930a /g4f | |
parent | Merge pull request #1635 from hlohaus/flow (diff) | |
parent | Add websocket support in OpenaiChat (diff) | |
download | gpt4free-96db520ff030cd0beae8b469876013b8f18b793a.tar gpt4free-96db520ff030cd0beae8b469876013b8f18b793a.tar.gz gpt4free-96db520ff030cd0beae8b469876013b8f18b793a.tar.bz2 gpt4free-96db520ff030cd0beae8b469876013b8f18b793a.tar.lz gpt4free-96db520ff030cd0beae8b469876013b8f18b793a.tar.xz gpt4free-96db520ff030cd0beae8b469876013b8f18b793a.tar.zst gpt4free-96db520ff030cd0beae8b469876013b8f18b793a.zip |
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/AiChatOnline.py | 2 | ||||
-rw-r--r-- | g4f/Provider/Aura.py | 5 | ||||
-rw-r--r-- | g4f/Provider/ChatgptAi.py | 2 | ||||
-rw-r--r-- | g4f/Provider/ChatgptDemo.py | 2 | ||||
-rw-r--r-- | g4f/Provider/ChatgptNext.py | 5 | ||||
-rw-r--r-- | g4f/Provider/Chatxyz.py | 2 | ||||
-rw-r--r-- | g4f/Provider/FlowGpt.py | 8 | ||||
-rw-r--r-- | g4f/Provider/GeminiPro.py | 3 | ||||
-rw-r--r-- | g4f/Provider/You.py | 19 | ||||
-rw-r--r-- | g4f/Provider/__init__.py | 2 | ||||
-rw-r--r-- | g4f/Provider/needs_auth/OpenaiChat.py | 291 | ||||
-rw-r--r-- | g4f/__init__.py | 4 | ||||
-rw-r--r-- | g4f/client.py | 24 | ||||
-rw-r--r-- | g4f/providers/base_provider.py | 2 | ||||
-rw-r--r-- | g4f/providers/retry_provider.py | 130 | ||||
-rw-r--r-- | g4f/providers/types.py | 21 | ||||
-rw-r--r-- | g4f/typing.py | 8 |
17 files changed, 346 insertions, 184 deletions
diff --git a/g4f/Provider/AiChatOnline.py b/g4f/Provider/AiChatOnline.py index dc774fe0..cc3b5b8e 100644 --- a/g4f/Provider/AiChatOnline.py +++ b/g4f/Provider/AiChatOnline.py @@ -9,7 +9,7 @@ from .helper import get_random_string class AiChatOnline(AsyncGeneratorProvider): url = "https://aichatonline.org" - working = True + working = False supports_gpt_35_turbo = True supports_message_history = False diff --git a/g4f/Provider/Aura.py b/g4f/Provider/Aura.py index 126c8d0f..d8f3471c 100644 --- a/g4f/Provider/Aura.py +++ b/g4f/Provider/Aura.py @@ -6,9 +6,8 @@ from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider class Aura(AsyncGeneratorProvider): - url = "https://openchat.team" - working = True - supports_gpt_35_turbo = True + url = "https://openchat.team" + working = True @classmethod async def create_async_generator( diff --git a/g4f/Provider/ChatgptAi.py b/g4f/Provider/ChatgptAi.py index f2785364..a38aea5e 100644 --- a/g4f/Provider/ChatgptAi.py +++ b/g4f/Provider/ChatgptAi.py @@ -9,7 +9,7 @@ from .base_provider import AsyncGeneratorProvider class ChatgptAi(AsyncGeneratorProvider): url = "https://chatgpt.ai" - working = True + working = False supports_message_history = True supports_gpt_35_turbo = True _system = None diff --git a/g4f/Provider/ChatgptDemo.py b/g4f/Provider/ChatgptDemo.py index 2f25477a..666b5753 100644 --- a/g4f/Provider/ChatgptDemo.py +++ b/g4f/Provider/ChatgptDemo.py @@ -10,7 +10,7 @@ from .helper import format_prompt class ChatgptDemo(AsyncGeneratorProvider): url = "https://chat.chatgptdemo.net" supports_gpt_35_turbo = True - working = True + working = False @classmethod async def create_async_generator( diff --git a/g4f/Provider/ChatgptNext.py b/g4f/Provider/ChatgptNext.py index c107a0bf..1ae37bd5 100644 --- a/g4f/Provider/ChatgptNext.py +++ b/g4f/Provider/ChatgptNext.py @@ -4,8 +4,7 @@ import json from aiohttp import ClientSession from ..typing import AsyncResult, Messages -from .base_provider import AsyncGeneratorProvider -from .helper import format_prompt +from ..providers.base_provider import AsyncGeneratorProvider class ChatgptNext(AsyncGeneratorProvider): @@ -24,7 +23,7 @@ class ChatgptNext(AsyncGeneratorProvider): if not model: model = "gpt-3.5-turbo" headers = { - "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/119.0", + "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:122.0) Gecko/20100101 Firefox/122.0", "Accept": "text/event-stream", "Accept-Language": "de,en-US;q=0.7,en;q=0.3", "Accept-Encoding": "gzip, deflate, br", diff --git a/g4f/Provider/Chatxyz.py b/g4f/Provider/Chatxyz.py index feb09be9..dd1216aa 100644 --- a/g4f/Provider/Chatxyz.py +++ b/g4f/Provider/Chatxyz.py @@ -8,7 +8,7 @@ from .base_provider import AsyncGeneratorProvider class Chatxyz(AsyncGeneratorProvider): url = "https://chat.3211000.xyz" - working = True + working = False supports_gpt_35_turbo = True supports_message_history = True diff --git a/g4f/Provider/FlowGpt.py b/g4f/Provider/FlowGpt.py index 39192bf9..b466a2e6 100644 --- a/g4f/Provider/FlowGpt.py +++ b/g4f/Provider/FlowGpt.py @@ -51,12 +51,16 @@ class FlowGpt(AsyncGeneratorProvider, ProviderModelMixin): "TE": "trailers" } async with ClientSession(headers=headers) as session: + history = [message for message in messages[:-1] if message["role"] != "system"] + system_message = "\n".join([message["content"] for message in messages if message["role"] == "system"]) + if not system_message: + system_message = "You are helpful assistant. Follow the user's instructions carefully." data = { "model": model, "nsfw": False, "question": messages[-1]["content"], - "history": [{"role": "assistant", "content": "Hello, how can I help you today?"}, *messages[:-1]], - "system": kwargs.get("system_message", "You are helpful assistant. Follow the user's instructions carefully."), + "history": [{"role": "assistant", "content": "Hello, how can I help you today?"}, *history], + "system": system_message, "temperature": kwargs.get("temperature", 0.7), "promptId": f"model-{model}", "documentIds": [], diff --git a/g4f/Provider/GeminiPro.py b/g4f/Provider/GeminiPro.py index a2e3538d..1c5487b1 100644 --- a/g4f/Provider/GeminiPro.py +++ b/g4f/Provider/GeminiPro.py @@ -27,6 +27,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): proxy: str = None, api_key: str = None, api_base: str = None, + use_auth_header: bool = True, image: ImageType = None, connector: BaseConnector = None, **kwargs @@ -38,7 +39,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): raise MissingAuthError('Missing "api_key"') headers = params = None - if api_base: + if api_base and use_auth_header: headers = {"Authorization": f"Bearer {api_key}"} else: params = {"key": api_key} diff --git a/g4f/Provider/You.py b/g4f/Provider/You.py index 34130c47..b21fd582 100644 --- a/g4f/Provider/You.py +++ b/g4f/Provider/You.py @@ -3,9 +3,9 @@ from __future__ import annotations import json import base64 import uuid -from aiohttp import ClientSession, FormData +from aiohttp import ClientSession, FormData, BaseConnector -from ..typing import AsyncGenerator, Messages, ImageType, Cookies +from ..typing import AsyncResult, Messages, ImageType, Cookies from .base_provider import AsyncGeneratorProvider from ..providers.helper import get_connector, format_prompt from ..image import to_bytes @@ -26,12 +26,13 @@ class You(AsyncGeneratorProvider): messages: Messages, image: ImageType = None, image_name: str = None, + connector: BaseConnector = None, proxy: str = None, chat_mode: str = "default", **kwargs, - ) -> AsyncGenerator: + ) -> AsyncResult: async with ClientSession( - connector=get_connector(kwargs.get("connector"), proxy), + connector=get_connector(connector, proxy), headers=DEFAULT_HEADERS ) as client: if image: @@ -72,13 +73,13 @@ class You(AsyncGeneratorProvider): response.raise_for_status() async for line in response.content: if line.startswith(b'event: '): - event = line[7:-1] + event = line[7:-1].decode() elif line.startswith(b'data: '): - if event == b"youChatUpdate" or event == b"youChatToken": + if event in ["youChatUpdate", "youChatToken"]: data = json.loads(line[6:-1]) - if event == b"youChatToken" and "youChatToken" in data: - yield data["youChatToken"] - elif event == b"youChatUpdate" and "t" in data: + if event == "youChatToken" and event in data: + yield data[event] + elif event == "youChatUpdate" and "t" in data: yield data["t"] @classmethod diff --git a/g4f/Provider/__init__.py b/g4f/Provider/__init__.py index 6cdc8806..52ba0274 100644 --- a/g4f/Provider/__init__.py +++ b/g4f/Provider/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from ..providers.types import BaseProvider, ProviderType -from ..providers.retry_provider import RetryProvider +from ..providers.retry_provider import RetryProvider, IterProvider from ..providers.base_provider import AsyncProvider, AsyncGeneratorProvider from ..providers.create_images import CreateImagesProvider diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 556c3d9b..0fa433a4 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -4,6 +4,8 @@ import asyncio import uuid import json import os +import base64 +from aiohttp import ClientWebSocketResponse try: from py_arkose_generator.arkose import get_values_for_request @@ -20,9 +22,9 @@ except ImportError: pass from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin -from ..helper import format_prompt, get_cookies -from ...webdriver import get_browser, get_driver_cookies -from ...typing import AsyncResult, Messages, Cookies, ImageType +from ..helper import get_cookies +from ...webdriver import get_browser +from ...typing import AsyncResult, Messages, Cookies, ImageType, Union, AsyncIterator from ...requests import get_args_from_browser from ...requests.aiohttp import StreamSession from ...image import to_image, to_bytes, ImageResponse, ImageRequest @@ -37,10 +39,15 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): needs_auth = True supports_gpt_35_turbo = True supports_gpt_4 = True + supports_message_history = True + supports_system_message = True default_model = None models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"] - model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo"} - _args: dict = None + model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo", "": "gpt-3.5-turbo"} + _api_key: str = None + _headers: dict = None + _cookies: Cookies = None + _last_message: int = 0 @classmethod async def create( @@ -170,6 +177,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): """ if not cls.default_model: async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response: + cls._update_request_args(session) response.raise_for_status() data = await response.json() if "categories" in data: @@ -179,7 +187,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): return cls.default_model @classmethod - def create_messages(cls, prompt: str, image_request: ImageRequest = None): + def create_messages(cls, messages: Messages, image_request: ImageRequest = None): """ Create a list of messages for the user input @@ -190,31 +198,27 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): Returns: A list of messages with the user input and the image, if any """ + # Create a message object with the user role and the content + messages = [{ + "id": str(uuid.uuid4()), + "author": {"role": message["role"]}, + "content": {"content_type": "text", "parts": [message["content"]]}, + } for message in messages] + # Check if there is an image response - if not image_request: - # Create a content object with the text type and the prompt - content = {"content_type": "text", "parts": [prompt]} - else: - # Create a content object with the multimodal text type and the image and the prompt - content = { + if image_request: + # Change content in last user message + messages[-1]["content"] = { "content_type": "multimodal_text", "parts": [{ "asset_pointer": f"file-service://{image_request.get('file_id')}", "height": image_request.get("height"), "size_bytes": image_request.get("file_size"), "width": image_request.get("width"), - }, prompt] + }, messages[-1]["content"]["parts"][0]] } - # Create a message object with the user role and the content - messages = [{ - "id": str(uuid.uuid4()), - "author": {"role": "user"}, - "content": content, - }] - # Check if there is an image response - if image_request: # Add the metadata object with the attachments - messages[0]["metadata"] = { + messages[-1]["metadata"] = { "attachments": [{ "height": image_request.get("height"), "id": image_request.get("file_id"), @@ -225,7 +229,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): }] } return messages - + @classmethod async def get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse: """ @@ -301,6 +305,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): conversation_id: str = None, parent_id: str = None, image: ImageType = None, + image_name: str = None, response_fields: bool = False, **kwargs ) -> AsyncResult: @@ -333,50 +338,65 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): raise MissingRequirementsError('Install "py-arkose-generator" and "async_property" package') if not parent_id: parent_id = str(uuid.uuid4()) - if cls._args is None and cookies is None: - cookies = get_cookies("chat.openai.com", False) + + # Read api_key from arguments api_key = kwargs["access_token"] if "access_token" in kwargs else api_key - if api_key is None and cookies is not None: - api_key = cookies["access_token"] if "access_token" in cookies else api_key - if cls._args is None: - cls._args = { - "headers": {"Cookie": "; ".join(f"{k}={v}" for k, v in cookies.items() if k != "access_token")}, - "cookies": {} if cookies is None else cookies - } - if api_key is not None: - cls._args["headers"]["Authorization"] = f"Bearer {api_key}" + async with StreamSession( proxies={"https": proxy}, impersonate="chrome", - timeout=timeout, - headers=cls._args["headers"] + timeout=timeout ) as session: - if api_key is not None: + # Read api_key and cookies from cache / browser config + if cls._headers is None: + if api_key is None: + # Read api_key from cookies + cookies = get_cookies("chat.openai.com", False) if cookies is None else cookies + api_key = cookies["access_token"] if "access_token" in cookies else api_key + cls._create_request_args(cookies) + else: + api_key = cls._api_key if api_key is None else api_key + # Read api_key with session cookies + if api_key is None and cookies: + api_key = await cls.fetch_access_token(session, cls._headers) + # Load default model + if cls.default_model is None and api_key is not None: try: - cls.default_model = await cls.get_default_model(session, cls._args["headers"]) + if not model: + cls._set_api_key(api_key) + cls.default_model = cls.get_model(await cls.get_default_model(session, cls._headers)) + else: + cls.default_model = cls.get_model(model) except Exception as e: if debug.logging: + print("OpenaiChat: Load default_model failed") print(f"{e.__class__.__name__}: {e}") - if cls.default_model is None: + # Browse api_key and default model + if api_key is None or cls.default_model is None: login_url = os.environ.get("G4F_LOGIN_URL") if login_url: yield f"Please login: [ChatGPT]({login_url})\n\n" try: - cls._args = cls.browse_access_token(proxy) + cls.browse_access_token(proxy) except MissingRequirementsError: - raise MissingAuthError(f'Missing or invalid "access_token". Add a new "api_key" please') - cls.default_model = await cls.get_default_model(session, cls._args["headers"]) + raise MissingAuthError(f'Missing "access_token". Add a "api_key" please') + cls.default_model = cls.get_model(await cls.get_default_model(session, cls._headers)) + else: + cls._set_api_key(api_key) + try: - image_response = None - if image: - image_response = await cls.upload_image(session, cls._args["headers"], image, kwargs.get("image_name")) + image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None except Exception as e: - yield e - end_turn = EndTurn() - model = cls.get_model(model) - model = "text-davinci-002-render-sha" if model == "gpt-3.5-turbo" else model - while not end_turn.is_end: + if debug.logging: + print("OpenaiChat: Upload image failed") + print(f"{e.__class__.__name__}: {e}") + + model = cls.get_model(model).replace("gpt-3.5-turbo", "text-davinci-002-render-sha") + fields = ResponseFields() + while fields.finish_reason is None: arkose_token = await cls.get_arkose_token(session) + conversation_id = conversation_id if fields.conversation_id is None else fields.conversation_id + parent_id = parent_id if fields.message_id is None else fields.message_id data = { "action": action, "arkose_token": arkose_token, @@ -389,13 +409,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): "history_and_training_disabled": history_disabled and not auto_continue, } if action != "continue": - prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"] - data["messages"] = cls.create_messages(prompt, image_response) - - # Update cookies before next request - for c in session.cookie_jar if hasattr(session, "cookie_jar") else session.cookies.jar: - cls._args["cookies"][c.name if hasattr(c, "name") else c.key] = c.value - cls._args["headers"]["Cookie"] = "; ".join(f"{k}={v}" for k, v in cls._args["cookies"].items()) + messages = messages if conversation_id is None else [messages[-1]] + data["messages"] = cls.create_messages(messages, image_request) async with session.post( f"{cls.url}/backend-api/conversation", @@ -403,61 +418,88 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): headers={ "Accept": "text/event-stream", "OpenAI-Sentinel-Arkose-Token": arkose_token, - **cls._args["headers"] + **cls._headers } ) as response: + cls._update_request_args(session) if not response.ok: raise RuntimeError(f"Response {response.status}: {await response.text()}") - last_message: int = 0 - async for line in response.iter_lines(): - if not line.startswith(b"data: "): - continue - elif line.startswith(b"data: [DONE]"): - break - try: - line = json.loads(line[6:]) - except: - continue - if "message" not in line: - continue - if "error" in line and line["error"]: - raise RuntimeError(line["error"]) - if "message_type" not in line["message"]["metadata"]: - continue - try: - image_response = await cls.get_generated_image(session, cls._args["headers"], line) - if image_response is not None: - yield image_response - except Exception as e: - yield e - if line["message"]["author"]["role"] != "assistant": - continue - if line["message"]["content"]["content_type"] != "text": - continue - if line["message"]["metadata"]["message_type"] not in ("next", "continue", "variant"): - continue - conversation_id = line["conversation_id"] - parent_id = line["message"]["id"] + async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, fields): if response_fields: response_fields = False - yield ResponseFields(conversation_id, parent_id, end_turn) - if "parts" in line["message"]["content"]: - new_message = line["message"]["content"]["parts"][0] - if len(new_message) > last_message: - yield new_message[last_message:] - last_message = len(new_message) - if "finish_details" in line["message"]["metadata"]: - if line["message"]["metadata"]["finish_details"]["type"] == "stop": - end_turn.end() + yield fields + yield chunk if not auto_continue: break action = "continue" await asyncio.sleep(5) if history_disabled and auto_continue: - await cls.delete_conversation(session, cls._args["headers"], conversation_id) + await cls.delete_conversation(session, cls._headers, conversation_id) + + @staticmethod + async def iter_messages_ws(ws: ClientWebSocketResponse) -> AsyncIterator: + while True: + yield base64.b64decode((await ws.receive_json())["body"]) @classmethod - def browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> tuple[str, dict]: + async def iter_messages_chunk(cls, messages: AsyncIterator, session: StreamSession, fields: ResponseFields) -> AsyncIterator: + last_message: int = 0 + async for message in messages: + if message.startswith(b'{"wss_url":'): + async with session.ws_connect(json.loads(message)["wss_url"]) as ws: + async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws), session, fields): + yield chunk + break + async for chunk in cls.iter_messages_line(session, message, fields): + if fields.finish_reason is not None: + break + elif isinstance(chunk, str): + if len(chunk) > last_message: + yield chunk[last_message:] + last_message = len(chunk) + else: + yield chunk + if fields.finish_reason is not None: + break + + @classmethod + async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: ResponseFields) -> AsyncIterator: + if not line.startswith(b"data: "): + return + elif line.startswith(b"data: [DONE]"): + return + try: + line = json.loads(line[6:]) + except: + return + if "message" not in line: + return + if "error" in line and line["error"]: + raise RuntimeError(line["error"]) + if "message_type" not in line["message"]["metadata"]: + return + try: + image_response = await cls.get_generated_image(session, cls._headers, line) + if image_response is not None: + yield image_response + except Exception as e: + yield e + if line["message"]["author"]["role"] != "assistant": + return + if line["message"]["content"]["content_type"] != "text": + return + if line["message"]["metadata"]["message_type"] not in ("next", "continue", "variant"): + return + if fields.conversation_id is None: + fields.conversation_id = line["conversation_id"] + fields.message_id = line["message"]["id"] + if "parts" in line["message"]["content"]: + yield line["message"]["content"]["parts"][0] + if "finish_details" in line["message"]["metadata"]: + fields.finish_reason = line["message"]["metadata"]["finish_details"]["type"] + + @classmethod + def browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> None: """ Browse to obtain an access token. @@ -475,14 +517,15 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): "let session = await fetch('/api/auth/session');" "let data = await session.json();" "let accessToken = data['accessToken'];" - "let expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 4);" + "let expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 4 * 1000);" "document.cookie = 'access_token=' + accessToken + ';expires=' + expires.toUTCString() + ';path=/';" "return accessToken;" ) args = get_args_from_browser(f"{cls.url}/", driver, do_bypass_cloudflare=False) - args["headers"]["Authorization"] = f"Bearer {access_token}" - args["headers"]["Cookie"] = "; ".join(f"{k}={v}" for k, v in args["cookies"].items() if k != "access_token") - return args + cls._headers = args["headers"] + cls._cookies = args["cookies"] + cls._update_cookie_header() + cls._set_api_key(access_token) finally: driver.close() @@ -516,6 +559,42 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): return decoded_json["token"] raise RuntimeError(f"Response: {decoded_json}") + @classmethod + async def fetch_access_token(cls, session: StreamSession, headers: dict): + async with session.get( + f"{cls.url}/api/auth/session", + headers=headers + ) as response: + if response.ok: + data = await response.json() + if "accessToken" in data: + return data["accessToken"] + + @staticmethod + def _format_cookies(cookies: Cookies): + return "; ".join(f"{k}={v}" for k, v in cookies.items() if k != "access_token") + + @classmethod + def _create_request_args(cls, cookies: Union[Cookies, None]): + cls._headers = {} + cls._cookies = {} if cookies is None else cookies + cls._update_cookie_header() + + @classmethod + def _update_request_args(cls, session: StreamSession): + for c in session.cookie_jar if hasattr(session, "cookie_jar") else session.cookies.jar: + cls._cookies[c.name if hasattr(c, "name") else c.key] = c.value + cls._update_cookie_header() + + @classmethod + def _set_api_key(cls, api_key: str): + cls._api_key = api_key + cls._headers["Authorization"] = f"Bearer {api_key}" + + @classmethod + def _update_cookie_header(cls): + cls._headers["Cookie"] = cls._format_cookies(cls._cookies) + class EndTurn: """ Class to represent the end of a conversation turn. @@ -530,10 +609,10 @@ class ResponseFields: """ Class to encapsulate response fields. """ - def __init__(self, conversation_id: str, message_id: str, end_turn: EndTurn): + def __init__(self, conversation_id: str = None, message_id: str = None, finish_reason: str = None): self.conversation_id = conversation_id self.message_id = message_id - self._end_turn = end_turn + self.finish_reason = finish_reason class Response(): """ @@ -567,7 +646,7 @@ class Response(): self._message = "".join(chunks) if not self._fields: raise RuntimeError("Missing response fields") - self.is_end = self._fields._end_turn.is_end + self.is_end = self._fields.end_turn def __aiter__(self): return self.generator() diff --git a/g4f/__init__.py b/g4f/__init__.py index 5df942ae..441225f1 100644 --- a/g4f/__init__.py +++ b/g4f/__init__.py @@ -10,7 +10,7 @@ from .cookies import get_cookies, set_cookies from . import debug, version from .providers.types import BaseRetryProvider, ProviderType from .providers.base_provider import ProviderModelMixin -from .providers.retry_provider import RetryProvider +from .providers.retry_provider import IterProvider def get_model_and_provider(model : Union[Model, str], provider : Union[ProviderType, str, None], @@ -48,7 +48,7 @@ def get_model_and_provider(model : Union[Model, str], provider_list = [ProviderUtils.convert[p] for p in provider.split() if p in ProviderUtils.convert] if not provider_list: raise ProviderNotFoundError(f'Providers not found: {provider}') - provider = RetryProvider(provider_list, False) + provider = IterProvider(provider_list) elif provider in ProviderUtils.convert: provider = ProviderUtils.convert[provider] else: diff --git a/g4f/client.py b/g4f/client.py index 3e11fac1..750c623f 100644 --- a/g4f/client.py +++ b/g4f/client.py @@ -7,7 +7,7 @@ import random import string from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse -from .typing import Union, Generator, Messages, ImageType +from .typing import Union, Iterator, Messages, ImageType from .providers.types import BaseProvider, ProviderType from .image import ImageResponse as ImageProviderResponse from .Provider.BingCreateImages import BingCreateImages @@ -17,7 +17,7 @@ from . import get_model_and_provider, get_last_provider ImageProvider = Union[BaseProvider, object] Proxies = Union[dict, str] -IterResponse = Generator[Union[ChatCompletion, ChatCompletionChunk], None, None] +IterResponse = Iterator[Union[ChatCompletion, ChatCompletionChunk]] def read_json(text: str) -> dict: """ @@ -110,6 +110,12 @@ class Client(): elif "https" in self.proxies: return self.proxies["https"] +def filter_none(**kwargs): + for key in list(kwargs.keys()): + if kwargs[key] is None: + del kwargs[key] + return kwargs + class Completions(): def __init__(self, client: Client, provider: ProviderType = None): self.client: Client = client @@ -126,7 +132,7 @@ class Completions(): stop: Union[list[str], str] = None, api_key: str = None, **kwargs - ) -> Union[ChatCompletion, Generator[ChatCompletionChunk]]: + ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: model, provider = get_model_and_provider( model, self.provider if provider is None else provider, @@ -135,11 +141,13 @@ class Completions(): ) stop = [stop] if isinstance(stop, str) else stop response = provider.create_completion( - model, messages, stream, - proxy=self.client.get_proxy(), - max_tokens=max_tokens, - stop=stop, - api_key=self.client.api_key if api_key is None else api_key, + model, messages, stream, + **filter_none( + proxy=self.client.get_proxy(), + max_tokens=max_tokens, + stop=stop, + api_key=self.client.api_key if api_key is None else api_key + ), **kwargs ) response = iter_response(response, stream, response_format, max_tokens, stop) diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py index b8649ba5..17c45875 100644 --- a/g4f/providers/base_provider.py +++ b/g4f/providers/base_provider.py @@ -270,7 +270,7 @@ class ProviderModelMixin: @classmethod def get_model(cls, model: str) -> str: - if not model: + if not model and cls.default_model is not None: model = cls.default_model elif model in cls.model_aliases: model = cls.model_aliases[model] diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py index a7ab2881..52f473e9 100644 --- a/g4f/providers/retry_provider.py +++ b/g4f/providers/retry_provider.py @@ -3,22 +3,37 @@ from __future__ import annotations import asyncio import random -from ..typing import CreateResult, Messages -from .types import BaseRetryProvider +from ..typing import Type, List, CreateResult, Messages, Iterator +from .types import BaseProvider, BaseRetryProvider from .. import debug from ..errors import RetryProviderError, RetryNoProviderError class RetryProvider(BaseRetryProvider): + def __init__( + self, + providers: List[Type[BaseProvider]], + shuffle: bool = True + ) -> None: + """ + Initialize the BaseRetryProvider. + + Args: + providers (List[Type[BaseProvider]]): List of providers to use. + shuffle (bool): Whether to shuffle the providers list. + """ + self.providers = providers + self.shuffle = shuffle + self.working = True + self.last_provider: Type[BaseProvider] = None + """ A provider class to handle retries for creating completions with different providers. Attributes: providers (list): A list of provider instances. shuffle (bool): A flag indicating whether to shuffle providers before use. - exceptions (dict): A dictionary to store exceptions encountered during retries. last_provider (BaseProvider): The last provider that was used. """ - def create_completion( self, model: str, @@ -44,7 +59,7 @@ class RetryProvider(BaseRetryProvider): if self.shuffle: random.shuffle(providers) - self.exceptions = {} + exceptions = {} started: bool = False for provider in providers: self.last_provider = provider @@ -57,13 +72,13 @@ class RetryProvider(BaseRetryProvider): if started: return except Exception as e: - self.exceptions[provider.__name__] = e + exceptions[provider.__name__] = e if debug.logging: print(f"{provider.__name__}: {e.__class__.__name__}: {e}") if started: raise e - self.raise_exceptions() + raise_exceptions(exceptions) async def create_async( self, @@ -88,7 +103,7 @@ class RetryProvider(BaseRetryProvider): if self.shuffle: random.shuffle(providers) - self.exceptions = {} + exceptions = {} for provider in providers: self.last_provider = provider try: @@ -97,23 +112,94 @@ class RetryProvider(BaseRetryProvider): timeout=kwargs.get("timeout", 60) ) except Exception as e: - self.exceptions[provider.__name__] = e + exceptions[provider.__name__] = e if debug.logging: print(f"{provider.__name__}: {e.__class__.__name__}: {e}") - self.raise_exceptions() + raise_exceptions(exceptions) - def raise_exceptions(self) -> None: - """ - Raise a combined exception if any occurred during retries. +class IterProvider(BaseRetryProvider): + __name__ = "IterProvider" - Raises: - RetryProviderError: If any provider encountered an exception. - RetryNoProviderError: If no provider is found. - """ - if self.exceptions: - raise RetryProviderError("RetryProvider failed:\n" + "\n".join([ - f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items() - ])) + def __init__( + self, + providers: List[BaseProvider], + ) -> None: + providers.reverse() + self.providers: List[BaseProvider] = providers + self.working: bool = True + self.last_provider: BaseProvider = None + + def create_completion( + self, + model: str, + messages: Messages, + stream: bool = False, + **kwargs + ) -> CreateResult: + exceptions: dict = {} + started: bool = False + for provider in self.iter_providers(): + if stream and not provider.supports_stream: + continue + try: + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token + started = True + if started: + return + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + if started: + raise e + raise_exceptions(exceptions) + + async def create_async( + self, + model: str, + messages: Messages, + **kwargs + ) -> str: + exceptions: dict = {} + for provider in self.iter_providers(): + try: + return await asyncio.wait_for( + provider.create_async(model, messages, **kwargs), + timeout=kwargs.get("timeout", 60) + ) + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + raise_exceptions(exceptions) + + def iter_providers(self) -> Iterator[BaseProvider]: + used_provider = [] + try: + while self.providers: + provider = self.providers.pop() + used_provider.append(provider) + self.last_provider = provider + if debug.logging: + print(f"Using {provider.__name__} provider") + yield provider + finally: + used_provider.reverse() + self.providers = [*used_provider, *self.providers] + +def raise_exceptions(exceptions: dict) -> None: + """ + Raise a combined exception if any occurred during retries. + + Raises: + RetryProviderError: If any provider encountered an exception. + RetryNoProviderError: If no provider is found. + """ + if exceptions: + raise RetryProviderError("RetryProvider failed:\n" + "\n".join([ + f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in exceptions.items() + ])) - raise RetryNoProviderError("No provider found")
\ No newline at end of file + raise RetryNoProviderError("No provider found")
\ No newline at end of file diff --git a/g4f/providers/types.py b/g4f/providers/types.py index 7b11ec43..67340958 100644 --- a/g4f/providers/types.py +++ b/g4f/providers/types.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Union, List, Dict, Type +from typing import Union, Dict, Type from ..typing import Messages, CreateResult class BaseProvider(ABC): @@ -26,6 +26,7 @@ class BaseProvider(ABC): supports_gpt_35_turbo: bool = False supports_gpt_4: bool = False supports_message_history: bool = False + supports_system_message: bool = False params: str @classmethod @@ -96,22 +97,4 @@ class BaseRetryProvider(BaseProvider): __name__: str = "RetryProvider" supports_stream: bool = True - def __init__( - self, - providers: List[Type[BaseProvider]], - shuffle: bool = True - ) -> None: - """ - Initialize the BaseRetryProvider. - - Args: - providers (List[Type[BaseProvider]]): List of providers to use. - shuffle (bool): Whether to shuffle the providers list. - """ - self.providers = providers - self.shuffle = shuffle - self.working = True - self.exceptions: Dict[str, Exception] = {} - self.last_provider: Type[BaseProvider] = None - ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
\ No newline at end of file diff --git a/g4f/typing.py b/g4f/typing.py index 386b3dfc..5d1bc959 100644 --- a/g4f/typing.py +++ b/g4f/typing.py @@ -1,5 +1,5 @@ import sys -from typing import Any, AsyncGenerator, Generator, NewType, Tuple, Union, List, Dict, Type, IO, Optional +from typing import Any, AsyncGenerator, Generator, AsyncIterator, Iterator, NewType, Tuple, Union, List, Dict, Type, IO, Optional try: from PIL.Image import Image @@ -12,8 +12,8 @@ else: from typing_extensions import TypedDict SHA256 = NewType('sha_256_hash', str) -CreateResult = Generator[str, None, None] -AsyncResult = AsyncGenerator[str, None] +CreateResult = Iterator[str] +AsyncResult = AsyncIterator[str] Messages = List[Dict[str, str]] Cookies = Dict[str, str] ImageType = Union[str, bytes, IO, Image, None] @@ -22,6 +22,8 @@ __all__ = [ 'Any', 'AsyncGenerator', 'Generator', + 'AsyncIterator', + 'Iterator' 'Tuple', 'Union', 'List', |