summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/needs_auth
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/needs_auth')
-rw-r--r--g4f/Provider/needs_auth/OpenaiChat.py17
1 files changed, 11 insertions, 6 deletions
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index e507404b..1a6fd947 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -5,6 +5,7 @@ import uuid
import json
import os
import base64
+import time
from aiohttp import ClientWebSocketResponse
try:
@@ -47,7 +48,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
_api_key: str = None
_headers: dict = None
_cookies: Cookies = None
- _last_message: int = 0
+ _expires: int = None
@classmethod
async def create(
@@ -348,7 +349,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
timeout=timeout
) as session:
# Read api_key and cookies from cache / browser config
- if cls._headers is None:
+ if cls._headers is None or time.time() > cls._expires:
if api_key is None:
# Read api_key from cookies
cookies = get_cookies("chat.openai.com", False) if cookies is None else cookies
@@ -437,17 +438,20 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
await cls.delete_conversation(session, cls._headers, fields.conversation_id)
@staticmethod
- async def iter_messages_ws(ws: ClientWebSocketResponse) -> AsyncIterator:
+ async def iter_messages_ws(ws: ClientWebSocketResponse, conversation_id: str) -> AsyncIterator:
while True:
- yield base64.b64decode((await ws.receive_json())["body"])
+ message = await ws.receive_json()
+ if message["conversation_id"] == conversation_id:
+ yield base64.b64decode(message["body"])
@classmethod
async def iter_messages_chunk(cls, messages: AsyncIterator, session: StreamSession, fields: ResponseFields) -> AsyncIterator:
last_message: int = 0
async for message in messages:
if message.startswith(b'{"wss_url":'):
- async with session.ws_connect(json.loads(message)["wss_url"]) as ws:
- async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws), session, fields):
+ message = json.loads(message)
+ async with session.ws_connect(message["wss_url"]) as ws:
+ async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws, message["conversation_id"]), session, fields):
yield chunk
break
async for chunk in cls.iter_messages_line(session, message, fields):
@@ -589,6 +593,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
@classmethod
def _set_api_key(cls, api_key: str):
cls._api_key = api_key
+ cls._expires = int(time.time()) + 60 * 60 * 4
cls._headers["Authorization"] = f"Bearer {api_key}"
@classmethod