diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/client/async_client.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/g4f/client/async_client.py b/g4f/client/async_client.py index 8e1ee33c..07ad3357 100644 --- a/g4f/client/async_client.py +++ b/g4f/client/async_client.py @@ -11,10 +11,9 @@ from .types import AsyncIterResponse, ImageProvider from .image_models import ImageModels from .helper import filter_json, find_stop, filter_none, cast_iter_async from .service import get_last_provider, get_model_and_provider -from ..typing import Union, Iterator, Messages, AsyncIterator, ImageType +from ..typing import Union, Messages, AsyncIterator, ImageType from ..errors import NoImageResponseError from ..image import ImageResponse as ImageProviderResponse -from ..providers.base_provider import AsyncGeneratorProvider try: anext @@ -88,7 +87,7 @@ def create_response( api_key: str = None, **kwargs ): - has_asnyc = isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider) + has_asnyc = hasattr(provider, "create_async_generator") if has_asnyc: create = provider.create_async_generator else: @@ -157,7 +156,7 @@ class Chat(): def __init__(self, client: AsyncClient, provider: ProviderType = None): self.completions = Completions(client, provider) -async def iter_image_response(response: Iterator) -> Union[ImagesResponse, None]: +async def iter_image_response(response: AsyncIterator) -> Union[ImagesResponse, None]: async for chunk in response: if isinstance(chunk, ImageProviderResponse): return ImagesResponse([Image(image) for image in chunk.get_list()]) @@ -182,7 +181,7 @@ class Images(): async def generate(self, prompt, model: str = "", **kwargs) -> ImagesResponse: provider = self.models.get(model, self.provider) - if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider): + if hasattr(provider, "create_async_generator"): response = create_image(self.client, provider, prompt, **kwargs) else: response = await provider.create_async(prompt) @@ -195,7 +194,7 @@ class Images(): async def create_variation(self, image: ImageType, model: str = None, **kwargs): provider = self.models.get(model, self.provider) result = None - if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider): + if hasattr(provider, "create_async_generator"): response = provider.create_async_generator( "", [{"role": "user", "content": "create a image like this"}], |