summaryrefslogtreecommitdiffstats
path: root/g4f/image.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/image.py16
1 files changed, 14 insertions, 2 deletions
diff --git a/g4f/image.py b/g4f/image.py
index 24ded915..61081ea1 100644
--- a/g4f/image.py
+++ b/g4f/image.py
@@ -4,9 +4,9 @@ import base64
from .typing import ImageType, Union
from PIL import Image
-ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp'}
+ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'}
-def to_image(image: ImageType) -> Image.Image:
+def to_image(image: ImageType, is_svg: bool = False) -> Image.Image:
"""
Converts the input image to a PIL Image object.
@@ -16,6 +16,16 @@ def to_image(image: ImageType) -> Image.Image:
Returns:
Image.Image: The converted PIL Image object.
"""
+ if is_svg:
+ try:
+ import cairosvg
+ except ImportError:
+ raise RuntimeError('Install "cairosvg" package for open svg images')
+ if not isinstance(image, bytes):
+ image = image.read()
+ buffer = BytesIO()
+ cairosvg.svg2png(image, write_to=buffer)
+ image = Image.open(buffer)
if isinstance(image, str):
is_data_uri_an_image(image)
image = extract_data_uri(image)
@@ -153,6 +163,8 @@ def to_base64(image: Image.Image, compression_rate: float) -> str:
str: The base64-encoded image.
"""
output_buffer = BytesIO()
+ if image.mode != "RGB":
+ image = image.convert('RGB')
image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
return base64.b64encode(output_buffer.getvalue()).decode()