summaryrefslogtreecommitdiffstats
path: root/g4f/client/__init__.py
diff options
context:
space:
mode:
authorkqlio67 <kqlio67@users.noreply.github.com>2024-11-17 14:33:18 +0100
committerkqlio67 <kqlio67@users.noreply.github.com>2024-11-17 14:33:18 +0100
commit8e2723938a280c7b525bac1d847fe80a5c2022ef (patch)
tree3dcb72c39c6b6412d5e66ced45054a91f0e3c13a /g4f/client/__init__.py
parentFix api streaming, fix AsyncClient (#2357) (diff)
downloadgpt4free-8e2723938a280c7b525bac1d847fe80a5c2022ef.tar
gpt4free-8e2723938a280c7b525bac1d847fe80a5c2022ef.tar.gz
gpt4free-8e2723938a280c7b525bac1d847fe80a5c2022ef.tar.bz2
gpt4free-8e2723938a280c7b525bac1d847fe80a5c2022ef.tar.lz
gpt4free-8e2723938a280c7b525bac1d847fe80a5c2022ef.tar.xz
gpt4free-8e2723938a280c7b525bac1d847fe80a5c2022ef.tar.zst
gpt4free-8e2723938a280c7b525bac1d847fe80a5c2022ef.zip
Diffstat (limited to '')
-rw-r--r--g4f/client/__init__.py73
1 files changed, 34 insertions, 39 deletions
diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py
index 5ffe9288..3adb18ef 100644
--- a/g4f/client/__init__.py
+++ b/g4f/client/__init__.py
@@ -247,7 +247,7 @@ class Images:
"""
Synchronous generate method that runs the async_generate method in an event loop.
"""
- return asyncio.run(self.async_generate(prompt, model, provider, response_format=response_format, proxy=proxy **kwargs))
+ return asyncio.run(self.async_generate(prompt, model, provider, response_format=response_format, proxy=proxy, **kwargs))
async def async_generate(self, prompt: str, model: str = None, provider: ProviderType = None, response_format: str = "url", proxy: str = None, **kwargs) -> ImagesResponse:
if provider is None:
@@ -261,7 +261,7 @@ class Images:
if isinstance(provider_handler, IterListProvider):
if provider_handler.providers:
- provider_handler = provider.providers[0]
+ provider_handler = provider_handler.providers[0]
else:
raise ValueError(f"IterListProvider for model {model} has no providers")
@@ -287,44 +287,39 @@ class Images:
raise NoImageResponseError(f"Unexpected response type: {type(response)}")
async def _process_image_response(self, response: ImageResponse, response_format: str, proxy: str = None, model: str = None, provider: str = None) -> ImagesResponse:
- async def process_image_item(session: aiohttp.ClientSession, image_data: str):
- if image_data.startswith('http://') or image_data.startswith('https://'):
- if response_format == "url":
- return Image(url=image_data, revised_prompt=response.alt)
- elif response_format == "b64_json":
- # Fetch the image data and convert it to base64
- image_content = await self._fetch_image(session, image_data)
- file_name = self._save_image(image_data_bytes)
- b64_json = base64.b64encode(image_content).decode('utf-8')
- return Image(b64_json=b64_json, url=file_name, revised_prompt=response.alt)
- else:
- # Assume image_data is base64 data or binary
- if response_format == "url":
- if image_data.startswith('data:image'):
- # Remove the data URL scheme and get the base64 data
- base64_data = image_data.split(',', 1)[-1]
- else:
- base64_data = image_data
- # Decode the base64 data
- image_data_bytes = base64.b64decode(base64_data)
- # Convert bytes to an image
+ async def process_image_item(session: aiohttp.ClientSession, image_data: str):
+ image_data_bytes = None
+ if image_data.startswith("http://") or image_data.startswith("https://"):
+ if response_format == "url":
+ return Image(url=image_data, revised_prompt=response.alt)
+ elif response_format == "b64_json":
+ # Fetch the image data and convert it to base64
+ image_data_bytes = await self._fetch_image(session, image_data)
+ b64_json = base64.b64encode(image_data_bytes).decode("utf-8")
+ return Image(b64_json=b64_json, url=image_data, revised_prompt=response.alt)
+ else:
+ # Assume image_data is base64 data or binary
+ if response_format == "url":
+ if image_data.startswith("data:image"):
+ # Remove the data URL scheme and get the base64 data
+ base64_data = image_data.split(",", 1)[-1]
+ else:
+ base64_data = image_data
+ # Decode the base64 data
+ image_data_bytes = base64.b64decode(base64_data)
+ if image_data_bytes:
file_name = self._save_image(image_data_bytes)
return Image(url=file_name, revised_prompt=response.alt)
- elif response_format == "b64_json":
- if isinstance(image_data, bytes):
- file_name = self._save_image(image_data_bytes)
- b64_json = base64.b64encode(image_data).decode('utf-8')
- else:
- b64_json = image_data # If already base64-encoded string
- return Image(b64_json=b64_json, url=file_name, revised_prompt=response.alt)
-
- last_provider = get_last_provider(True)
- async with aiohttp.ClientSession(cookies=response.get("cookies"), connector=get_connector(proxy=proxy)) as session:
- return ImagesResponse(
- await asyncio.gather(*[process_image_item(session, image_data) for image_data in response.get_list()]),
- model=last_provider.get("model") if model is None else model,
- provider=last_provider.get("name") if provider is None else provider
- )
+ else:
+ raise ValueError("Unable to process image data")
+
+ last_provider = get_last_provider(True)
+ async with aiohttp.ClientSession(cookies=response.get("cookies"), connector=get_connector(proxy=proxy)) as session:
+ return ImagesResponse(
+ await asyncio.gather(*[process_image_item(session, image_data) for image_data in response.get_list()]),
+ model=last_provider.get("model") if model is None else model,
+ provider=last_provider.get("name") if provider is None else provider
+ )
async def _fetch_image(self, session: aiohttp.ClientSession, url: str) -> bytes:
# Asynchronously fetch image data from the URL
@@ -465,4 +460,4 @@ class AsyncImages(Images):
async def create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse:
return await self.async_create_variation(
image, model, provider, response_format, **kwargs
- ) \ No newline at end of file
+ )