summaryrefslogtreecommitdiffstats
path: root/g4f/Provider
diff options
context:
space:
mode:
authorH Lohaus <hlohaus@users.noreply.github.com>2024-02-28 09:48:57 +0100
committerGitHub <noreply@github.com>2024-02-28 09:48:57 +0100
commit96db520ff030cd0beae8b469876013b8f18b793a (patch)
tree0d2d6cf85371cc454279bd454ae851ca0fee930a /g4f/Provider
parentMerge pull request #1635 from hlohaus/flow (diff)
parentAdd websocket support in OpenaiChat (diff)
downloadgpt4free-96db520ff030cd0beae8b469876013b8f18b793a.tar
gpt4free-96db520ff030cd0beae8b469876013b8f18b793a.tar.gz
gpt4free-96db520ff030cd0beae8b469876013b8f18b793a.tar.bz2
gpt4free-96db520ff030cd0beae8b469876013b8f18b793a.tar.lz
gpt4free-96db520ff030cd0beae8b469876013b8f18b793a.tar.xz
gpt4free-96db520ff030cd0beae8b469876013b8f18b793a.tar.zst
gpt4free-96db520ff030cd0beae8b469876013b8f18b793a.zip
Diffstat (limited to 'g4f/Provider')
-rw-r--r--g4f/Provider/AiChatOnline.py2
-rw-r--r--g4f/Provider/Aura.py5
-rw-r--r--g4f/Provider/ChatgptAi.py2
-rw-r--r--g4f/Provider/ChatgptDemo.py2
-rw-r--r--g4f/Provider/ChatgptNext.py5
-rw-r--r--g4f/Provider/Chatxyz.py2
-rw-r--r--g4f/Provider/FlowGpt.py8
-rw-r--r--g4f/Provider/GeminiPro.py3
-rw-r--r--g4f/Provider/You.py19
-rw-r--r--g4f/Provider/__init__.py2
-rw-r--r--g4f/Provider/needs_auth/OpenaiChat.py291
11 files changed, 212 insertions, 129 deletions
diff --git a/g4f/Provider/AiChatOnline.py b/g4f/Provider/AiChatOnline.py
index dc774fe0..cc3b5b8e 100644
--- a/g4f/Provider/AiChatOnline.py
+++ b/g4f/Provider/AiChatOnline.py
@@ -9,7 +9,7 @@ from .helper import get_random_string
class AiChatOnline(AsyncGeneratorProvider):
url = "https://aichatonline.org"
- working = True
+ working = False
supports_gpt_35_turbo = True
supports_message_history = False
diff --git a/g4f/Provider/Aura.py b/g4f/Provider/Aura.py
index 126c8d0f..d8f3471c 100644
--- a/g4f/Provider/Aura.py
+++ b/g4f/Provider/Aura.py
@@ -6,9 +6,8 @@ from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider
class Aura(AsyncGeneratorProvider):
- url = "https://openchat.team"
- working = True
- supports_gpt_35_turbo = True
+ url = "https://openchat.team"
+ working = True
@classmethod
async def create_async_generator(
diff --git a/g4f/Provider/ChatgptAi.py b/g4f/Provider/ChatgptAi.py
index f2785364..a38aea5e 100644
--- a/g4f/Provider/ChatgptAi.py
+++ b/g4f/Provider/ChatgptAi.py
@@ -9,7 +9,7 @@ from .base_provider import AsyncGeneratorProvider
class ChatgptAi(AsyncGeneratorProvider):
url = "https://chatgpt.ai"
- working = True
+ working = False
supports_message_history = True
supports_gpt_35_turbo = True
_system = None
diff --git a/g4f/Provider/ChatgptDemo.py b/g4f/Provider/ChatgptDemo.py
index 2f25477a..666b5753 100644
--- a/g4f/Provider/ChatgptDemo.py
+++ b/g4f/Provider/ChatgptDemo.py
@@ -10,7 +10,7 @@ from .helper import format_prompt
class ChatgptDemo(AsyncGeneratorProvider):
url = "https://chat.chatgptdemo.net"
supports_gpt_35_turbo = True
- working = True
+ working = False
@classmethod
async def create_async_generator(
diff --git a/g4f/Provider/ChatgptNext.py b/g4f/Provider/ChatgptNext.py
index c107a0bf..1ae37bd5 100644
--- a/g4f/Provider/ChatgptNext.py
+++ b/g4f/Provider/ChatgptNext.py
@@ -4,8 +4,7 @@ import json
from aiohttp import ClientSession
from ..typing import AsyncResult, Messages
-from .base_provider import AsyncGeneratorProvider
-from .helper import format_prompt
+from ..providers.base_provider import AsyncGeneratorProvider
class ChatgptNext(AsyncGeneratorProvider):
@@ -24,7 +23,7 @@ class ChatgptNext(AsyncGeneratorProvider):
if not model:
model = "gpt-3.5-turbo"
headers = {
- "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/119.0",
+ "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:122.0) Gecko/20100101 Firefox/122.0",
"Accept": "text/event-stream",
"Accept-Language": "de,en-US;q=0.7,en;q=0.3",
"Accept-Encoding": "gzip, deflate, br",
diff --git a/g4f/Provider/Chatxyz.py b/g4f/Provider/Chatxyz.py
index feb09be9..dd1216aa 100644
--- a/g4f/Provider/Chatxyz.py
+++ b/g4f/Provider/Chatxyz.py
@@ -8,7 +8,7 @@ from .base_provider import AsyncGeneratorProvider
class Chatxyz(AsyncGeneratorProvider):
url = "https://chat.3211000.xyz"
- working = True
+ working = False
supports_gpt_35_turbo = True
supports_message_history = True
diff --git a/g4f/Provider/FlowGpt.py b/g4f/Provider/FlowGpt.py
index 39192bf9..b466a2e6 100644
--- a/g4f/Provider/FlowGpt.py
+++ b/g4f/Provider/FlowGpt.py
@@ -51,12 +51,16 @@ class FlowGpt(AsyncGeneratorProvider, ProviderModelMixin):
"TE": "trailers"
}
async with ClientSession(headers=headers) as session:
+ history = [message for message in messages[:-1] if message["role"] != "system"]
+ system_message = "\n".join([message["content"] for message in messages if message["role"] == "system"])
+ if not system_message:
+ system_message = "You are helpful assistant. Follow the user's instructions carefully."
data = {
"model": model,
"nsfw": False,
"question": messages[-1]["content"],
- "history": [{"role": "assistant", "content": "Hello, how can I help you today?"}, *messages[:-1]],
- "system": kwargs.get("system_message", "You are helpful assistant. Follow the user's instructions carefully."),
+ "history": [{"role": "assistant", "content": "Hello, how can I help you today?"}, *history],
+ "system": system_message,
"temperature": kwargs.get("temperature", 0.7),
"promptId": f"model-{model}",
"documentIds": [],
diff --git a/g4f/Provider/GeminiPro.py b/g4f/Provider/GeminiPro.py
index a2e3538d..1c5487b1 100644
--- a/g4f/Provider/GeminiPro.py
+++ b/g4f/Provider/GeminiPro.py
@@ -27,6 +27,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
proxy: str = None,
api_key: str = None,
api_base: str = None,
+ use_auth_header: bool = True,
image: ImageType = None,
connector: BaseConnector = None,
**kwargs
@@ -38,7 +39,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
raise MissingAuthError('Missing "api_key"')
headers = params = None
- if api_base:
+ if api_base and use_auth_header:
headers = {"Authorization": f"Bearer {api_key}"}
else:
params = {"key": api_key}
diff --git a/g4f/Provider/You.py b/g4f/Provider/You.py
index 34130c47..b21fd582 100644
--- a/g4f/Provider/You.py
+++ b/g4f/Provider/You.py
@@ -3,9 +3,9 @@ from __future__ import annotations
import json
import base64
import uuid
-from aiohttp import ClientSession, FormData
+from aiohttp import ClientSession, FormData, BaseConnector
-from ..typing import AsyncGenerator, Messages, ImageType, Cookies
+from ..typing import AsyncResult, Messages, ImageType, Cookies
from .base_provider import AsyncGeneratorProvider
from ..providers.helper import get_connector, format_prompt
from ..image import to_bytes
@@ -26,12 +26,13 @@ class You(AsyncGeneratorProvider):
messages: Messages,
image: ImageType = None,
image_name: str = None,
+ connector: BaseConnector = None,
proxy: str = None,
chat_mode: str = "default",
**kwargs,
- ) -> AsyncGenerator:
+ ) -> AsyncResult:
async with ClientSession(
- connector=get_connector(kwargs.get("connector"), proxy),
+ connector=get_connector(connector, proxy),
headers=DEFAULT_HEADERS
) as client:
if image:
@@ -72,13 +73,13 @@ class You(AsyncGeneratorProvider):
response.raise_for_status()
async for line in response.content:
if line.startswith(b'event: '):
- event = line[7:-1]
+ event = line[7:-1].decode()
elif line.startswith(b'data: '):
- if event == b"youChatUpdate" or event == b"youChatToken":
+ if event in ["youChatUpdate", "youChatToken"]:
data = json.loads(line[6:-1])
- if event == b"youChatToken" and "youChatToken" in data:
- yield data["youChatToken"]
- elif event == b"youChatUpdate" and "t" in data:
+ if event == "youChatToken" and event in data:
+ yield data[event]
+ elif event == "youChatUpdate" and "t" in data:
yield data["t"]
@classmethod
diff --git a/g4f/Provider/__init__.py b/g4f/Provider/__init__.py
index 6cdc8806..52ba0274 100644
--- a/g4f/Provider/__init__.py
+++ b/g4f/Provider/__init__.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from ..providers.types import BaseProvider, ProviderType
-from ..providers.retry_provider import RetryProvider
+from ..providers.retry_provider import RetryProvider, IterProvider
from ..providers.base_provider import AsyncProvider, AsyncGeneratorProvider
from ..providers.create_images import CreateImagesProvider
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index 556c3d9b..0fa433a4 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -4,6 +4,8 @@ import asyncio
import uuid
import json
import os
+import base64
+from aiohttp import ClientWebSocketResponse
try:
from py_arkose_generator.arkose import get_values_for_request
@@ -20,9 +22,9 @@ except ImportError:
pass
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
-from ..helper import format_prompt, get_cookies
-from ...webdriver import get_browser, get_driver_cookies
-from ...typing import AsyncResult, Messages, Cookies, ImageType
+from ..helper import get_cookies
+from ...webdriver import get_browser
+from ...typing import AsyncResult, Messages, Cookies, ImageType, Union, AsyncIterator
from ...requests import get_args_from_browser
from ...requests.aiohttp import StreamSession
from ...image import to_image, to_bytes, ImageResponse, ImageRequest
@@ -37,10 +39,15 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
needs_auth = True
supports_gpt_35_turbo = True
supports_gpt_4 = True
+ supports_message_history = True
+ supports_system_message = True
default_model = None
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"]
- model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo"}
- _args: dict = None
+ model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo", "": "gpt-3.5-turbo"}
+ _api_key: str = None
+ _headers: dict = None
+ _cookies: Cookies = None
+ _last_message: int = 0
@classmethod
async def create(
@@ -170,6 +177,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
"""
if not cls.default_model:
async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response:
+ cls._update_request_args(session)
response.raise_for_status()
data = await response.json()
if "categories" in data:
@@ -179,7 +187,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
return cls.default_model
@classmethod
- def create_messages(cls, prompt: str, image_request: ImageRequest = None):
+ def create_messages(cls, messages: Messages, image_request: ImageRequest = None):
"""
Create a list of messages for the user input
@@ -190,31 +198,27 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
Returns:
A list of messages with the user input and the image, if any
"""
+ # Create a message object with the user role and the content
+ messages = [{
+ "id": str(uuid.uuid4()),
+ "author": {"role": message["role"]},
+ "content": {"content_type": "text", "parts": [message["content"]]},
+ } for message in messages]
+
# Check if there is an image response
- if not image_request:
- # Create a content object with the text type and the prompt
- content = {"content_type": "text", "parts": [prompt]}
- else:
- # Create a content object with the multimodal text type and the image and the prompt
- content = {
+ if image_request:
+ # Change content in last user message
+ messages[-1]["content"] = {
"content_type": "multimodal_text",
"parts": [{
"asset_pointer": f"file-service://{image_request.get('file_id')}",
"height": image_request.get("height"),
"size_bytes": image_request.get("file_size"),
"width": image_request.get("width"),
- }, prompt]
+ }, messages[-1]["content"]["parts"][0]]
}
- # Create a message object with the user role and the content
- messages = [{
- "id": str(uuid.uuid4()),
- "author": {"role": "user"},
- "content": content,
- }]
- # Check if there is an image response
- if image_request:
# Add the metadata object with the attachments
- messages[0]["metadata"] = {
+ messages[-1]["metadata"] = {
"attachments": [{
"height": image_request.get("height"),
"id": image_request.get("file_id"),
@@ -225,7 +229,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
}]
}
return messages
-
+
@classmethod
async def get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse:
"""
@@ -301,6 +305,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
conversation_id: str = None,
parent_id: str = None,
image: ImageType = None,
+ image_name: str = None,
response_fields: bool = False,
**kwargs
) -> AsyncResult:
@@ -333,50 +338,65 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
raise MissingRequirementsError('Install "py-arkose-generator" and "async_property" package')
if not parent_id:
parent_id = str(uuid.uuid4())
- if cls._args is None and cookies is None:
- cookies = get_cookies("chat.openai.com", False)
+
+ # Read api_key from arguments
api_key = kwargs["access_token"] if "access_token" in kwargs else api_key
- if api_key is None and cookies is not None:
- api_key = cookies["access_token"] if "access_token" in cookies else api_key
- if cls._args is None:
- cls._args = {
- "headers": {"Cookie": "; ".join(f"{k}={v}" for k, v in cookies.items() if k != "access_token")},
- "cookies": {} if cookies is None else cookies
- }
- if api_key is not None:
- cls._args["headers"]["Authorization"] = f"Bearer {api_key}"
+
async with StreamSession(
proxies={"https": proxy},
impersonate="chrome",
- timeout=timeout,
- headers=cls._args["headers"]
+ timeout=timeout
) as session:
- if api_key is not None:
+ # Read api_key and cookies from cache / browser config
+ if cls._headers is None:
+ if api_key is None:
+ # Read api_key from cookies
+ cookies = get_cookies("chat.openai.com", False) if cookies is None else cookies
+ api_key = cookies["access_token"] if "access_token" in cookies else api_key
+ cls._create_request_args(cookies)
+ else:
+ api_key = cls._api_key if api_key is None else api_key
+ # Read api_key with session cookies
+ if api_key is None and cookies:
+ api_key = await cls.fetch_access_token(session, cls._headers)
+ # Load default model
+ if cls.default_model is None and api_key is not None:
try:
- cls.default_model = await cls.get_default_model(session, cls._args["headers"])
+ if not model:
+ cls._set_api_key(api_key)
+ cls.default_model = cls.get_model(await cls.get_default_model(session, cls._headers))
+ else:
+ cls.default_model = cls.get_model(model)
except Exception as e:
if debug.logging:
+ print("OpenaiChat: Load default_model failed")
print(f"{e.__class__.__name__}: {e}")
- if cls.default_model is None:
+ # Browse api_key and default model
+ if api_key is None or cls.default_model is None:
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
yield f"Please login: [ChatGPT]({login_url})\n\n"
try:
- cls._args = cls.browse_access_token(proxy)
+ cls.browse_access_token(proxy)
except MissingRequirementsError:
- raise MissingAuthError(f'Missing or invalid "access_token". Add a new "api_key" please')
- cls.default_model = await cls.get_default_model(session, cls._args["headers"])
+ raise MissingAuthError(f'Missing "access_token". Add a "api_key" please')
+ cls.default_model = cls.get_model(await cls.get_default_model(session, cls._headers))
+ else:
+ cls._set_api_key(api_key)
+
try:
- image_response = None
- if image:
- image_response = await cls.upload_image(session, cls._args["headers"], image, kwargs.get("image_name"))
+ image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None
except Exception as e:
- yield e
- end_turn = EndTurn()
- model = cls.get_model(model)
- model = "text-davinci-002-render-sha" if model == "gpt-3.5-turbo" else model
- while not end_turn.is_end:
+ if debug.logging:
+ print("OpenaiChat: Upload image failed")
+ print(f"{e.__class__.__name__}: {e}")
+
+ model = cls.get_model(model).replace("gpt-3.5-turbo", "text-davinci-002-render-sha")
+ fields = ResponseFields()
+ while fields.finish_reason is None:
arkose_token = await cls.get_arkose_token(session)
+ conversation_id = conversation_id if fields.conversation_id is None else fields.conversation_id
+ parent_id = parent_id if fields.message_id is None else fields.message_id
data = {
"action": action,
"arkose_token": arkose_token,
@@ -389,13 +409,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
"history_and_training_disabled": history_disabled and not auto_continue,
}
if action != "continue":
- prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"]
- data["messages"] = cls.create_messages(prompt, image_response)
-
- # Update cookies before next request
- for c in session.cookie_jar if hasattr(session, "cookie_jar") else session.cookies.jar:
- cls._args["cookies"][c.name if hasattr(c, "name") else c.key] = c.value
- cls._args["headers"]["Cookie"] = "; ".join(f"{k}={v}" for k, v in cls._args["cookies"].items())
+ messages = messages if conversation_id is None else [messages[-1]]
+ data["messages"] = cls.create_messages(messages, image_request)
async with session.post(
f"{cls.url}/backend-api/conversation",
@@ -403,61 +418,88 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
headers={
"Accept": "text/event-stream",
"OpenAI-Sentinel-Arkose-Token": arkose_token,
- **cls._args["headers"]
+ **cls._headers
}
) as response:
+ cls._update_request_args(session)
if not response.ok:
raise RuntimeError(f"Response {response.status}: {await response.text()}")
- last_message: int = 0
- async for line in response.iter_lines():
- if not line.startswith(b"data: "):
- continue
- elif line.startswith(b"data: [DONE]"):
- break
- try:
- line = json.loads(line[6:])
- except:
- continue
- if "message" not in line:
- continue
- if "error" in line and line["error"]:
- raise RuntimeError(line["error"])
- if "message_type" not in line["message"]["metadata"]:
- continue
- try:
- image_response = await cls.get_generated_image(session, cls._args["headers"], line)
- if image_response is not None:
- yield image_response
- except Exception as e:
- yield e
- if line["message"]["author"]["role"] != "assistant":
- continue
- if line["message"]["content"]["content_type"] != "text":
- continue
- if line["message"]["metadata"]["message_type"] not in ("next", "continue", "variant"):
- continue
- conversation_id = line["conversation_id"]
- parent_id = line["message"]["id"]
+ async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, fields):
if response_fields:
response_fields = False
- yield ResponseFields(conversation_id, parent_id, end_turn)
- if "parts" in line["message"]["content"]:
- new_message = line["message"]["content"]["parts"][0]
- if len(new_message) > last_message:
- yield new_message[last_message:]
- last_message = len(new_message)
- if "finish_details" in line["message"]["metadata"]:
- if line["message"]["metadata"]["finish_details"]["type"] == "stop":
- end_turn.end()
+ yield fields
+ yield chunk
if not auto_continue:
break
action = "continue"
await asyncio.sleep(5)
if history_disabled and auto_continue:
- await cls.delete_conversation(session, cls._args["headers"], conversation_id)
+ await cls.delete_conversation(session, cls._headers, conversation_id)
+
+ @staticmethod
+ async def iter_messages_ws(ws: ClientWebSocketResponse) -> AsyncIterator:
+ while True:
+ yield base64.b64decode((await ws.receive_json())["body"])
@classmethod
- def browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> tuple[str, dict]:
+ async def iter_messages_chunk(cls, messages: AsyncIterator, session: StreamSession, fields: ResponseFields) -> AsyncIterator:
+ last_message: int = 0
+ async for message in messages:
+ if message.startswith(b'{"wss_url":'):
+ async with session.ws_connect(json.loads(message)["wss_url"]) as ws:
+ async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws), session, fields):
+ yield chunk
+ break
+ async for chunk in cls.iter_messages_line(session, message, fields):
+ if fields.finish_reason is not None:
+ break
+ elif isinstance(chunk, str):
+ if len(chunk) > last_message:
+ yield chunk[last_message:]
+ last_message = len(chunk)
+ else:
+ yield chunk
+ if fields.finish_reason is not None:
+ break
+
+ @classmethod
+ async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: ResponseFields) -> AsyncIterator:
+ if not line.startswith(b"data: "):
+ return
+ elif line.startswith(b"data: [DONE]"):
+ return
+ try:
+ line = json.loads(line[6:])
+ except:
+ return
+ if "message" not in line:
+ return
+ if "error" in line and line["error"]:
+ raise RuntimeError(line["error"])
+ if "message_type" not in line["message"]["metadata"]:
+ return
+ try:
+ image_response = await cls.get_generated_image(session, cls._headers, line)
+ if image_response is not None:
+ yield image_response
+ except Exception as e:
+ yield e
+ if line["message"]["author"]["role"] != "assistant":
+ return
+ if line["message"]["content"]["content_type"] != "text":
+ return
+ if line["message"]["metadata"]["message_type"] not in ("next", "continue", "variant"):
+ return
+ if fields.conversation_id is None:
+ fields.conversation_id = line["conversation_id"]
+ fields.message_id = line["message"]["id"]
+ if "parts" in line["message"]["content"]:
+ yield line["message"]["content"]["parts"][0]
+ if "finish_details" in line["message"]["metadata"]:
+ fields.finish_reason = line["message"]["metadata"]["finish_details"]["type"]
+
+ @classmethod
+ def browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> None:
"""
Browse to obtain an access token.
@@ -475,14 +517,15 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
"let session = await fetch('/api/auth/session');"
"let data = await session.json();"
"let accessToken = data['accessToken'];"
- "let expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 4);"
+ "let expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 4 * 1000);"
"document.cookie = 'access_token=' + accessToken + ';expires=' + expires.toUTCString() + ';path=/';"
"return accessToken;"
)
args = get_args_from_browser(f"{cls.url}/", driver, do_bypass_cloudflare=False)
- args["headers"]["Authorization"] = f"Bearer {access_token}"
- args["headers"]["Cookie"] = "; ".join(f"{k}={v}" for k, v in args["cookies"].items() if k != "access_token")
- return args
+ cls._headers = args["headers"]
+ cls._cookies = args["cookies"]
+ cls._update_cookie_header()
+ cls._set_api_key(access_token)
finally:
driver.close()
@@ -516,6 +559,42 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
return decoded_json["token"]
raise RuntimeError(f"Response: {decoded_json}")
+ @classmethod
+ async def fetch_access_token(cls, session: StreamSession, headers: dict):
+ async with session.get(
+ f"{cls.url}/api/auth/session",
+ headers=headers
+ ) as response:
+ if response.ok:
+ data = await response.json()
+ if "accessToken" in data:
+ return data["accessToken"]
+
+ @staticmethod
+ def _format_cookies(cookies: Cookies):
+ return "; ".join(f"{k}={v}" for k, v in cookies.items() if k != "access_token")
+
+ @classmethod
+ def _create_request_args(cls, cookies: Union[Cookies, None]):
+ cls._headers = {}
+ cls._cookies = {} if cookies is None else cookies
+ cls._update_cookie_header()
+
+ @classmethod
+ def _update_request_args(cls, session: StreamSession):
+ for c in session.cookie_jar if hasattr(session, "cookie_jar") else session.cookies.jar:
+ cls._cookies[c.name if hasattr(c, "name") else c.key] = c.value
+ cls._update_cookie_header()
+
+ @classmethod
+ def _set_api_key(cls, api_key: str):
+ cls._api_key = api_key
+ cls._headers["Authorization"] = f"Bearer {api_key}"
+
+ @classmethod
+ def _update_cookie_header(cls):
+ cls._headers["Cookie"] = cls._format_cookies(cls._cookies)
+
class EndTurn:
"""
Class to represent the end of a conversation turn.
@@ -530,10 +609,10 @@ class ResponseFields:
"""
Class to encapsulate response fields.
"""
- def __init__(self, conversation_id: str, message_id: str, end_turn: EndTurn):
+ def __init__(self, conversation_id: str = None, message_id: str = None, finish_reason: str = None):
self.conversation_id = conversation_id
self.message_id = message_id
- self._end_turn = end_turn
+ self.finish_reason = finish_reason
class Response():
"""
@@ -567,7 +646,7 @@ class Response():
self._message = "".join(chunks)
if not self._fields:
raise RuntimeError("Missing response fields")
- self.is_end = self._fields._end_turn.is_end
+ self.is_end = self._fields.end_turn
def __aiter__(self):
return self.generator()