From 91feb34054f529c37e10d98d2471c8c0c6780147 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Tue, 23 Jan 2024 19:44:48 +0100 Subject: Add ProviderModelMixin for model selection --- g4f/Provider/DeepInfra.py | 26 +++++++++----- g4f/Provider/HuggingChat.py | 14 +++----- g4f/Provider/Liaobots.py | 13 ++++--- g4f/Provider/Llama2.py | 26 +++++++------- g4f/Provider/PerplexityLabs.py | 21 +++++------ g4f/Provider/base_provider.py | 23 ++++++++++-- g4f/Provider/needs_auth/OpenaiChat.py | 66 +++++++++++++++-------------------- g4f/errors.py | 3 ++ 8 files changed, 104 insertions(+), 88 deletions(-) (limited to 'g4f') diff --git a/g4f/Provider/DeepInfra.py b/g4f/Provider/DeepInfra.py index acde1200..2f34b679 100644 --- a/g4f/Provider/DeepInfra.py +++ b/g4f/Provider/DeepInfra.py @@ -1,18 +1,27 @@ from __future__ import annotations import json -from ..typing import AsyncResult, Messages -from .base_provider import AsyncGeneratorProvider -from ..requests import StreamSession +import requests +from ..typing import AsyncResult, Messages +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin +from ..requests import StreamSession -class DeepInfra(AsyncGeneratorProvider): +class DeepInfra(AsyncGeneratorProvider, ProviderModelMixin): url = "https://deepinfra.com" working = True supports_stream = True supports_message_history = True - + default_model = 'meta-llama/Llama-2-70b-chat-hf' + @staticmethod + def get_models(): + url = 'https://api.deepinfra.com/models/featured' + models = requests.get(url).json() + return [model['model_name'] for model in models] + + @classmethod async def create_async_generator( + cls, model: str, messages: Messages, stream: bool, @@ -21,8 +30,6 @@ class DeepInfra(AsyncGeneratorProvider): auth: str = None, **kwargs ) -> AsyncResult: - if not model: - model = 'meta-llama/Llama-2-70b-chat-hf' headers = { 'Accept-Encoding': 'gzip, deflate, br', 'Accept-Language': 'en-US', @@ -49,7 +56,7 @@ class DeepInfra(AsyncGeneratorProvider): impersonate="chrome110" ) as session: json_data = { - 'model' : model, + 'model' : cls.get_model(model), 'messages': messages, 'stream' : True } @@ -70,7 +77,8 @@ class DeepInfra(AsyncGeneratorProvider): if token: if first: token = token.lstrip() + if token: first = False - yield token + yield token except Exception: raise RuntimeError(f"Response: {line}") diff --git a/g4f/Provider/HuggingChat.py b/g4f/Provider/HuggingChat.py index 3b4a520c..d493da8f 100644 --- a/g4f/Provider/HuggingChat.py +++ b/g4f/Provider/HuggingChat.py @@ -5,11 +5,11 @@ import json, uuid from aiohttp import ClientSession from ..typing import AsyncResult, Messages -from .base_provider import AsyncGeneratorProvider +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .helper import format_prompt, get_cookies -class HuggingChat(AsyncGeneratorProvider): +class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin): url = "https://huggingface.co/chat" working = True default_model = "meta-llama/Llama-2-70b-chat-hf" @@ -21,7 +21,7 @@ class HuggingChat(AsyncGeneratorProvider): "mistralai/Mistral-7B-Instruct-v0.2", "openchat/openchat-3.5-0106" ] - model_map = { + model_aliases = { "openchat/openchat_3.5": "openchat/openchat-3.5-1210", "mistralai/Mixtral-8x7B-Instruct-v0.1": "mistralai/Mistral-7B-Instruct-v0.2" } @@ -37,12 +37,6 @@ class HuggingChat(AsyncGeneratorProvider): cookies: dict = None, **kwargs ) -> AsyncResult: - if not model: - model = cls.default_model - elif model in cls.model_map: - model = cls.model_map[model] - elif model not in cls.models: - raise ValueError(f"Model is not supported: {model}") if not cookies: cookies = get_cookies(".huggingface.co") @@ -53,7 +47,7 @@ class HuggingChat(AsyncGeneratorProvider): cookies=cookies, headers=headers ) as session: - async with session.post(f"{cls.url}/conversation", json={"model": model}, proxy=proxy) as response: + async with session.post(f"{cls.url}/conversation", json={"model": cls.get_model(model)}, proxy=proxy) as response: conversation_id = (await response.json())["conversationId"] send = { diff --git a/g4f/Provider/Liaobots.py b/g4f/Provider/Liaobots.py index 88f0c4ff..5151c115 100644 --- a/g4f/Provider/Liaobots.py +++ b/g4f/Provider/Liaobots.py @@ -5,7 +5,7 @@ import uuid from aiohttp import ClientSession from ..typing import AsyncResult, Messages -from .base_provider import AsyncGeneratorProvider +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin models = { "gpt-4": { @@ -70,13 +70,17 @@ models = { } } - -class Liaobots(AsyncGeneratorProvider): +class Liaobots(AsyncGeneratorProvider, ProviderModelMixin): url = "https://liaobots.site" working = True supports_message_history = True supports_gpt_35_turbo = True supports_gpt_4 = True + default_model = "gpt-3.5-turbo" + models = [m for m in models] + model_aliases = { + "claude-v2": "claude-2" + } _auth_code = None _cookie_jar = None @@ -89,7 +93,6 @@ class Liaobots(AsyncGeneratorProvider): proxy: str = None, **kwargs ) -> AsyncResult: - model = model if model in models else "gpt-3.5-turbo" headers = { "authority": "liaobots.com", "content-type": "application/json", @@ -122,7 +125,7 @@ class Liaobots(AsyncGeneratorProvider): data = { "conversationId": str(uuid.uuid4()), - "model": models[model], + "model": models[cls.get_model(model)], "messages": messages, "key": "", "prompt": kwargs.get("system_message", "You are ChatGPT, a large language model trained by OpenAI. Follow the user's instructions carefully."), diff --git a/g4f/Provider/Llama2.py b/g4f/Provider/Llama2.py index 17969621..d1f8e194 100644 --- a/g4f/Provider/Llama2.py +++ b/g4f/Provider/Llama2.py @@ -3,18 +3,24 @@ from __future__ import annotations from aiohttp import ClientSession from ..typing import AsyncResult, Messages -from .base_provider import AsyncGeneratorProvider +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin -models = { - "meta-llama/Llama-2-7b-chat-hf": "meta/llama-2-7b-chat", - "meta-llama/Llama-2-13b-chat-hf": "meta/llama-2-13b-chat", - "meta-llama/Llama-2-70b-chat-hf": "meta/llama-2-70b-chat", -} -class Llama2(AsyncGeneratorProvider): +class Llama2(AsyncGeneratorProvider, ProviderModelMixin): url = "https://www.llama2.ai" working = True supports_message_history = True + default_model = "meta/llama-2-70b-chat" + models = [ + "meta/llama-2-7b-chat", + "meta/llama-2-13b-chat", + "meta/llama-2-70b-chat", + ] + model_aliases = { + "meta-llama/Llama-2-7b-chat-hf": "meta/llama-2-7b-chat", + "meta-llama/Llama-2-13b-chat-hf": "meta/llama-2-13b-chat", + "meta-llama/Llama-2-70b-chat-hf": "meta/llama-2-70b-chat", + } @classmethod async def create_async_generator( @@ -24,10 +30,6 @@ class Llama2(AsyncGeneratorProvider): proxy: str = None, **kwargs ) -> AsyncResult: - if not model: - model = "meta/llama-2-70b-chat" - elif model in models: - model = models[model] headers = { "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/118.0", "Accept": "*/*", @@ -48,7 +50,7 @@ class Llama2(AsyncGeneratorProvider): prompt = format_prompt(messages) data = { "prompt": prompt, - "model": model, + "model": cls.get_model(model), "systemPrompt": kwargs.get("system_message", "You are a helpful assistant."), "temperature": kwargs.get("temperature", 0.75), "topP": kwargs.get("top_p", 0.9), diff --git a/g4f/Provider/PerplexityLabs.py b/g4f/Provider/PerplexityLabs.py index c989b3da..90258da5 100644 --- a/g4f/Provider/PerplexityLabs.py +++ b/g4f/Provider/PerplexityLabs.py @@ -5,20 +5,21 @@ import json from aiohttp import ClientSession from ..typing import AsyncResult, Messages -from .base_provider import AsyncGeneratorProvider +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin API_URL = "https://labs-api.perplexity.ai/socket.io/" WS_URL = "wss://labs-api.perplexity.ai/socket.io/" -class PerplexityLabs(AsyncGeneratorProvider): +class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin): url = "https://labs.perplexity.ai" working = True - supports_gpt_35_turbo = True - models = ['pplx-7b-online', 'pplx-70b-online', 'pplx-7b-chat', 'pplx-70b-chat', 'mistral-7b-instruct', + models = [ + 'pplx-7b-online', 'pplx-70b-online', 'pplx-7b-chat', 'pplx-70b-chat', 'mistral-7b-instruct', 'codellama-34b-instruct', 'llama-2-70b-chat', 'llava-7b-chat', 'mixtral-8x7b-instruct', - 'mistral-medium', 'related'] + 'mistral-medium', 'related' + ] default_model = 'pplx-70b-online' - model_map = { + model_aliases = { "mistralai/Mistral-7B-Instruct-v0.1": "mistral-7b-instruct", "meta-llama/Llama-2-70b-chat-hf": "llama-2-70b-chat", "mistralai/Mixtral-8x7B-Instruct-v0.1": "mixtral-8x7b-instruct", @@ -33,12 +34,6 @@ class PerplexityLabs(AsyncGeneratorProvider): proxy: str = None, **kwargs ) -> AsyncResult: - if not model: - model = cls.default_model - elif model in cls.model_map: - model = cls.model_map[model] - elif model not in cls.models: - raise ValueError(f"Model is not supported: {model}") headers = { "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:121.0) Gecko/20100101 Firefox/121.0", "Accept": "*/*", @@ -78,7 +73,7 @@ class PerplexityLabs(AsyncGeneratorProvider): message_data = { 'version': '2.2', 'source': 'default', - 'model': model, + 'model': cls.get_model(model), 'messages': messages } await ws.send_str('42' + json.dumps(['perplexity_playground', message_data])) diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index bc47a1fa..e1dcd24d 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -8,7 +8,7 @@ from inspect import signature, Parameter from .helper import get_cookies, format_prompt from ..typing import CreateResult, AsyncResult, Messages, Union from ..base_provider import BaseProvider -from ..errors import NestAsyncioError +from ..errors import NestAsyncioError, ModelNotSupportedError if sys.version_info < (3, 10): NoneType = type(None) @@ -251,4 +251,23 @@ class AsyncGeneratorProvider(AsyncProvider): Returns: AsyncResult: An asynchronous generator yielding results. """ - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() + +class ProviderModelMixin: + default_model: str + models: list[str] = [] + model_aliases: dict[str, str] = {} + + @classmethod + def get_models(cls) -> list[str]: + return cls.models + + @classmethod + def get_model(cls, model: str) -> str: + if not model: + return cls.default_model + elif model in cls.model_aliases: + return cls.model_aliases[model] + elif model not in cls.get_models(): + raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}") + return model \ No newline at end of file diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index abf5b8d9..85866272 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -10,22 +10,15 @@ from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support import expected_conditions as EC -from ..base_provider import AsyncGeneratorProvider +from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..helper import format_prompt, get_cookies from ...webdriver import get_browser, get_driver_cookies from ...typing import AsyncResult, Messages from ...requests import StreamSession from ...image import to_image, to_bytes, ImageType, ImageResponse -# Aliases for model names -MODELS = { - "gpt-3.5": "text-davinci-002-render-sha", - "gpt-3.5-turbo": "text-davinci-002-render-sha", - "gpt-4": "gpt-4", - "gpt-4-gizmo": "gpt-4-gizmo" -} -class OpenaiChat(AsyncGeneratorProvider): +class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): """A class for creating and managing conversations with OpenAI chat service""" url = "https://chat.openai.com" @@ -33,6 +26,11 @@ class OpenaiChat(AsyncGeneratorProvider): needs_auth = True supports_gpt_35_turbo = True supports_gpt_4 = True + default_model = None + models = ["text-davinci-002-render-sha", "gpt-4", "gpt-4-gizmo"] + model_aliases = { + "gpt-3.5-turbo": "text-davinci-002-render-sha", + } _cookies: dict = {} _default_model: str = None @@ -91,7 +89,7 @@ class OpenaiChat(AsyncGeneratorProvider): ) @classmethod - async def _upload_image( + async def upload_image( cls, session: StreamSession, headers: dict, @@ -150,7 +148,7 @@ class OpenaiChat(AsyncGeneratorProvider): return ImageResponse(download_url, image_data["file_name"], image_data) @classmethod - async def _get_default_model(cls, session: StreamSession, headers: dict): + async def get_default_model(cls, session: StreamSession, headers: dict): """ Get the default model name from the service @@ -161,20 +159,17 @@ class OpenaiChat(AsyncGeneratorProvider): Returns: The default model name as a string """ - # Check the cache for the default model - if cls._default_model: - return cls._default_model - # Get the models data from the service - async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response: - data = await response.json() - if "categories" in data: - cls._default_model = data["categories"][-1]["default_model"] - else: - raise RuntimeError(f"Response: {data}") - return cls._default_model + if not cls.default_model: + async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response: + data = await response.json() + if "categories" in data: + cls.default_model = data["categories"][-1]["default_model"] + else: + raise RuntimeError(f"Response: {data}") + return cls.default_model @classmethod - def _create_messages(cls, prompt: str, image_response: ImageResponse = None): + def create_messages(cls, prompt: str, image_response: ImageResponse = None): """ Create a list of messages for the user input @@ -222,7 +217,7 @@ class OpenaiChat(AsyncGeneratorProvider): return messages @classmethod - async def _get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse: + async def get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse: """ Retrieves the image response based on the message content. @@ -257,7 +252,7 @@ class OpenaiChat(AsyncGeneratorProvider): raise RuntimeError(f"Error in downloading image: {e}") @classmethod - async def _delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str): + async def delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str): """ Deletes a conversation by setting its visibility to False. @@ -322,7 +317,6 @@ class OpenaiChat(AsyncGeneratorProvider): Raises: RuntimeError: If an error occurs during processing. """ - model = MODELS.get(model, model) if not parent_id: parent_id = str(uuid.uuid4()) if not cookies: @@ -333,7 +327,7 @@ class OpenaiChat(AsyncGeneratorProvider): login_url = os.environ.get("G4F_LOGIN_URL") if login_url: yield f"Please login: [ChatGPT]({login_url})\n\n" - access_token, cookies = cls._browse_access_token(proxy) + access_token, cookies = cls.browse_access_token(proxy) cls._cookies = cookies headers = {"Authorization": f"Bearer {access_token}"} @@ -344,12 +338,10 @@ class OpenaiChat(AsyncGeneratorProvider): timeout=timeout, cookies=dict([(name, value) for name, value in cookies.items() if name == "_puid"]) ) as session: - if not model: - model = await cls._get_default_model(session, headers) try: image_response = None if image: - image_response = await cls._upload_image(session, headers, image) + image_response = await cls.upload_image(session, headers, image) yield image_response except Exception as e: yield e @@ -357,15 +349,15 @@ class OpenaiChat(AsyncGeneratorProvider): while not end_turn.is_end: data = { "action": action, - "arkose_token": await cls._get_arkose_token(session), + "arkose_token": await cls.get_arkose_token(session), "conversation_id": conversation_id, "parent_message_id": parent_id, - "model": model, + "model": cls.get_model(model or await cls.get_default_model(session, headers)), "history_and_training_disabled": history_disabled and not auto_continue, } if action != "continue": prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"] - data["messages"] = cls._create_messages(prompt, image_response) + data["messages"] = cls.create_messages(prompt, image_response) async with session.post( f"{cls.url}/backend-api/conversation", json=data, @@ -391,7 +383,7 @@ class OpenaiChat(AsyncGeneratorProvider): if "message_type" not in line["message"]["metadata"]: continue try: - image_response = await cls._get_generated_image(session, headers, line) + image_response = await cls.get_generated_image(session, headers, line) if image_response: yield image_response except Exception as e: @@ -422,10 +414,10 @@ class OpenaiChat(AsyncGeneratorProvider): action = "continue" await asyncio.sleep(5) if history_disabled and auto_continue: - await cls._delete_conversation(session, headers, conversation_id) + await cls.delete_conversation(session, headers, conversation_id) @classmethod - def _browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> tuple[str, dict]: + def browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> tuple[str, dict]: """ Browse to obtain an access token. @@ -452,7 +444,7 @@ class OpenaiChat(AsyncGeneratorProvider): driver.quit() @classmethod - async def _get_arkose_token(cls, session: StreamSession) -> str: + async def get_arkose_token(cls, session: StreamSession) -> str: """ Obtain an Arkose token for the session. diff --git a/g4f/errors.py b/g4f/errors.py index c0e6ddfc..ddfe74db 100644 --- a/g4f/errors.py +++ b/g4f/errors.py @@ -26,4 +26,7 @@ class VersionNotFoundError(Exception): pass class NestAsyncioError(Exception): + pass + +class ModelNotSupportedError(Exception): pass \ No newline at end of file -- cgit v1.2.3