From 9cbe9c1ccb2381e37402a36297f11a0f96b1b557 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Sun, 21 Jan 2024 02:20:23 +0100 Subject: Improve tests --- g4f/Provider/Bing.py | 13 +++++-------- g4f/Provider/base_provider.py | 41 +++++++++++++++++++++++------------------ g4f/gui/server/backend.py | 3 +-- g4f/image.py | 14 +++++++------- g4f/typing.py | 7 +++++++ 5 files changed, 43 insertions(+), 35 deletions(-) (limited to 'g4f') diff --git a/g4f/Provider/Bing.py b/g4f/Provider/Bing.py index 34687866..b869a6ef 100644 --- a/g4f/Provider/Bing.py +++ b/g4f/Provider/Bing.py @@ -64,12 +64,7 @@ class Bing(AsyncGeneratorProvider): prompt = messages[-1]["content"] context = create_context(messages[:-1]) - if not cookies: - cookies = Defaults.cookies - else: - for key, value in Defaults.cookies.items(): - if key not in cookies: - cookies[key] = value + cookies = {**Defaults.cookies, **cookies} if cookies else Defaults.cookies gpt4_turbo = True if model.startswith("gpt-4-turbo") else False @@ -207,10 +202,12 @@ def create_message( request_id = str(uuid.uuid4()) struct = { 'arguments': [{ - 'source': 'cib', 'optionsSets': options_sets, + 'source': 'cib', + 'optionsSets': options_sets, 'allowedMessageTypes': Defaults.allowedMessageTypes, 'sliceIds': Defaults.sliceIds, - 'traceId': os.urandom(16).hex(), 'isStartOfSession': True, + 'traceId': os.urandom(16).hex(), + 'isStartOfSession': True, 'requestId': request_id, 'message': { **Defaults.location, diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index 95f1b0b2..bc47a1fa 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -5,8 +5,8 @@ from asyncio import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor from abc import abstractmethod from inspect import signature, Parameter -from .helper import get_event_loop, get_cookies, format_prompt -from ..typing import CreateResult, AsyncResult, Messages +from .helper import get_cookies, format_prompt +from ..typing import CreateResult, AsyncResult, Messages, Union from ..base_provider import BaseProvider from ..errors import NestAsyncioError @@ -20,6 +20,17 @@ if sys.platform == 'win32': if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) +def get_running_loop() -> Union[AbstractEventLoop, None]: + try: + loop = asyncio.get_running_loop() + if not hasattr(loop.__class__, "_nest_patched"): + raise NestAsyncioError( + 'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.' + ) + return loop + except RuntimeError: + pass + class AbstractProvider(BaseProvider): """ Abstract class for providing asynchronous functionality to derived classes. @@ -56,7 +67,7 @@ class AbstractProvider(BaseProvider): return await asyncio.wait_for( loop.run_in_executor(executor, create_func), - timeout=kwargs.get("timeout", 0) + timeout=kwargs.get("timeout") ) @classmethod @@ -118,14 +129,7 @@ class AsyncProvider(AbstractProvider): Returns: CreateResult: The result of the completion creation. """ - try: - loop = asyncio.get_running_loop() - if not hasattr(loop.__class__, "_nest_patched"): - raise NestAsyncioError( - 'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.' - ) - except RuntimeError: - pass + get_running_loop() yield asyncio.run(cls.create_async(model, messages, **kwargs)) @staticmethod @@ -180,15 +184,12 @@ class AsyncGeneratorProvider(AsyncProvider): Returns: CreateResult: The result of the streaming completion creation. """ - try: - loop = asyncio.get_running_loop() - if not hasattr(loop.__class__, "_nest_patched"): - raise NestAsyncioError( - 'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.' - ) - except RuntimeError: + loop = get_running_loop() + new_loop = False + if not loop: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) + new_loop = True generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) gen = generator.__aiter__() @@ -199,6 +200,10 @@ class AsyncGeneratorProvider(AsyncProvider): except StopAsyncIteration: break + if new_loop: + loop.close() + asyncio.set_event_loop(None) + @classmethod async def create_async( cls, diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py index b4c8f56c..d5c59ed1 100644 --- a/g4f/gui/server/backend.py +++ b/g4f/gui/server/backend.py @@ -2,7 +2,7 @@ import logging import json from flask import request, Flask from typing import Generator -from g4f import debug, version, models +from g4f import version, models from g4f import _all_models, get_last_provider, ChatCompletion from g4f.image import is_allowed_extension, to_image from g4f.errors import VersionNotFoundError @@ -10,7 +10,6 @@ from g4f.Provider import __providers__ from g4f.Provider.bing.create_images import patch_provider from .internet import get_search_message -debug.logging = True class Backend_Api: """ diff --git a/g4f/image.py b/g4f/image.py index cfa22ab1..24ded915 100644 --- a/g4f/image.py +++ b/g4f/image.py @@ -112,7 +112,7 @@ def get_orientation(image: Image.Image) -> int: """ exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif() if exif_data is not None: - orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF + orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF if orientation is not None: return orientation @@ -156,23 +156,23 @@ def to_base64(image: Image.Image, compression_rate: float) -> str: image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100)) return base64.b64encode(output_buffer.getvalue()).decode() -def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=200") -> str: +def format_images_markdown(images, alt: str, preview: str="{image}?w=200&h=200") -> str: """ Formats the given images as a markdown string. Args: images: The images to format. - prompt (str): The prompt for the images. + alt (str): The alt for the images. preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200". Returns: str: The formatted markdown string. """ - if isinstance(images, list): - images = [f"[![#{idx+1} {prompt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)] - images = "\n".join(images) + if isinstance(images, str): + images = f"[![{alt}]({preview.replace('{image}', images)})]({images})" else: - images = f"[![{prompt}]({images})]({images})" + images = [f"[![#{idx+1} {alt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)] + images = "\n".join(images) start_flag = "\n" end_flag = "\n" return f"\n{start_flag}{images}\n{end_flag}\n" diff --git a/g4f/typing.py b/g4f/typing.py index c972f505..a6a62e3f 100644 --- a/g4f/typing.py +++ b/g4f/typing.py @@ -18,7 +18,14 @@ __all__ = [ 'AsyncGenerator', 'Generator', 'Tuple', + 'Union', + 'List', + 'Dict', + 'Type', 'TypedDict', 'SHA256', 'CreateResult', + 'AsyncResult', + 'Messages', + 'ImageType' ] -- cgit v1.2.3