diff options
Diffstat (limited to 'g4f/Provider/needs_auth/OpenaiChat.py')
-rw-r--r-- | g4f/Provider/needs_auth/OpenaiChat.py | 65 |
1 files changed, 13 insertions, 52 deletions
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 15a87f38..97515ec4 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -65,6 +65,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): default_vision_model = "gpt-4o" fallback_models = ["auto", "gpt-4", "gpt-4o", "gpt-4o-mini", "gpt-4o-canmore", "o1-preview", "o1-mini"] vision_models = fallback_models + image_models = fallback_models _api_key: str = None _headers: dict = None @@ -330,7 +331,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): api_key: str = None, cookies: Cookies = None, auto_continue: bool = False, - history_disabled: bool = True, + history_disabled: bool = False, action: str = "next", conversation_id: str = None, conversation: Conversation = None, @@ -425,12 +426,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): f"Arkose: {'False' if not need_arkose else RequestConfig.arkose_token[:12]+'...'}", f"Proofofwork: {'False' if proofofwork is None else proofofwork[:12]+'...'}", )] - ws = None - if need_arkose: - async with session.post(f"{cls.url}/backend-api/register-websocket", headers=cls._headers) as response: - wss_url = (await response.json()).get("wss_url") - if wss_url: - ws = await session.ws_connect(wss_url) data = { "action": action, "messages": None, @@ -474,7 +469,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): await asyncio.sleep(5) continue await raise_for_status(response) - async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, conversation, ws): + async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, conversation): if return_conversation: history_disabled = False return_conversation = False @@ -489,44 +484,16 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): if history_disabled and auto_continue: await cls.delete_conversation(session, cls._headers, conversation.conversation_id) - @staticmethod - async def iter_messages_ws(ws: ClientWebSocketResponse, conversation_id: str, is_curl: bool) -> AsyncIterator: - while True: - if is_curl: - message = json.loads(ws.recv()[0]) - else: - message = await ws.receive_json() - if message["conversation_id"] == conversation_id: - yield base64.b64decode(message["body"]) - @classmethod async def iter_messages_chunk( cls, messages: AsyncIterator, session: StreamSession, fields: Conversation, - ws = None ) -> AsyncIterator: async for message in messages: - if message.startswith(b'{"wss_url":'): - message = json.loads(message) - ws = await session.ws_connect(message["wss_url"]) if ws is None else ws - try: - async for chunk in cls.iter_messages_chunk( - cls.iter_messages_ws(ws, message["conversation_id"], hasattr(ws, "recv")), - session, fields - ): - yield chunk - finally: - await ws.aclose() if hasattr(ws, "aclose") else await ws.close() - break async for chunk in cls.iter_messages_line(session, message, fields): - if fields.finish_reason is not None: - break - else: - yield chunk - if fields.finish_reason is not None: - break + yield chunk @classmethod async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: Conversation) -> AsyncIterator: @@ -542,9 +509,9 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): return if isinstance(line, dict) and "v" in line: v = line.get("v") - if isinstance(v, str): + if isinstance(v, str) and fields.is_recipient: yield v - elif isinstance(v, list): + elif isinstance(v, list) and fields.is_recipient: for m in v: if m.get("p") == "/message/content/parts/0": yield m.get("v") @@ -556,25 +523,20 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): fields.conversation_id = v.get("conversation_id") debug.log(f"OpenaiChat: New conversation: {fields.conversation_id}") m = v.get("message", {}) - if m.get("author", {}).get("role") == "assistant": - fields.message_id = v.get("message", {}).get("id") + fields.is_recipient = m.get("recipient") == "all" + if fields.is_recipient: c = m.get("content", {}) if c.get("content_type") == "multimodal_text": generated_images = [] for element in c.get("parts"): - if isinstance(element, str): - debug.log(f"No image or text: {line}") - elif element.get("content_type") == "image_asset_pointer": + if isinstance(element, dict) and element.get("content_type") == "image_asset_pointer": generated_images.append( cls.get_generated_image(session, cls._headers, element) ) - elif element.get("content_type") == "text": - for part in element.get("parts", []): - yield part for image_response in await asyncio.gather(*generated_images): yield image_response - else: - debug.log(f"OpenaiChat: {line}") + if m.get("author", {}).get("role") == "assistant": + fields.message_id = v.get("message", {}).get("id") return if "error" in line and line.get("error"): raise RuntimeError(line.get("error")) @@ -652,7 +614,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): cls._headers = cls.get_default_headers() if headers is None else headers if user_agent is not None: cls._headers["user-agent"] = user_agent - cls._cookies = {} if cookies is None else {k: v for k, v in cookies.items() if k != "access_token"} + cls._cookies = {} if cookies is None else cookies cls._update_cookie_header() @classmethod @@ -671,8 +633,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): @classmethod def _update_cookie_header(cls): cls._headers["cookie"] = format_cookies(cls._cookies) - if "oai-did" in cls._cookies: - cls._headers["oai-device-id"] = cls._cookies["oai-did"] class Conversation(BaseConversation): """ @@ -682,6 +642,7 @@ class Conversation(BaseConversation): self.conversation_id = conversation_id self.message_id = message_id self.finish_reason = finish_reason + self.is_recipient = False class Response(): """ |