summaryrefslogtreecommitdiffstats
path: root/g4f/client
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/client')
-rw-r--r--g4f/client/async_client.py90
-rw-r--r--g4f/client/service.py6
-rw-r--r--g4f/client/stubs.py23
3 files changed, 96 insertions, 23 deletions
diff --git a/g4f/client/async_client.py b/g4f/client/async_client.py
index 07ad3357..1508e566 100644
--- a/g4f/client/async_client.py
+++ b/g4f/client/async_client.py
@@ -3,6 +3,9 @@ from __future__ import annotations
import time
import random
import string
+import asyncio
+import base64
+from aiohttp import ClientSession, BaseConnector
from .types import Client as BaseClient
from .types import ProviderType, FinishReason
@@ -11,9 +14,11 @@ from .types import AsyncIterResponse, ImageProvider
from .image_models import ImageModels
from .helper import filter_json, find_stop, filter_none, cast_iter_async
from .service import get_last_provider, get_model_and_provider
+from ..Provider import ProviderUtils
from ..typing import Union, Messages, AsyncIterator, ImageType
-from ..errors import NoImageResponseError
-from ..image import ImageResponse as ImageProviderResponse
+from ..errors import NoImageResponseError, ProviderNotFoundError
+from ..requests.aiohttp import get_connector
+from ..image import ImageResponse as ImageProviderResponse, ImageDataResponse
try:
anext
@@ -156,12 +161,28 @@ class Chat():
def __init__(self, client: AsyncClient, provider: ProviderType = None):
self.completions = Completions(client, provider)
-async def iter_image_response(response: AsyncIterator) -> Union[ImagesResponse, None]:
+async def iter_image_response(
+ response: AsyncIterator,
+ response_format: str = None,
+ connector: BaseConnector = None,
+ proxy: str = None
+) -> Union[ImagesResponse, None]:
async for chunk in response:
if isinstance(chunk, ImageProviderResponse):
- return ImagesResponse([Image(image) for image in chunk.get_list()])
+ if response_format == "b64_json":
+ async with ClientSession(
+ connector=get_connector(connector, proxy)
+ ) as session:
+ async def fetch_image(image):
+ async with session.get(image) as response:
+ return base64.b64encode(await response.content.read()).decode()
+ images = await asyncio.gather(*[fetch_image(image) for image in chunk.get_list()])
+ return ImagesResponse([Image(None, image, chunk.alt) for image in images], int(time.time()))
+ return ImagesResponse([Image(image, None, chunk.alt) for image in chunk.get_list()], int(time.time()))
+ elif isinstance(chunk, ImageDataResponse):
+ return ImagesResponse([Image(None, image, chunk.alt) for image in chunk.get_list()], int(time.time()))
-def create_image(client: AsyncClient, provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator:
+def create_image(provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator:
prompt = f"create a image with: {prompt}"
if provider.__name__ == "You":
kwargs["chat_mode"] = "create"
@@ -169,7 +190,6 @@ def create_image(client: AsyncClient, provider: ProviderType, prompt: str, model
model,
[{"role": "user", "content": prompt}],
stream=True,
- proxy=client.get_proxy(),
**kwargs
)
@@ -179,31 +199,71 @@ class Images():
self.provider: ImageProvider = provider
self.models: ImageModels = ImageModels(client)
- async def generate(self, prompt, model: str = "", **kwargs) -> ImagesResponse:
- provider = self.models.get(model, self.provider)
+ def get_provider(self, model: str, provider: ProviderType = None):
+ if isinstance(provider, str):
+ if provider in ProviderUtils.convert:
+ provider = ProviderUtils.convert[provider]
+ else:
+ raise ProviderNotFoundError(f'Provider not found: {provider}')
+ else:
+ provider = self.models.get(model, self.provider)
+ return provider
+
+ async def generate(
+ self,
+ prompt,
+ model: str = "",
+ provider: ProviderType = None,
+ response_format: str = None,
+ connector: BaseConnector = None,
+ proxy: str = None,
+ **kwargs
+ ) -> ImagesResponse:
+ provider = self.get_provider(model, provider)
if hasattr(provider, "create_async_generator"):
- response = create_image(self.client, provider, prompt, **kwargs)
+ response = create_image(
+ provider,
+ prompt,
+ **filter_none(
+ response_format=response_format,
+ connector=connector,
+ proxy=self.client.get_proxy() if proxy is None else proxy,
+ ),
+ **kwargs
+ )
else:
response = await provider.create_async(prompt)
return ImagesResponse([Image(image) for image in response.get_list()])
- image = await iter_image_response(response)
+ image = await iter_image_response(response, response_format, connector, proxy)
if image is None:
raise NoImageResponseError()
return image
- async def create_variation(self, image: ImageType, model: str = None, **kwargs):
- provider = self.models.get(model, self.provider)
+ async def create_variation(
+ self,
+ image: ImageType,
+ model: str = None,
+ response_format: str = None,
+ connector: BaseConnector = None,
+ proxy: str = None,
+ **kwargs
+ ):
+ provider = self.get_provider(model, provider)
result = None
if hasattr(provider, "create_async_generator"):
response = provider.create_async_generator(
"",
[{"role": "user", "content": "create a image like this"}],
- True,
+ stream=True,
image=image,
- proxy=self.client.get_proxy(),
+ **filter_none(
+ response_format=response_format,
+ connector=connector,
+ proxy=self.client.get_proxy() if proxy is None else proxy,
+ ),
**kwargs
)
- result = iter_image_response(response)
+ result = iter_image_response(response, response_format, connector, proxy)
if result is None:
raise NoImageResponseError()
return result
diff --git a/g4f/client/service.py b/g4f/client/service.py
index dd6bf4b6..5fdb150c 100644
--- a/g4f/client/service.py
+++ b/g4f/client/service.py
@@ -4,7 +4,7 @@ from typing import Union
from .. import debug, version
from ..errors import ProviderNotFoundError, ModelNotFoundError, ProviderNotWorkingError, StreamNotSupportedError
-from ..models import Model, ModelUtils
+from ..models import Model, ModelUtils, default
from ..Provider import ProviderUtils
from ..providers.types import BaseRetryProvider, ProviderType
from ..providers.retry_provider import IterProvider
@@ -60,7 +60,9 @@ def get_model_and_provider(model : Union[Model, str],
model = ModelUtils.convert[model]
if not provider:
- if isinstance(model, str):
+ if not model:
+ model = default
+ elif isinstance(model, str):
raise ModelNotFoundError(f'Model not found: {model}')
provider = model.best_provider
diff --git a/g4f/client/stubs.py b/g4f/client/stubs.py
index ceb51b83..8cf2bcba 100644
--- a/g4f/client/stubs.py
+++ b/g4f/client/stubs.py
@@ -96,13 +96,24 @@ class ChatCompletionDeltaChoice(Model):
}
class Image(Model):
- url: str
+ def __init__(self, url: str = None, b64_json: str = None, revised_prompt: str = None) -> None:
+ if url is not None:
+ self.url = url
+ if b64_json is not None:
+ self.b64_json = b64_json
+ if revised_prompt is not None:
+ self.revised_prompt = revised_prompt
- def __init__(self, url: str) -> None:
- self.url = url
+ def to_json(self):
+ return self.__dict__
class ImagesResponse(Model):
- data: list[Image]
-
- def __init__(self, data: list) -> None:
+ def __init__(self, data: list[Image], created: int = 0) -> None:
self.data = data
+ self.created = created
+
+ def to_json(self):
+ return {
+ **self.__dict__,
+ "data": [image.to_json() for image in self.data]
+ } \ No newline at end of file