summaryrefslogtreecommitdiffstats
path: root/g4f/image.py
diff options
context:
space:
mode:
authorHeiner Lohaus <hlohaus@users.noreply.github.com>2024-01-29 18:14:46 +0100
committerHeiner Lohaus <hlohaus@users.noreply.github.com>2024-01-29 18:14:46 +0100
commita28bab938704a15c825c1b45a8983c72e8c90ace (patch)
tree0a3ee26bb1f9e7bf7c1a5a739ed59015248acbfb /g4f/image.py
parentMerge pull request #1523 from u66u/which-webdriver (diff)
downloadgpt4free-a28bab938704a15c825c1b45a8983c72e8c90ace.tar
gpt4free-a28bab938704a15c825c1b45a8983c72e8c90ace.tar.gz
gpt4free-a28bab938704a15c825c1b45a8983c72e8c90ace.tar.bz2
gpt4free-a28bab938704a15c825c1b45a8983c72e8c90ace.tar.lz
gpt4free-a28bab938704a15c825c1b45a8983c72e8c90ace.tar.xz
gpt4free-a28bab938704a15c825c1b45a8983c72e8c90ace.tar.zst
gpt4free-a28bab938704a15c825c1b45a8983c72e8c90ace.zip
Diffstat (limited to 'g4f/image.py')
-rw-r--r--g4f/image.py39
1 files changed, 24 insertions, 15 deletions
diff --git a/g4f/image.py b/g4f/image.py
index 68767155..1a4692b3 100644
--- a/g4f/image.py
+++ b/g4f/image.py
@@ -3,14 +3,13 @@ from __future__ import annotations
import re
from io import BytesIO
import base64
-from .typing import ImageType, Union
+from .typing import ImageType, Union, Image
try:
- from PIL.Image import open as open_image, new as new_image, Image
+ from PIL.Image import open as open_image, new as new_image
from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90
has_requirements = True
except ImportError:
- Image = type
has_requirements = False
from .errors import MissingRequirementsError
@@ -29,6 +28,9 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image:
"""
if not has_requirements:
raise MissingRequirementsError('Install "pillow" package for images')
+ if isinstance(image, str):
+ is_data_uri_an_image(image)
+ image = extract_data_uri(image)
if is_svg:
try:
import cairosvg
@@ -39,9 +41,6 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image:
buffer = BytesIO()
cairosvg.svg2png(image, write_to=buffer)
return open_image(buffer)
- if isinstance(image, str):
- is_data_uri_an_image(image)
- image = extract_data_uri(image)
if isinstance(image, bytes):
is_accepted_format(image)
return open_image(BytesIO(image))
@@ -79,9 +78,9 @@ def is_data_uri_an_image(data_uri: str) -> bool:
if not re.match(r'data:image/(\w+);base64,', data_uri):
raise ValueError("Invalid data URI image.")
# Extract the image format from the data URI
- image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1)
+ image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower()
# Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
- if image_format.lower() not in ALLOWED_EXTENSIONS:
+ if image_format not in ALLOWED_EXTENSIONS and image_format != "svg+xml":
raise ValueError("Invalid image format (from mime file type).")
def is_accepted_format(binary_data: bytes) -> bool:
@@ -187,7 +186,7 @@ def to_base64_jpg(image: Image, compression_rate: float) -> str:
image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
return base64.b64encode(output_buffer.getvalue()).decode()
-def format_images_markdown(images, alt: str, preview: str="{image}?w=200&h=200") -> str:
+def format_images_markdown(images, alt: str, preview: str = None) -> str:
"""
Formats the given images as a markdown string.
@@ -200,9 +199,12 @@ def format_images_markdown(images, alt: str, preview: str="{image}?w=200&h=200")
str: The formatted markdown string.
"""
if isinstance(images, str):
- images = f"[![{alt}]({preview.replace('{image}', images)})]({images})"
+ images = f"[![{alt}]({preview.replace('{image}', images) if preview else images})]({images})"
else:
- images = [f"[![#{idx+1} {alt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)]
+ images = [
+ f"[![#{idx+1} {alt}]({preview.replace('{image}', image) if preview else image})]({image})"
+ for idx, image in enumerate(images)
+ ]
images = "\n".join(images)
start_flag = "<!-- generated images start -->\n"
end_flag = "<!-- generated images end -->\n"
@@ -223,7 +225,7 @@ def to_bytes(image: Image) -> bytes:
image.seek(0)
return bytes_io.getvalue()
-class ImageResponse():
+class ImageResponse:
def __init__(
self,
images: Union[str, list],
@@ -235,10 +237,17 @@ class ImageResponse():
self.options = options
def __str__(self) -> str:
- return format_images_markdown(self.images, self.alt)
+ return format_images_markdown(self.images, self.alt, self.get("preview"))
def get(self, key: str):
return self.options.get(key)
-class ImageRequest(ImageResponse):
- pass \ No newline at end of file
+class ImageRequest:
+ def __init__(
+ self,
+ options: dict = {}
+ ):
+ self.options = options
+
+ def get(self, key: str):
+ return self.options.get(key) \ No newline at end of file