diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/gui/server/api.py | 99 |
1 files changed, 38 insertions, 61 deletions
diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py index 29fc34e2..00eb7182 100644 --- a/g4f/gui/server/api.py +++ b/g4f/gui/server/api.py @@ -2,34 +2,22 @@ from __future__ import annotations import logging import os -import uuid import asyncio -import time -from aiohttp import ClientSession -from typing import Iterator, Optional +from typing import Iterator from flask import send_from_directory +from inspect import signature from g4f import version, models from g4f import get_last_provider, ChatCompletion from g4f.errors import VersionNotFoundError -from g4f.typing import Cookies -from g4f.image import ImagePreview, ImageResponse, is_accepted_format, extract_data_uri -from g4f.requests.aiohttp import get_connector +from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir from g4f.Provider import ProviderType, __providers__, __map__ -from g4f.providers.base_provider import ProviderModelMixin, FinishReason -from g4f.providers.conversation import BaseConversation +from g4f.providers.base_provider import ProviderModelMixin +from g4f.providers.response import BaseConversation, FinishReason +from g4f.client.service import convert_to_provider from g4f import debug logger = logging.getLogger(__name__) - -# Define the directory for generated images -images_dir = "./generated_images" - -# Function to ensure the images directory exists -def ensure_images_dir(): - if not os.path.exists(images_dir): - os.makedirs(images_dir) - conversations: dict[dict[str, BaseConversation]] = {} class Api: @@ -42,7 +30,10 @@ class Api: if provider in __map__: provider: ProviderType = __map__[provider] if issubclass(provider, ProviderModelMixin): - models = provider.get_models() if api_key is None else provider.get_models(api_key=api_key) + if api_key is not None and "api_key" in signature(provider.get_models).parameters: + models = provider.get_models(api_key=api_key) + else: + models = provider.get_models() return [ { "model": model, @@ -90,7 +81,7 @@ class Api: def get_providers() -> list[str]: return { provider.__name__: (provider.label if hasattr(provider, "label") else provider.__name__) - + (" (Image Generation)" if hasattr(provider, "image_models") else "") + + (" (Image Generation)" if getattr(provider, "image_models", None) else "") + (" (Image Upload)" if getattr(provider, "default_vision_model", None) else "") + (" (WebDriver)" if "webdriver" in provider.get_parameters() else "") + (" (Auth)" if provider.needs_auth else "") @@ -120,16 +111,23 @@ class Api: api_key = json_data.get("api_key") if api_key is not None: kwargs["api_key"] = api_key - if json_data.get('web_search'): - if provider: - kwargs['web_search'] = True - else: - from .internet import get_search_message - messages[-1]["content"] = get_search_message(messages[-1]["content"]) + do_web_search = json_data.get('web_search') + if do_web_search and provider: + provider_handler = convert_to_provider(provider) + if hasattr(provider_handler, "get_parameters"): + if "web_search" in provider_handler.get_parameters(): + kwargs['web_search'] = True + do_web_search = False + if do_web_search: + from .internet import get_search_message + messages[-1]["content"] = get_search_message(messages[-1]["content"]) + if json_data.get("auto_continue"): + kwargs['auto_continue'] = True conversation_id = json_data.get("conversation_id") - if conversation_id and provider in conversations and conversation_id in conversations[provider]: - kwargs["conversation"] = conversations[provider][conversation_id] + if conversation_id and provider: + if provider in conversations and conversation_id in conversations[provider]: + kwargs["conversation"] = conversations[provider][conversation_id] return { "model": model, @@ -141,7 +139,7 @@ class Api: **kwargs } - def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str) -> Iterator: + def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator: if debug.logging: debug.logs = [] print_callback = debug.log_handler @@ -163,18 +161,22 @@ class Api: first = False yield self._format_json("provider", get_last_provider(True)) if isinstance(chunk, BaseConversation): - if provider not in conversations: - conversations[provider] = {} - conversations[provider][conversation_id] = chunk - yield self._format_json("conversation", conversation_id) + if provider: + if provider not in conversations: + conversations[provider] = {} + conversations[provider][conversation_id] = chunk + yield self._format_json("conversation", conversation_id) elif isinstance(chunk, Exception): logger.exception(chunk) yield self._format_json("message", get_error_message(chunk)) elif isinstance(chunk, ImagePreview): yield self._format_json("preview", chunk.to_string()) elif isinstance(chunk, ImageResponse): - images = asyncio.run(self._copy_images(chunk.get_list(), chunk.options.get("cookies"))) - yield self._format_json("content", str(ImageResponse(images, chunk.alt))) + images = chunk + if download_images: + images = asyncio.run(copy_images(chunk.get_list(), chunk.options.get("cookies"))) + images = ImageResponse(images, chunk.alt) + yield self._format_json("content", str(images)) elif not isinstance(chunk, FinishReason): yield self._format_json("content", str(chunk)) if debug.logs: @@ -185,31 +187,6 @@ class Api: logger.exception(e) yield self._format_json('error', get_error_message(e)) - async def _copy_images(self, images: list[str], cookies: Optional[Cookies] = None): - ensure_images_dir() - async with ClientSession( - connector=get_connector(None, os.environ.get("G4F_PROXY")), - cookies=cookies - ) as session: - async def copy_image(image: str) -> str: - target = os.path.join(images_dir, f"{int(time.time())}_{str(uuid.uuid4())}") - if image.startswith("data:"): - with open(target, "wb") as f: - f.write(extract_data_uri(image)) - else: - async with session.get(image) as response: - with open(target, "wb") as f: - async for chunk in response.content.iter_any(): - f.write(chunk) - with open(target, "rb") as f: - extension = is_accepted_format(f.read(12)).split("/")[-1] - extension = "jpg" if extension == "jpeg" else extension - new_target = f"{target}.{extension}" - os.rename(target, new_target) - return f"/images/{os.path.basename(new_target)}" - - return await asyncio.gather(*[copy_image(image) for image in images]) - def _format_json(self, response_type: str, content): return { 'type': response_type, @@ -221,4 +198,4 @@ def get_error_message(exception: Exception) -> str: provider = get_last_provider() if provider is None: return message - return f"{provider.__name__}: {message}" + return f"{provider.__name__}: {message}"
\ No newline at end of file |