From d44b39b31c83c6a4bc636bea931275702c700feb Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Sat, 6 Apr 2024 01:05:00 +0200 Subject: Add Groq and Openai interfaces, Add integration tests --- etc/unittest/__main__.py | 1 + etc/unittest/integration.py | 25 +++++++++++++ g4f/Provider/base_provider.py | 1 + g4f/Provider/needs_auth/Groq.py | 23 ++++++++++++ g4f/Provider/needs_auth/Openai.py | 74 +++++++++++++++++++++++++++++++++++++ g4f/Provider/needs_auth/__init__.py | 4 +- g4f/client.py | 5 ++- g4f/providers/base_provider.py | 37 ++++++++++--------- g4f/providers/types.py | 6 ++- g4f/requests/aiohttp.py | 16 ++++++-- 10 files changed, 167 insertions(+), 25 deletions(-) create mode 100644 etc/unittest/integration.py create mode 100644 g4f/Provider/needs_auth/Groq.py create mode 100644 g4f/Provider/needs_auth/Openai.py diff --git a/etc/unittest/__main__.py b/etc/unittest/__main__.py index 06b2dff5..3a459dba 100644 --- a/etc/unittest/__main__.py +++ b/etc/unittest/__main__.py @@ -5,5 +5,6 @@ from .main import * from .model import * from .client import * from .include import * +from .integration import * unittest.main() \ No newline at end of file diff --git a/etc/unittest/integration.py b/etc/unittest/integration.py new file mode 100644 index 00000000..808a8d1d --- /dev/null +++ b/etc/unittest/integration.py @@ -0,0 +1,25 @@ +import unittest +import json + +from g4f.client import Client, ChatCompletion +from g4f.Provider import Bing, OpenaiChat + +DEFAULT_MESSAGES = [{"role": "system", "content": 'Response in json, Example: {"success: True"}'}, + {"role": "user", "content": "Say success true in json"}] + +class TestProviderIntegration(unittest.TestCase): + + def test_bing(self): + client = Client(provider=Bing) + response = client.chat.completions.create(DEFAULT_MESSAGES, "", response_format={"type": "json_object"}) + self.assertIsInstance(response, ChatCompletion) + self.assertIn("success", json.loads(response.choices[0].message.content)) + + def test_openai(self): + client = Client(provider=OpenaiChat) + response = client.chat.completions.create(DEFAULT_MESSAGES, "", response_format={"type": "json_object"}) + self.assertIsInstance(response, ChatCompletion) + self.assertIn("success", json.loads(response.choices[0].message.content)) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index 8e761dba..4c0157f3 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -1,2 +1,3 @@ from ..providers.base_provider import * +from ..providers.types import FinishReason from .helper import get_cookies, format_prompt \ No newline at end of file diff --git a/g4f/Provider/needs_auth/Groq.py b/g4f/Provider/needs_auth/Groq.py new file mode 100644 index 00000000..87e87e60 --- /dev/null +++ b/g4f/Provider/needs_auth/Groq.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from .Openai import Openai +from ...typing import AsyncResult, Messages + +class Groq(Openai): + url = "https://console.groq.com/playground" + working = True + default_model = "mixtral-8x7b-32768" + models = ["mixtral-8x7b-32768", "llama2-70b-4096", "gemma-7b-it"] + model_aliases = {"mixtral-8x7b": "mixtral-8x7b-32768", "llama2-70b": "llama2-70b-4096"} + + @classmethod + def create_async_generator( + cls, + model: str, + messages: Messages, + api_base: str = "https://api.groq.com/openai/v1", + **kwargs + ) -> AsyncResult: + return super().create_async_generator( + model, messages, api_base=api_base, **kwargs + ) \ No newline at end of file diff --git a/g4f/Provider/needs_auth/Openai.py b/g4f/Provider/needs_auth/Openai.py new file mode 100644 index 00000000..b876cd0b --- /dev/null +++ b/g4f/Provider/needs_auth/Openai.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import json + +from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason +from ...typing import AsyncResult, Messages +from ...requests.raise_for_status import raise_for_status +from ...requests import StreamSession +from ...errors import MissingAuthError + +class Openai(AsyncGeneratorProvider, ProviderModelMixin): + url = "https://openai.com" + working = True + needs_auth = True + supports_message_history = True + supports_system_message = True + + @classmethod + async def create_async_generator( + cls, + model: str, + messages: Messages, + proxy: str = None, + timeout: int = 120, + api_key: str = None, + api_base: str = "https://api.openai.com/v1", + temperature: float = None, + max_tokens: int = None, + top_p: float = None, + stop: str = None, + stream: bool = False, + **kwargs + ) -> AsyncResult: + if api_key is None: + raise MissingAuthError('Add a "api_key"') + async with StreamSession( + proxies={"all": proxy}, + headers=cls.get_headers(api_key), + timeout=timeout + ) as session: + data = { + "messages": messages, + "model": cls.get_model(model), + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "stop": stop, + "stream": stream, + } + async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response: + await raise_for_status(response) + async for line in response.iter_lines(): + if line.startswith(b"data: ") or not stream: + async for chunk in cls.read_line(line[6:] if stream else line, stream): + yield chunk + + @staticmethod + async def read_line(line: str, stream: bool): + if line == b"[DONE]": + return + choice = json.loads(line)["choices"][0] + if stream and "content" in choice["delta"] and choice["delta"]["content"]: + yield choice["delta"]["content"] + elif not stream and "content" in choice["message"]: + yield choice["message"]["content"] + if "finish_reason" in choice and choice["finish_reason"] is not None: + yield FinishReason(choice["finish_reason"]) + + @staticmethod + def get_headers(api_key: str) -> dict: + return { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } \ No newline at end of file diff --git a/g4f/Provider/needs_auth/__init__.py b/g4f/Provider/needs_auth/__init__.py index 5eb1b2eb..92fa165b 100644 --- a/g4f/Provider/needs_auth/__init__.py +++ b/g4f/Provider/needs_auth/__init__.py @@ -4,4 +4,6 @@ from .Theb import Theb from .ThebApi import ThebApi from .OpenaiChat import OpenaiChat from .OpenAssistant import OpenAssistant -from .Poe import Poe \ No newline at end of file +from .Poe import Poe +from .Openai import Openai +from .Groq import Groq \ No newline at end of file diff --git a/g4f/client.py b/g4f/client.py index d7ceb009..2c4fe788 100644 --- a/g4f/client.py +++ b/g4f/client.py @@ -8,7 +8,7 @@ import string from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse from .typing import Union, Iterator, Messages, ImageType -from .providers.types import BaseProvider, ProviderType +from .providers.types import BaseProvider, ProviderType, FinishReason from .image import ImageResponse as ImageProviderResponse from .errors import NoImageResponseError, RateLimitError, MissingAuthError from . import get_model_and_provider, get_last_provider @@ -47,6 +47,9 @@ def iter_response( finish_reason = None completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28)) for idx, chunk in enumerate(response): + if isinstance(chunk, FinishReason): + finish_reason = chunk.reason + break content += str(chunk) if max_tokens is not None and idx + 1 >= max_tokens: finish_reason = "length" diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py index ee5bcbb8..37f4af15 100644 --- a/g4f/providers/base_provider.py +++ b/g4f/providers/base_provider.py @@ -6,9 +6,10 @@ from asyncio import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor from abc import abstractmethod from inspect import signature, Parameter -from ..typing import CreateResult, AsyncResult, Messages, Union -from .types import BaseProvider -from ..errors import NestAsyncioError, ModelNotSupportedError +from typing import Callable, Union +from ..typing import CreateResult, AsyncResult, Messages +from .types import BaseProvider, FinishReason +from ..errors import NestAsyncioError, ModelNotSupportedError, MissingRequirementsError from .. import debug if sys.version_info < (3, 10): @@ -21,17 +22,23 @@ if sys.platform == 'win32': if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) -def get_running_loop() -> Union[AbstractEventLoop, None]: +def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]: try: loop = asyncio.get_running_loop() - if not hasattr(loop.__class__, "_nest_patched"): - raise NestAsyncioError( - 'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.' - ) + if check_nested and not hasattr(loop.__class__, "_nest_patched"): + try: + import nest_asyncio + nest_asyncio.apply(loop) + except ImportError: + raise MissingRequirementsError('Install "nest_asyncio" package') return loop except RuntimeError: pass +# Fix for RuntimeError: async generator ignored GeneratorExit +async def await_callback(callback: Callable): + return await callback() + class AbstractProvider(BaseProvider): """ Abstract class for providing asynchronous functionality to derived classes. @@ -132,7 +139,7 @@ class AsyncProvider(AbstractProvider): Returns: CreateResult: The result of the completion creation. """ - get_running_loop() + get_running_loop(check_nested=True) yield asyncio.run(cls.create_async(model, messages, **kwargs)) @staticmethod @@ -158,7 +165,6 @@ class AsyncProvider(AbstractProvider): """ raise NotImplementedError() - class AsyncGeneratorProvider(AsyncProvider): """ Provides asynchronous generator functionality for streaming results. @@ -187,9 +193,9 @@ class AsyncGeneratorProvider(AsyncProvider): Returns: CreateResult: The result of the streaming completion creation. """ - loop = get_running_loop() + loop = get_running_loop(check_nested=True) new_loop = False - if not loop: + if loop is None: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) new_loop = True @@ -197,16 +203,11 @@ class AsyncGeneratorProvider(AsyncProvider): generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) gen = generator.__aiter__() - # Fix for RuntimeError: async generator ignored GeneratorExit - async def await_callback(callback): - return await callback() - try: while True: yield loop.run_until_complete(await_callback(gen.__anext__)) except StopAsyncIteration: ... - # Fix for: ResourceWarning: unclosed event loop finally: if new_loop: loop.close() @@ -233,7 +234,7 @@ class AsyncGeneratorProvider(AsyncProvider): """ return "".join([ chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs) - if not isinstance(chunk, Exception) + if not isinstance(chunk, (Exception, FinishReason)) ]) @staticmethod diff --git a/g4f/providers/types.py b/g4f/providers/types.py index 67340958..a3eeb99e 100644 --- a/g4f/providers/types.py +++ b/g4f/providers/types.py @@ -97,4 +97,8 @@ class BaseRetryProvider(BaseProvider): __name__: str = "RetryProvider" supports_stream: bool = True -ProviderType = Union[Type[BaseProvider], BaseRetryProvider] \ No newline at end of file +ProviderType = Union[Type[BaseProvider], BaseRetryProvider] + +class FinishReason(): + def __init__(self, reason: str): + self.reason = reason \ No newline at end of file diff --git a/g4f/requests/aiohttp.py b/g4f/requests/aiohttp.py index 16b052eb..71e7bde7 100644 --- a/g4f/requests/aiohttp.py +++ b/g4f/requests/aiohttp.py @@ -15,11 +15,19 @@ class StreamResponse(ClientResponse): async for chunk in self.content.iter_any(): yield chunk - async def json(self) -> Any: - return await super().json(content_type=None) + async def json(self, content_type: str = None) -> Any: + return await super().json(content_type=content_type) class StreamSession(ClientSession): - def __init__(self, headers: dict = {}, timeout: int = None, proxies: dict = {}, impersonate = None, **kwargs): + def __init__( + self, + headers: dict = {}, + timeout: int = None, + connector: BaseConnector = None, + proxies: dict = {}, + impersonate = None, + **kwargs + ): if impersonate: headers = { **DEFAULT_HEADERS, @@ -29,7 +37,7 @@ class StreamSession(ClientSession): **kwargs, timeout=ClientTimeout(timeout) if timeout else None, response_class=StreamResponse, - connector=get_connector(kwargs.get("connector"), proxies.get("https")), + connector=get_connector(connector, proxies.get("all", proxies.get("https"))), headers=headers ) -- cgit v1.2.3