diff options
Diffstat (limited to 'g4f/Provider/needs_auth/Gemini.py')
-rw-r--r-- | g4f/Provider/needs_auth/Gemini.py | 36 |
1 files changed, 19 insertions, 17 deletions
diff --git a/g4f/Provider/needs_auth/Gemini.py b/g4f/Provider/needs_auth/Gemini.py index e468f64a..f9b1c4a5 100644 --- a/g4f/Provider/needs_auth/Gemini.py +++ b/g4f/Provider/needs_auth/Gemini.py @@ -4,6 +4,7 @@ import os import json import random import re +import base64 from aiohttp import ClientSession, BaseConnector @@ -22,7 +23,7 @@ from ..base_provider import AsyncGeneratorProvider from ..helper import format_prompt, get_cookies from ...requests.raise_for_status import raise_for_status from ...errors import MissingAuthError, MissingRequirementsError -from ...image import to_bytes, to_data_uri, ImageResponse +from ...image import to_bytes, ImageResponse, ImageDataResponse from ...webdriver import get_browser, get_driver_cookies REQUEST_HEADERS = { @@ -122,6 +123,7 @@ class Gemini(AsyncGeneratorProvider): connector: BaseConnector = None, image: ImageType = None, image_name: str = None, + response_format: str = None, **kwargs ) -> AsyncResult: prompt = format_prompt(messages) @@ -192,22 +194,22 @@ class Gemini(AsyncGeneratorProvider): if image_prompt: images = [image[0][3][3] for image in response_part[4][0][12][7][0]] resolved_images = [] - preview = [] - for image in images: - async with client.get(image, allow_redirects=False) as fetch: - image = fetch.headers["location"] - async with client.get(image, allow_redirects=False) as fetch: - image = fetch.headers["location"] - resolved_images.append(image) - preview.append(image.replace('=s512', '=s200')) - # preview_url = image.replace('=s512', '=s200') - # async with client.get(preview_url) as fetch: - # preview_data = to_data_uri(await fetch.content.read()) - # async with client.get(image) as fetch: - # data = to_data_uri(await fetch.content.read()) - # preview.append(preview_data) - # resolved_images.append(data) - yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview}) + if response_format == "b64_json": + for image in images: + async with client.get(image) as response: + data = base64.b64encode(await response.content.read()).decode() + resolved_images.append(data) + yield ImageDataResponse(resolved_images, image_prompt) + else: + preview = [] + for image in images: + async with client.get(image, allow_redirects=False) as fetch: + image = fetch.headers["location"] + async with client.get(image, allow_redirects=False) as fetch: + image = fetch.headers["location"] + resolved_images.append(image) + preview.append(image.replace('=s512', '=s200')) + yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview}) def build_request( prompt: str, |