diff options
author | H Lohaus <hlohaus@users.noreply.github.com> | 2024-02-11 09:28:50 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-11 09:28:50 +0100 |
commit | 27812c57e76f8b189efe9a26451374f9da160b6d (patch) | |
tree | 80ab5e213794a598e70b5c5020d05768f9f443ef | |
parent | Merge pull request #1574 from Simatwa/main (diff) | |
parent | Improve preview in image generation of Gemini (diff) | |
download | gpt4free-0.2.1.3.tar gpt4free-0.2.1.3.tar.gz gpt4free-0.2.1.3.tar.bz2 gpt4free-0.2.1.3.tar.lz gpt4free-0.2.1.3.tar.xz gpt4free-0.2.1.3.tar.zst gpt4free-0.2.1.3.zip |
-rw-r--r-- | g4f/Provider/create_images.py | 2 | ||||
-rw-r--r-- | g4f/Provider/needs_auth/Gemini.py | 4 | ||||
-rw-r--r-- | g4f/image.py | 7 |
3 files changed, 8 insertions, 5 deletions
diff --git a/g4f/Provider/create_images.py b/g4f/Provider/create_images.py index b8bcbde3..9a9e3f08 100644 --- a/g4f/Provider/create_images.py +++ b/g4f/Provider/create_images.py @@ -87,7 +87,7 @@ class CreateImagesProvider(BaseProvider): messages.insert(0, {"role": "system", "content": self.system_message}) buffer = "" for chunk in self.provider.create_completion(model, messages, stream, **kwargs): - if buffer or "<" in chunk: + if isinstance(chunk, str) and buffer or "<" in chunk: buffer += chunk if ">" in buffer: match = re.search(r'<img data-prompt="(.*?)">', buffer) diff --git a/g4f/Provider/needs_auth/Gemini.py b/g4f/Provider/needs_auth/Gemini.py index 32510505..0650942e 100644 --- a/g4f/Provider/needs_auth/Gemini.py +++ b/g4f/Provider/needs_auth/Gemini.py @@ -142,13 +142,15 @@ class Gemini(AsyncGeneratorProvider): if image_prompt: images = [image[0][3][3] for image in response_part[4][0][12][7][0]] resolved_images = [] + preview = [] for image in images: async with session.get(image, allow_redirects=False) as fetch: image = fetch.headers["location"] async with session.get(image, allow_redirects=False) as fetch: image = fetch.headers["location"] resolved_images.append(image) - yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images}) + preview.append(image.replace('=s512', '=s200')) + yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview}) def build_request( prompt: str, diff --git a/g4f/image.py b/g4f/image.py index 93922c2e..01d6ae50 100644 --- a/g4f/image.py +++ b/g4f/image.py @@ -187,7 +187,7 @@ def to_base64_jpg(image: Image, compression_rate: float) -> str: image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100)) return base64.b64encode(output_buffer.getvalue()).decode() -def format_images_markdown(images, alt: str, preview: str = None) -> str: +def format_images_markdown(images: Union[str, list], alt: str, preview: Union[str, list] = None) -> str: """ Formats the given images as a markdown string. @@ -202,9 +202,10 @@ def format_images_markdown(images, alt: str, preview: str = None) -> str: if isinstance(images, str): images = f"[![{alt}]({preview.replace('{image}', images) if preview else images})]({images})" else: + if not isinstance(preview, list): + preview = [preview.replace('{image}', image) if preview else image for image in images] images = [ - f"[![#{idx+1} {alt}]({preview.replace('{image}', image) if preview else image})]({image})" - for idx, image in enumerate(images) + f"[![#{idx+1} {alt}]({preview[idx]})]({image})" for idx, image in enumerate(images) ] images = "\n".join(images) start_flag = "<!-- generated images start -->\n" |