diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/needs_auth/OpenaiChat.py | 77 |
1 files changed, 21 insertions, 56 deletions
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index a202f45e..d8ea4fad 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -26,7 +26,7 @@ from ...webdriver import get_browser from ...typing import AsyncResult, Messages, Cookies, ImageType, AsyncIterator from ...requests import get_args_from_browser, raise_for_status from ...requests.aiohttp import StreamSession -from ...image import to_image, to_bytes, ImageResponse, ImageRequest +from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format from ...errors import MissingAuthError, ResponseError from ...providers.conversation import BaseConversation from ..helper import format_cookies @@ -138,23 +138,22 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): An ImageRequest object that contains the download URL, file name, and other data """ # Convert the image to a PIL Image object and get the extension - image = to_image(image) - extension = image.format.lower() - # Convert the image to a bytes object and get the size data_bytes = to_bytes(image) + image = to_image(data_bytes) + extension = image.format.lower() data = { - "file_name": image_name if image_name else f"{image.width}x{image.height}.{extension}", + "file_name": "" if image_name is None else image_name, "file_size": len(data_bytes), "use_case": "multimodal" } # Post the image data to the service and get the image data async with session.post(f"{cls.url}/backend-api/files", json=data, headers=headers) as response: - cls._update_request_args() + cls._update_request_args(session) await raise_for_status(response) image_data = { **data, **await response.json(), - "mime_type": f"image/{extension}", + "mime_type": is_accepted_format(data_bytes), "extension": extension, "height": image.height, "width": image.width @@ -275,7 +274,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): first_part = line["message"]["content"]["parts"][0] if "asset_pointer" not in first_part or "metadata" not in first_part: return - if first_part["metadata"] is None: + if first_part["metadata"] is None or first_part["metadata"]["dalle"] is None: return prompt = first_part["metadata"]["dalle"]["prompt"] file_id = first_part["asset_pointer"].split("file-service://", 1)[1] @@ -365,49 +364,17 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): ) as session: if cls._expires is not None and cls._expires < time.time(): cls._headers = cls._api_key = None - if cls._headers is None or cookies is not None: - cls._create_request_args(cookies) - api_key = kwargs["access_token"] if "access_token" in kwargs else api_key - if api_key is not None: - cls._set_api_key(api_key) - - if cls.default_model is None and (not cls.needs_auth or cls._api_key is not None): - if cls._api_key is None: - cls._create_request_args(cookies) - async with session.get( - f"{cls.url}/", - headers=DEFAULT_HEADERS - ) as response: - cls._update_request_args(session) - await raise_for_status(response) - try: - if not model: - cls.default_model = cls.get_model(await cls.get_default_model(session, cls._headers)) - else: - cls.default_model = cls.get_model(model) - except MissingAuthError: - pass - except Exception as e: - api_key = cls._api_key = None - cls._create_request_args() - if debug.logging: - print("OpenaiChat: Load default model failed") - print(f"{e.__class__.__name__}: {e}") - arkose_token = None proofTokens = None - if cls.default_model is None: - error = None - try: - arkose_token, api_key, cookies, headers, proofTokens = await getArkoseAndAccessToken(proxy) - cls._create_request_args(cookies, headers) - cls._set_api_key(api_key) - except NoValidHarFileError as e: - error = e - if cls._api_key is None: - await cls.nodriver_access_token(proxy) + try: + arkose_token, api_key, cookies, headers, proofTokens = await getArkoseAndAccessToken(proxy) + cls._create_request_args(cookies, headers) + cls._set_api_key(api_key) + except NoValidHarFileError as e: if cls._api_key is None and cls.needs_auth: - raise error + raise e + + if cls.default_model is None: cls.default_model = cls.get_model(await cls.get_default_model(session, cls._headers)) try: @@ -461,7 +428,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): ) ws = None if need_arkose: - async with session.post("https://chatgpt.com/backend-api/register-websocket", headers=cls._headers) as response: + async with session.post(f"{cls.url}/backend-api/register-websocket", headers=cls._headers) as response: wss_url = (await response.json()).get("wss_url") if wss_url: ws = await session.ws_connect(wss_url) @@ -490,7 +457,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): if proofofwork is not None: headers["Openai-Sentinel-Proof-Token"] = proofofwork async with session.post( - f"{cls.url}/backend-anon/conversation" if cls._api_key is None else + f"{cls.url}/backend-anon/conversation" + if cls._api_key is None else f"{cls.url}/backend-api/conversation", json=data, headers=headers @@ -580,12 +548,9 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): raise RuntimeError(line["error"]) if "message_type" not in line["message"]["metadata"]: return - try: - image_response = await cls.get_generated_image(session, cls._headers, line) - if image_response is not None: - yield image_response - except Exception as e: - yield e + image_response = await cls.get_generated_image(session, cls._headers, line) + if image_response is not None: + yield image_response if line["message"]["author"]["role"] != "assistant": return if line["message"]["content"]["content_type"] != "text": |