diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/client/__init__.py | 92 |
1 files changed, 49 insertions, 43 deletions
diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index f6a0f5e8..86a81049 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -16,7 +16,7 @@ from ..providers.response import ResponseType, FinishReason, BaseConversation, S from ..errors import NoImageResponseError, ModelNotFoundError from ..providers.retry_provider import IterListProvider from ..providers.asyncio import get_running_loop, to_sync_generator, async_generator_to_list -from ..Provider.needs_auth.BingCreateImages import BingCreateImages +from ..Provider.needs_auth import BingCreateImages, OpenaiAccount from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse from .image_models import ImageModels from .types import IterResponse, ImageProvider, Client as BaseClient @@ -73,7 +73,7 @@ def iter_response( finish_reason = "stop" if stream: - yield ChatCompletionChunk(chunk, None, completion_id, int(time.time())) + yield ChatCompletionChunk.model_construct(chunk, None, completion_id, int(time.time())) if finish_reason is not None: break @@ -83,12 +83,12 @@ def iter_response( finish_reason = "stop" if finish_reason is None else finish_reason if stream: - yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time())) + yield ChatCompletionChunk.model_construct(None, finish_reason, completion_id, int(time.time())) else: if response_format is not None and "type" in response_format: if response_format["type"] == "json_object": content = filter_json(content) - yield ChatCompletion(content, finish_reason, completion_id, int(time.time())) + yield ChatCompletion.model_construct(content, finish_reason, completion_id, int(time.time())) # Synchronous iter_append_model_and_provider function def iter_append_model_and_provider(response: ChatCompletionResponseType) -> ChatCompletionResponseType: @@ -137,7 +137,7 @@ async def async_iter_response( finish_reason = "stop" if stream: - yield ChatCompletionChunk(chunk, None, completion_id, int(time.time())) + yield ChatCompletionChunk.model_construct(chunk, None, completion_id, int(time.time())) if finish_reason is not None: break @@ -145,15 +145,14 @@ async def async_iter_response( finish_reason = "stop" if finish_reason is None else finish_reason if stream: - yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time())) + yield ChatCompletionChunk.model_construct(None, finish_reason, completion_id, int(time.time())) else: if response_format is not None and "type" in response_format: if response_format["type"] == "json_object": content = filter_json(content) - yield ChatCompletion(content, finish_reason, completion_id, int(time.time())) + yield ChatCompletion.model_construct(content, finish_reason, completion_id, int(time.time())) finally: - if hasattr(response, 'aclose'): - await safe_aclose(response) + await safe_aclose(response) async def async_iter_append_model_and_provider( response: AsyncChatCompletionResponseType @@ -167,8 +166,7 @@ async def async_iter_append_model_and_provider( chunk.provider = last_provider.get("name") yield chunk finally: - if hasattr(response, 'aclose'): - await safe_aclose(response) + await safe_aclose(response) class Client(BaseClient): def __init__( @@ -266,33 +264,39 @@ class Images: """ return asyncio.run(self.async_generate(prompt, model, provider, response_format, proxy, **kwargs)) - async def async_generate( - self, - prompt: str, - model: Optional[str] = None, - provider: Optional[ProviderType] = None, - response_format: Optional[str] = "url", - proxy: Optional[str] = None, - **kwargs - ) -> ImagesResponse: + async def get_provider_handler(self, model: Optional[str], provider: Optional[ImageProvider], default: ImageProvider) -> ImageProvider: if provider is None: - provider_handler = self.models.get(model, provider or self.provider or BingCreateImages) + provider_handler = self.provider + if provider_handler is None: + provider_handler = self.models.get(model, default) elif isinstance(provider, str): provider_handler = convert_to_provider(provider) else: provider_handler = provider if provider_handler is None: - raise ModelNotFoundError(f"Unknown model: {model}") + return default if isinstance(provider_handler, IterListProvider): if provider_handler.providers: provider_handler = provider_handler.providers[0] else: raise ModelNotFoundError(f"IterListProvider for model {model} has no providers") + return provider_handler + + async def async_generate( + self, + prompt: str, + model: Optional[str] = None, + provider: Optional[ProviderType] = None, + response_format: Optional[str] = "url", + proxy: Optional[str] = None, + **kwargs + ) -> ImagesResponse: + provider_handler = await self.get_provider_handler(model, provider, BingCreateImages) if proxy is None: proxy = self.client.proxy response = None - if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider): + if hasattr(provider_handler, "create_async_generator"): messages = [{"role": "user", "content": f"Generate a image: {prompt}"}] async for item in provider_handler.create_async_generator(model, messages, prompt=prompt, **kwargs): if isinstance(item, ImageResponse): @@ -313,7 +317,7 @@ class Images: response = item break else: - raise ValueError(f"Provider {provider} does not support image generation") + raise ValueError(f"Provider {getattr(provider_handler, '__name__')} does not support image generation") if isinstance(response, ImageResponse): return await self._process_image_response( response, @@ -322,6 +326,8 @@ class Images: model, getattr(provider_handler, "__name__", None) ) + if response is None: + raise NoImageResponseError(f"No image response from {getattr(provider_handler, '__name__')}") raise NoImageResponseError(f"Unexpected response type: {type(response)}") def create_variation( @@ -345,32 +351,26 @@ class Images: proxy: Optional[str] = None, **kwargs ) -> ImagesResponse: - if provider is None: - provider = self.models.get(model, provider or self.provider or BingCreateImages) - if provider is None: - raise ModelNotFoundError(f"Unknown model: {model}") - if isinstance(provider, str): - provider = convert_to_provider(provider) + provider_handler = await self.get_provider_handler(model, provider, OpenaiAccount) if proxy is None: proxy = self.client.proxy - if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider): + if hasattr(provider_handler, "create_async_generator"): messages = [{"role": "user", "content": "create a variation of this image"}] generator = None try: - generator = provider.create_async_generator(model, messages, image=image, response_format=response_format, proxy=proxy, **kwargs) + generator = provider_handler.create_async_generator(model, messages, image=image, response_format=response_format, proxy=proxy, **kwargs) async for chunk in generator: if isinstance(chunk, ImageResponse): response = chunk break finally: - if generator and hasattr(generator, 'aclose'): - await safe_aclose(generator) - elif hasattr(provider, 'create_variation'): - if asyncio.iscoroutinefunction(provider.create_variation): - response = await provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs) + await safe_aclose(generator) + elif hasattr(provider_handler, 'create_variation'): + if asyncio.iscoroutinefunction(provider.provider_handler): + response = await provider_handler.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs) else: - response = provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs) + response = provider_handler.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs) else: raise NoImageResponseError(f"Provider {provider} does not support image variation") @@ -378,6 +378,8 @@ class Images: response = ImageResponse([response]) if isinstance(response, ImageResponse): return self._process_image_response(response, response_format, proxy, model, getattr(provider, "__name__", None)) + if response is None: + raise NoImageResponseError(f"No image response from {getattr(provider, '__name__')}") raise NoImageResponseError(f"Unexpected response type: {type(response)}") async def _process_image_response( @@ -394,13 +396,13 @@ class Images: if response_format == "b64_json": with open(os.path.join(images_dir, os.path.basename(image_file)), "rb") as file: image_data = base64.b64encode(file.read()).decode() - return Image(url=image_file, b64_json=image_data, revised_prompt=response.alt) - return Image(url=image_file, revised_prompt=response.alt) + return Image.model_construct(url=image_file, b64_json=image_data, revised_prompt=response.alt) + return Image.model_construct(url=image_file, revised_prompt=response.alt) images = await asyncio.gather(*[process_image_item(image) for image in images]) else: - images = [Image(url=image, revised_prompt=response.alt) for image in response.get_list()] + images = [Image.model_construct(url=image, revised_prompt=response.alt) for image in response.get_list()] last_provider = get_last_provider(True) - return ImagesResponse( + return ImagesResponse.model_construct( images, model=last_provider.get("model") if model is None else model, provider=last_provider.get("name") if provider is None else provider @@ -454,7 +456,11 @@ class AsyncCompletions: ) stop = [stop] if isinstance(stop, str) else stop - response = provider.create_completion( + if hasattr(provider, "create_async_generator"): + create_handler = provider.create_async_generator + else: + create_handler = provider.create_completion + response = create_handler( model, messages, stream=stream, |