summaryrefslogtreecommitdiffstats
path: root/g4f/Provider
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/needs_auth/OpenaiChat.py64
1 files changed, 46 insertions, 18 deletions
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index 587c0a23..797455fe 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -7,7 +7,6 @@ import json
import base64
import time
import requests
-from aiohttp import ClientWebSocketResponse
from copy import copy
try:
@@ -28,7 +27,7 @@ from ...requests.raise_for_status import raise_for_status
from ...requests.aiohttp import StreamSession
from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format
from ...errors import MissingAuthError, ResponseError
-from ...providers.response import BaseConversation
+from ...providers.response import BaseConversation, FinishReason, SynthesizeData
from ..helper import format_cookies
from ..openai.har_file import get_request_config, NoValidHarFileError
from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url
@@ -367,19 +366,13 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
Raises:
RuntimeError: If an error occurs during processing.
"""
+ await cls.login(proxy)
+
async with StreamSession(
proxy=proxy,
impersonate="chrome",
timeout=timeout
) as session:
- if cls._expires is not None and cls._expires < time.time():
- cls._headers = cls._api_key = None
- try:
- await get_request_config(proxy)
- cls._create_request_args(RequestConfig.cookies, RequestConfig.headers)
- cls._set_api_key(RequestConfig.access_token)
- except NoValidHarFileError as e:
- await cls.nodriver_auth(proxy)
try:
image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None
except Exception as e:
@@ -469,18 +462,25 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
await asyncio.sleep(5)
continue
await raise_for_status(response)
- async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, conversation):
- if return_conversation:
- history_disabled = False
- return_conversation = False
- yield conversation
- yield chunk
+ if return_conversation:
+ history_disabled = False
+ yield conversation
+ async for line in response.iter_lines():
+ async for chunk in cls.iter_messages_line(session, line, conversation):
+ yield chunk
+ if not history_disabled:
+ yield SynthesizeData(cls.__name__, {
+ "conversation_id": conversation.conversation_id,
+ "message_id": conversation.message_id,
+ "voice": "maple",
+ })
if auto_continue and conversation.finish_reason == "max_tokens":
conversation.finish_reason = None
action = "continue"
await asyncio.sleep(5)
else:
break
+ yield FinishReason(conversation.finish_reason)
if history_disabled and auto_continue:
await cls.delete_conversation(session, cls._headers, conversation.conversation_id)
@@ -542,9 +542,37 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
raise RuntimeError(line.get("error"))
@classmethod
+ async def synthesize(cls, params: dict) -> AsyncIterator[bytes]:
+ await cls.login()
+ async with StreamSession(
+ impersonate="chrome",
+ timeout=900
+ ) as session:
+ async with session.get(
+ f"{cls.url}/backend-api/synthesize",
+ params=params,
+ headers=cls._headers
+ ) as response:
+ await raise_for_status(response)
+ async for chunk in response.iter_content():
+ yield chunk
+
+ @classmethod
+ async def login(cls, proxy: str = None):
+ if cls._expires is not None and cls._expires < time.time():
+ cls._headers = cls._api_key = None
+ try:
+ await get_request_config(proxy)
+ cls._create_request_args(RequestConfig.cookies, RequestConfig.headers)
+ cls._set_api_key(RequestConfig.access_token)
+ except NoValidHarFileError:
+ if has_nodriver:
+ await cls.nodriver_auth(proxy)
+ else:
+ raise
+
+ @classmethod
async def nodriver_auth(cls, proxy: str = None):
- if not has_nodriver:
- return
if has_platformdirs:
user_data_dir = user_config_dir("g4f-nodriver")
else: