diff options
Diffstat (limited to 'g4f/gui/server')
-rw-r--r-- | g4f/gui/server/api.py | 99 | ||||
-rw-r--r-- | g4f/gui/server/backend.py | 7 | ||||
-rw-r--r-- | g4f/gui/server/internet.py | 20 |
3 files changed, 57 insertions, 69 deletions
diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py index 29fc34e2..00eb7182 100644 --- a/g4f/gui/server/api.py +++ b/g4f/gui/server/api.py @@ -2,34 +2,22 @@ from __future__ import annotations import logging import os -import uuid import asyncio -import time -from aiohttp import ClientSession -from typing import Iterator, Optional +from typing import Iterator from flask import send_from_directory +from inspect import signature from g4f import version, models from g4f import get_last_provider, ChatCompletion from g4f.errors import VersionNotFoundError -from g4f.typing import Cookies -from g4f.image import ImagePreview, ImageResponse, is_accepted_format, extract_data_uri -from g4f.requests.aiohttp import get_connector +from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir from g4f.Provider import ProviderType, __providers__, __map__ -from g4f.providers.base_provider import ProviderModelMixin, FinishReason -from g4f.providers.conversation import BaseConversation +from g4f.providers.base_provider import ProviderModelMixin +from g4f.providers.response import BaseConversation, FinishReason +from g4f.client.service import convert_to_provider from g4f import debug logger = logging.getLogger(__name__) - -# Define the directory for generated images -images_dir = "./generated_images" - -# Function to ensure the images directory exists -def ensure_images_dir(): - if not os.path.exists(images_dir): - os.makedirs(images_dir) - conversations: dict[dict[str, BaseConversation]] = {} class Api: @@ -42,7 +30,10 @@ class Api: if provider in __map__: provider: ProviderType = __map__[provider] if issubclass(provider, ProviderModelMixin): - models = provider.get_models() if api_key is None else provider.get_models(api_key=api_key) + if api_key is not None and "api_key" in signature(provider.get_models).parameters: + models = provider.get_models(api_key=api_key) + else: + models = provider.get_models() return [ { "model": model, @@ -90,7 +81,7 @@ class Api: def get_providers() -> list[str]: return { provider.__name__: (provider.label if hasattr(provider, "label") else provider.__name__) - + (" (Image Generation)" if hasattr(provider, "image_models") else "") + + (" (Image Generation)" if getattr(provider, "image_models", None) else "") + (" (Image Upload)" if getattr(provider, "default_vision_model", None) else "") + (" (WebDriver)" if "webdriver" in provider.get_parameters() else "") + (" (Auth)" if provider.needs_auth else "") @@ -120,16 +111,23 @@ class Api: api_key = json_data.get("api_key") if api_key is not None: kwargs["api_key"] = api_key - if json_data.get('web_search'): - if provider: - kwargs['web_search'] = True - else: - from .internet import get_search_message - messages[-1]["content"] = get_search_message(messages[-1]["content"]) + do_web_search = json_data.get('web_search') + if do_web_search and provider: + provider_handler = convert_to_provider(provider) + if hasattr(provider_handler, "get_parameters"): + if "web_search" in provider_handler.get_parameters(): + kwargs['web_search'] = True + do_web_search = False + if do_web_search: + from .internet import get_search_message + messages[-1]["content"] = get_search_message(messages[-1]["content"]) + if json_data.get("auto_continue"): + kwargs['auto_continue'] = True conversation_id = json_data.get("conversation_id") - if conversation_id and provider in conversations and conversation_id in conversations[provider]: - kwargs["conversation"] = conversations[provider][conversation_id] + if conversation_id and provider: + if provider in conversations and conversation_id in conversations[provider]: + kwargs["conversation"] = conversations[provider][conversation_id] return { "model": model, @@ -141,7 +139,7 @@ class Api: **kwargs } - def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str) -> Iterator: + def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator: if debug.logging: debug.logs = [] print_callback = debug.log_handler @@ -163,18 +161,22 @@ class Api: first = False yield self._format_json("provider", get_last_provider(True)) if isinstance(chunk, BaseConversation): - if provider not in conversations: - conversations[provider] = {} - conversations[provider][conversation_id] = chunk - yield self._format_json("conversation", conversation_id) + if provider: + if provider not in conversations: + conversations[provider] = {} + conversations[provider][conversation_id] = chunk + yield self._format_json("conversation", conversation_id) elif isinstance(chunk, Exception): logger.exception(chunk) yield self._format_json("message", get_error_message(chunk)) elif isinstance(chunk, ImagePreview): yield self._format_json("preview", chunk.to_string()) elif isinstance(chunk, ImageResponse): - images = asyncio.run(self._copy_images(chunk.get_list(), chunk.options.get("cookies"))) - yield self._format_json("content", str(ImageResponse(images, chunk.alt))) + images = chunk + if download_images: + images = asyncio.run(copy_images(chunk.get_list(), chunk.options.get("cookies"))) + images = ImageResponse(images, chunk.alt) + yield self._format_json("content", str(images)) elif not isinstance(chunk, FinishReason): yield self._format_json("content", str(chunk)) if debug.logs: @@ -185,31 +187,6 @@ class Api: logger.exception(e) yield self._format_json('error', get_error_message(e)) - async def _copy_images(self, images: list[str], cookies: Optional[Cookies] = None): - ensure_images_dir() - async with ClientSession( - connector=get_connector(None, os.environ.get("G4F_PROXY")), - cookies=cookies - ) as session: - async def copy_image(image: str) -> str: - target = os.path.join(images_dir, f"{int(time.time())}_{str(uuid.uuid4())}") - if image.startswith("data:"): - with open(target, "wb") as f: - f.write(extract_data_uri(image)) - else: - async with session.get(image) as response: - with open(target, "wb") as f: - async for chunk in response.content.iter_any(): - f.write(chunk) - with open(target, "rb") as f: - extension = is_accepted_format(f.read(12)).split("/")[-1] - extension = "jpg" if extension == "jpeg" else extension - new_target = f"{target}.{extension}" - os.rename(target, new_target) - return f"/images/{os.path.basename(new_target)}" - - return await asyncio.gather(*[copy_image(image) for image in images]) - def _format_json(self, response_type: str, content): return { 'type': response_type, @@ -221,4 +198,4 @@ def get_error_message(exception: Exception) -> str: provider = get_last_provider() if provider is None: return message - return f"{provider.__name__}: {message}" + return f"{provider.__name__}: {message}"
\ No newline at end of file diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py index 020e49ef..917d779e 100644 --- a/g4f/gui/server/backend.py +++ b/g4f/gui/server/backend.py @@ -89,7 +89,12 @@ class Backend_Api(Api): kwargs = self._prepare_conversation_kwargs(json_data, kwargs) return self.app.response_class( - self._create_response_stream(kwargs, json_data.get("conversation_id"), json_data.get("provider")), + self._create_response_stream( + kwargs, + json_data.get("conversation_id"), + json_data.get("provider"), + json_data.get("download_images", True), + ), mimetype='text/event-stream' ) diff --git a/g4f/gui/server/internet.py b/g4f/gui/server/internet.py index b41b5eae..bafa3af7 100644 --- a/g4f/gui/server/internet.py +++ b/g4f/gui/server/internet.py @@ -8,12 +8,14 @@ try: except ImportError: has_requirements = False from ...errors import MissingRequirementsError - +from ... import debug + import asyncio class SearchResults(): - def __init__(self, results: list): + def __init__(self, results: list, used_words: int): self.results = results + self.used_words = used_words def __iter__(self): yield from self.results @@ -104,7 +106,8 @@ async def search(query: str, n_results: int = 5, max_words: int = 2500, add_text region="wt-wt", safesearch="moderate", timelimit="y", - max_results=n_results + max_results=n_results, + backend="html" ): results.append(SearchResultEntry( result["title"], @@ -120,6 +123,7 @@ async def search(query: str, n_results: int = 5, max_words: int = 2500, add_text texts = await asyncio.gather(*requests) formatted_results = [] + used_words = 0 left_words = max_words for i, entry in enumerate(results): if add_text: @@ -132,13 +136,14 @@ async def search(query: str, n_results: int = 5, max_words: int = 2500, add_text left_words -= entry.snippet.count(" ") if 0 > left_words: break + used_words = max_words - left_words formatted_results.append(entry) - return SearchResults(formatted_results) + return SearchResults(formatted_results, used_words) -def get_search_message(prompt) -> str: +def get_search_message(prompt, n_results: int = 5, max_words: int = 2500) -> str: try: - search_results = asyncio.run(search(prompt)) + search_results = asyncio.run(search(prompt, n_results, max_words)) message = f""" {search_results} @@ -149,7 +154,8 @@ Make sure to add the sources of cites using [[Number]](Url) notation after the r User request: {prompt} """ + debug.log(f"Web search: '{prompt.strip()[:50]}...' {search_results.used_words} Words") return message except Exception as e: - print("Couldn't do web search:", e) + debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}") return prompt
\ No newline at end of file |