From 78c20c08a087f696f38246015d7bbe230276de07 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Fri, 6 Dec 2024 21:54:13 +0100 Subject: Improve image generation in OpenaiChat and Gemini --- g4f/Provider/needs_auth/Gemini.py | 44 ++++++++++++---------------- g4f/Provider/needs_auth/OpenaiChat.py | 55 +++++++++++++++++++---------------- g4f/api/__init__.py | 8 ++--- g4f/client/__init__.py | 37 +++++++++++++---------- g4f/gui/server/api.py | 6 ++-- g4f/image.py | 27 ++++++++++------- g4f/requests/raise_for_status.py | 6 ++-- 7 files changed, 98 insertions(+), 85 deletions(-) (limited to 'g4f') diff --git a/g4f/Provider/needs_auth/Gemini.py b/g4f/Provider/needs_auth/Gemini.py index e7c9de23..3c842f3c 100644 --- a/g4f/Provider/needs_auth/Gemini.py +++ b/g4f/Provider/needs_auth/Gemini.py @@ -76,9 +76,8 @@ class Gemini(AsyncGeneratorProvider): page = await browser.get(f"{cls.url}/app") await page.select("div.ql-editor.textarea", 240) cookies = {} - for c in await page.browser.cookies.get_all(): - if c.domain.endswith(".google.com"): - cookies[c.name] = c.value + for c in await page.send(nodriver.cdp.network.get_cookies([cls.url])): + cookies[c.name] = c.value await page.close() cls._cookies = cookies @@ -92,7 +91,6 @@ class Gemini(AsyncGeneratorProvider): connector: BaseConnector = None, image: ImageType = None, image_name: str = None, - response_format: str = None, return_conversation: bool = False, conversation: Conversation = None, language: str = "en", @@ -113,7 +111,7 @@ class Gemini(AsyncGeneratorProvider): async for chunk in cls.nodriver_login(proxy): yield chunk except Exception as e: - raise MissingAuthError('Missing "__Secure-1PSID" cookie', e) + raise MissingAuthError('Missing or invalid "__Secure-1PSID" cookie', e) if not cls._snlm0e: if cls._cookies is None or "__Secure-1PSID" not in cls._cookies: raise MissingAuthError('Missing "__Secure-1PSID" cookie') @@ -153,7 +151,7 @@ class Gemini(AsyncGeneratorProvider): ) as response: await raise_for_status(response) image_prompt = response_part = None - last_content_len = 0 + last_content = "" async for line in response.content: try: try: @@ -171,32 +169,26 @@ class Gemini(AsyncGeneratorProvider): yield Conversation(response_part[1][0], response_part[1][1], response_part[4][0][0]) content = response_part[4][0][1][0] except (ValueError, KeyError, TypeError, IndexError) as e: - print(f"{cls.__name__}:{e.__class__.__name__}:{e}") + debug.log(f"{cls.__name__}:{e.__class__.__name__}:{e}") continue match = re.search(r'\[Imagen of (.*?)\]', content) if match: image_prompt = match.group(1) content = content.replace(match.group(0), '') - yield content[last_content_len:] - last_content_len = len(content) - if image_prompt: - try: - images = [image[0][3][3] for image in response_part[4][0][12][7][0]] - if response_format == "b64_json": + pattern = r"http://googleusercontent.com/image_generation_content/\d+" + content = re.sub(pattern, "", content) + if last_content and content.startswith(last_content): + yield content[len(last_content):] + else: + yield content + last_content = content + if image_prompt: + try: + images = [image[0][3][3] for image in response_part[4][0][12][7][0]] + image_prompt = image_prompt.replace("a fake image", "") yield ImageResponse(images, image_prompt, {"cookies": cls._cookies}) - else: - resolved_images = [] - preview = [] - for image in images: - async with client.get(image, allow_redirects=False) as fetch: - image = fetch.headers["location"] - async with client.get(image, allow_redirects=False) as fetch: - image = fetch.headers["location"] - resolved_images.append(image) - preview.append(image.replace('=s512', '=s200')) - yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview}) - except TypeError: - pass + except TypeError: + pass @classmethod async def synthesize(cls, params: dict, proxy: str = None) -> AsyncIterator[bytes]: diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 9c0b8768..c93ba7be 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -184,7 +184,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): "content": {"content_type": "text", "parts": [message["content"]]}, "id": str(uuid.uuid4()), "create_time": int(time.time()), - "id": str(uuid.uuid4()), "metadata": {"serialization_metadata": {"custom_symbol_offsets": []}, "system_hints": system_hints}, } for message in messages] @@ -295,8 +294,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): Raises: RuntimeError: If an error occurs during processing. """ - if model == cls.default_image_model: - model = cls.default_model if cls.needs_auth: await cls.login(proxy) @@ -308,9 +305,9 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): if not cls.needs_auth: cls._create_request_args(cookies) RequestConfig.proof_token = get_config(cls._headers.get("user-agent")) - async with session.get(cls.url, headers=INIT_HEADERS) as response: - cls._update_request_args(session) - await raise_for_status(response) + async with session.get(cls.url, headers=INIT_HEADERS) as response: + cls._update_request_args(session) + await raise_for_status(response) try: image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None except Exception as e: @@ -318,6 +315,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): debug.log("OpenaiChat: Upload image failed") debug.log(f"{e.__class__.__name__}: {e}") model = cls.get_model(model) + if model == cls.default_image_model: + model = cls.default_vision_model if conversation is None: conversation = Conversation(conversation_id, str(uuid.uuid4()) if parent_id is None else parent_id) else: @@ -363,13 +362,22 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): "messages": None, "parent_message_id": conversation.message_id, "model": model, - "paragen_cot_summary_display_override": "allow", + "timezone_offset_min":-60, + "timezone":"Europe/Berlin", "history_and_training_disabled": history_disabled and not auto_continue and not return_conversation, - "conversation_mode": {"kind":"primary_assistant"}, + "conversation_mode":{"kind":"primary_assistant","plugin_ids":None}, + "force_paragen":False, + "force_paragen_model_slug":"", + "force_rate_limit":False, + "reset_rate_limits":False, "websocket_request_id": str(uuid.uuid4()), - "supported_encodings": ["v1"], - "supports_buffering": True, - "system_hints": ["search"] if web_search else None + "system_hints": ["search"] if web_search else None, + "supported_encodings":["v1"], + "conversation_origin":None, + "client_contextual_info":{"is_dark_mode":False,"time_since_loaded":14,"page_height":578,"page_width":1850,"pixel_ratio":1,"screen_height":1080,"screen_width":1920}, + "paragen_stream_type_override":None, + "paragen_cot_summary_display_override":"allow", + "supports_buffering":True } if conversation.conversation_id is not None: data["conversation_id"] = conversation.conversation_id @@ -408,7 +416,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): async for line in response.iter_lines(): async for chunk in cls.iter_messages_line(session, line, conversation): yield chunk - if not history_disabled: + if not history_disabled and RequestConfig.access_token is not None: yield SynthesizeData(cls.__name__, { "conversation_id": conversation.conversation_id, "message_id": conversation.message_id, @@ -493,9 +501,12 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): await get_request_config(proxy) cls._create_request_args(RequestConfig.cookies, RequestConfig.headers) cls._set_api_key(RequestConfig.access_token) + if RequestConfig.proof_token is None: + RequestConfig.proof_token = get_config(cls._headers.get("user-agent")) except NoValidHarFileError: if has_nodriver: - await cls.nodriver_auth(proxy) + if RequestConfig.access_token is None: + await cls.nodriver_auth(proxy) else: raise @@ -527,23 +538,17 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): await page.send(nodriver.cdp.network.enable()) page.add_handler(nodriver.cdp.network.RequestWillBeSent, on_request) page = await browser.get(cls.url) - try: - if RequestConfig.access_request_id is not None: - body = await page.send(get_response_body(RequestConfig.access_request_id)) - if isinstance(body, tuple) and body: - body = body[0] - if body: - match = re.search(r'"accessToken":"(.*?)"', body) - if match: - RequestConfig.access_token = match.group(1) - except KeyError: - pass + body = await page.evaluate("JSON.stringify(window.__remixContext)") + if body: + match = re.search(r'"accessToken":"(.*?)"', body) + if match: + RequestConfig.access_token = match.group(1) for c in await page.send(nodriver.cdp.network.get_cookies([cls.url])): RequestConfig.cookies[c.name] = c.value user_agent = await page.evaluate("window.navigator.userAgent") await page.select("#prompt-textarea", 240) while True: - if RequestConfig.proof_token: + if RequestConfig.access_token: break await asyncio.sleep(1) await page.close() diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index 0ca19360..74db35c8 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -112,7 +112,7 @@ class ImageGenerationConfig(BaseModel): prompt: str model: Optional[str] = None provider: Optional[str] = None - response_format: str = "url" + response_format: Optional[str] = None api_key: Optional[str] = None proxy: Optional[str] = None @@ -370,9 +370,9 @@ class Api: model=config.model, provider=AppConfig.image_provider if config.provider is None else config.provider, **filter_none( - response_format = config.response_format, - api_key = config.api_key, - proxy = config.proxy + response_format=config.response_format, + api_key=config.api_key, + proxy=config.proxy ) ) for image in response.data: diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index 177401dd..d95618f1 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -292,6 +292,7 @@ class Images: if proxy is None: proxy = self.client.proxy + e = None response = None if isinstance(provider_handler, IterListProvider): for provider in provider_handler.providers: @@ -300,7 +301,7 @@ class Images: if response is not None: provider_name = provider.__name__ break - except (MissingAuthError, NoValidHarFileError) as e: + except Exception as e: debug.log(f"Image provider {provider.__name__}: {e}") else: response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs) @@ -314,6 +315,8 @@ class Images: provider_name ) if response is None: + if e is not None: + raise e raise NoImageResponseError(f"No image response from {provider_name}") raise NoImageResponseError(f"Unexpected response type: {type(response)}") @@ -362,7 +365,7 @@ class Images: image: ImageType, model: str = None, provider: Optional[ProviderType] = None, - response_format: str = "url", + response_format: Optional[str] = None, **kwargs ) -> ImagesResponse: return asyncio.run(self.async_create_variation( @@ -374,7 +377,7 @@ class Images: image: ImageType, model: Optional[str] = None, provider: Optional[ProviderType] = None, - response_format: str = "url", + response_format: Optional[str] = None, proxy: Optional[str] = None, **kwargs ) -> ImagesResponse: @@ -384,6 +387,7 @@ class Images: proxy = self.client.proxy prompt = "create a variation of this image" + e = None response = None if isinstance(provider_handler, IterListProvider): # File pointer can be read only once, so we need to convert it to bytes @@ -394,7 +398,7 @@ class Images: if response is not None: provider_name = provider.__name__ break - except (MissingAuthError, NoValidHarFileError) as e: + except Exception as e: debug.log(f"Image provider {provider.__name__}: {e}") else: response = await self._generate_image_response(provider_handler, provider_name, model, prompt, image=image, **kwargs) @@ -402,10 +406,11 @@ class Images: if isinstance(response, ImageResponse): return await self._process_image_response(response, response_format, proxy, model, provider_name) if response is None: + if e is not None: + raise e raise NoImageResponseError(f"No image response from {provider_name}") raise NoImageResponseError(f"Unexpected response type: {type(response)}") - async def _process_image_response( self, response: ImageResponse, @@ -414,21 +419,21 @@ class Images: model: Optional[str] = None, provider: Optional[str] = None ) -> ImagesResponse: + last_provider = get_last_provider(True) if response_format == "url": # Return original URLs without saving locally images = [Image.construct(url=image, revised_prompt=response.alt) for image in response.get_list()] - elif response_format == "b64_json": - images = await copy_images(response.get_list(), response.options.get("cookies"), proxy) - async def process_image_item(image_file: str) -> Image: - with open(os.path.join(images_dir, os.path.basename(image_file)), "rb") as file: - image_data = base64.b64encode(file.read()).decode() - return Image.construct(b64_json=image_data, revised_prompt=response.alt) - images = await asyncio.gather(*[process_image_item(image) for image in images]) else: # Save locally for None (default) case - images = await copy_images(response.get_list(), response.options.get("cookies"), proxy) - images = [Image.construct(url=f"/images/{os.path.basename(image)}", revised_prompt=response.alt) for image in images] - last_provider = get_last_provider(True) + images = await copy_images(response.get_list(), response.get("cookies"), proxy) + if response_format == "b64_json": + async def process_image_item(image_file: str) -> Image: + with open(os.path.join(images_dir, os.path.basename(image_file)), "rb") as file: + image_data = base64.b64encode(file.read()).decode() + return Image.construct(b64_json=image_data, revised_prompt=response.alt) + images = await asyncio.gather(*[process_image_item(image) for image in images]) + else: + images = [Image.construct(url=f"/images/{os.path.basename(image)}", revised_prompt=response.alt) for image in images] return ImagesResponse.construct( created=int(time.time()), data=images, @@ -529,7 +534,7 @@ class AsyncImages(Images): image: ImageType, model: str = None, provider: ProviderType = None, - response_format: str = "url", + response_format: Optional[str] = None, **kwargs ) -> ImagesResponse: return await self.async_create_variation( diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py index 692d9e5c..ccad27df 100644 --- a/g4f/gui/server/api.py +++ b/g4f/gui/server/api.py @@ -110,8 +110,10 @@ class Api: def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator: def log_handler(text: str): debug.logs.append(text) - print(text) + if debug.logging: + print(text) debug.log_handler = log_handler + proxy = os.environ.get("G4F_PROXY") try: result = ChatCompletion.create(**kwargs) first = True @@ -139,7 +141,7 @@ class Api: elif isinstance(chunk, ImageResponse): images = chunk if download_images: - images = asyncio.run(copy_images(chunk.get_list(), chunk.options.get("cookies"))) + images = asyncio.run(copy_images(chunk.get_list(), chunk.get("cookies"), proxy)) images = ImageResponse(images, chunk.alt) yield self._format_json("content", str(images)) elif isinstance(chunk, SynthesizeData): diff --git a/g4f/image.py b/g4f/image.py index e9abcb6e..4a1d740c 100644 --- a/g4f/image.py +++ b/g4f/image.py @@ -7,8 +7,7 @@ import uuid from io import BytesIO import base64 import asyncio -from aiohttp import ClientSession - +from aiohttp import ClientSession, ClientError try: from PIL.Image import open as open_image, new as new_image from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90 @@ -20,6 +19,7 @@ from .typing import ImageType, Union, Image, Optional, Cookies from .errors import MissingRequirementsError from .providers.response import ResponseType from .requests.aiohttp import get_connector +from . import debug ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'} @@ -277,12 +277,14 @@ def ensure_images_dir(): if not os.path.exists(images_dir): os.makedirs(images_dir) -async def copy_images(images: list[str], cookies: Optional[Cookies] = None, proxy: Optional[str] = None): +async def copy_images( + images: list[str], + cookies: Optional[Cookies] = None, + proxy: Optional[str] = None +): ensure_images_dir() async with ClientSession( - connector=get_connector( - proxy=os.environ.get("G4F_PROXY") if proxy is None else proxy - ), + connector=get_connector(proxy=proxy), cookies=cookies ) as session: async def copy_image(image: str) -> str: @@ -291,10 +293,15 @@ async def copy_images(images: list[str], cookies: Optional[Cookies] = None, prox with open(target, "wb") as f: f.write(extract_data_uri(image)) else: - async with session.get(image) as response: - with open(target, "wb") as f: - async for chunk in response.content.iter_chunked(4096): - f.write(chunk) + try: + async with session.get(image) as response: + response.raise_for_status() + with open(target, "wb") as f: + async for chunk in response.content.iter_chunked(4096): + f.write(chunk) + except ClientError as e: + debug.log(f"copy_images failed: {e.__class__.__name__}: {e}") + return image with open(target, "rb") as f: extension = is_accepted_format(f.read(12)).split("/")[-1] extension = "jpg" if extension == "jpeg" else extension diff --git a/g4f/requests/raise_for_status.py b/g4f/requests/raise_for_status.py index 0cd09a2a..3566ead2 100644 --- a/g4f/requests/raise_for_status.py +++ b/g4f/requests/raise_for_status.py @@ -18,7 +18,7 @@ def is_cloudflare(text: str) -> bool: return '
' in text or "Just a moment..." in text def is_openai(text: str) -> bool: - return "

Unable to load site

" in text + return "

Unable to load site

" in text or 'id="challenge-error-text"' in text async def raise_for_status_async(response: Union[StreamResponse, ClientResponse], message: str = None): if response.status in (429, 402): @@ -27,8 +27,10 @@ async def raise_for_status_async(response: Union[StreamResponse, ClientResponse] if response.status == 403 and is_cloudflare(message): raise CloudflareError(f"Response {response.status}: Cloudflare detected") elif response.status == 403 and is_openai(message): - raise ResponseStatusError(f"Response {response.status}: Bot are detected") + raise ResponseStatusError(f"Response {response.status}: OpenAI Bot detected") elif not response.ok: + if "" in message: + message = "HTML content" raise ResponseStatusError(f"Response {response.status}: {message}") def raise_for_status(response: Union[Response, StreamResponse, ClientResponse, RequestsResponse], message: str = None): -- cgit v1.2.3