diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/DDG.py | 108 |
1 files changed, 60 insertions, 48 deletions
diff --git a/g4f/Provider/DDG.py b/g4f/Provider/DDG.py index 43cc39c0..c4be0ea8 100644 --- a/g4f/Provider/DDG.py +++ b/g4f/Provider/DDG.py @@ -2,12 +2,31 @@ from __future__ import annotations import json import aiohttp -from aiohttp import ClientSession +from aiohttp import ClientSession, BaseConnector from ..typing import AsyncResult, Messages -from .base_provider import AsyncGeneratorProvider, ProviderModelMixin +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, BaseConversation from .helper import format_prompt +from ..requests.aiohttp import get_connector +from ..requests.raise_for_status import raise_for_status +from .. import debug +MODELS = [ + {"model":"gpt-4o","modelName":"GPT-4o","modelVariant":None,"modelStyleId":"gpt-4o-mini","createdBy":"OpenAI","moderationLevel":"HIGH","isAvailable":1,"inputCharLimit":16e3,"settingId":"4"}, + {"model":"gpt-4o-mini","modelName":"GPT-4o","modelVariant":"mini","modelStyleId":"gpt-4o-mini","createdBy":"OpenAI","moderationLevel":"HIGH","isAvailable":0,"inputCharLimit":16e3,"settingId":"3"}, + {"model":"claude-3-5-sonnet-20240620","modelName":"Claude 3.5","modelVariant":"Sonnet","modelStyleId":"claude-3-haiku","createdBy":"Anthropic","moderationLevel":"HIGH","isAvailable":1,"inputCharLimit":16e3,"settingId":"7"}, + {"model":"claude-3-opus-20240229","modelName":"Claude 3","modelVariant":"Opus","modelStyleId":"claude-3-haiku","createdBy":"Anthropic","moderationLevel":"HIGH","isAvailable":1,"inputCharLimit":16e3,"settingId":"2"}, + {"model":"claude-3-haiku-20240307","modelName":"Claude 3","modelVariant":"Haiku","modelStyleId":"claude-3-haiku","createdBy":"Anthropic","moderationLevel":"HIGH","isAvailable":0,"inputCharLimit":16e3,"settingId":"1"}, + {"model":"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo","modelName":"Llama 3.1","modelVariant":"70B","modelStyleId":"llama-3","createdBy":"Meta","moderationLevel":"MEDIUM","isAvailable":0,"isOpenSource":0,"inputCharLimit":16e3,"settingId":"5"}, + {"model":"mistralai/Mixtral-8x7B-Instruct-v0.1","modelName":"Mixtral","modelVariant":"8x7B","modelStyleId":"mixtral","createdBy":"Mistral AI","moderationLevel":"LOW","isAvailable":0,"isOpenSource":0,"inputCharLimit":16e3,"settingId":"6"} +] + +class Conversation(BaseConversation): + vqd: str = None + message_history: Messages = [] + + def __init__(self, model: str): + self.model = model class DDG(AsyncGeneratorProvider, ProviderModelMixin): url = "https://duckduckgo.com" @@ -18,81 +37,74 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin): supports_message_history = True default_model = "gpt-4o-mini" - models = [ - "gpt-4o-mini", - "claude-3-haiku-20240307", - "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", - "mistralai/Mixtral-8x7B-Instruct-v0.1" - ] + models = [model.get("model") for model in MODELS] model_aliases = { "claude-3-haiku": "claude-3-haiku-20240307", "llama-3.1-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", - "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1" + "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "gpt-4": "gpt-4o-mini" } @classmethod - def get_model(cls, model: str) -> str: - return cls.model_aliases.get(model, model) if model in cls.model_aliases else cls.default_model - - @classmethod - async def get_vqd(cls): + async def get_vqd(cls, proxy: str, connector: BaseConnector = None): status_url = "https://duckduckgo.com/duckchat/v1/status" - headers = { 'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36', 'Accept': 'text/event-stream', 'x-vqd-accept': '1' } - - async with aiohttp.ClientSession() as session: - try: - async with session.get(status_url, headers=headers) as response: - if response.status == 200: - return response.headers.get("x-vqd-4") - else: - print(f"Error: Status code {response.status}") - return None - except Exception as e: - print(f"Error getting VQD: {e}") - return None + async with aiohttp.ClientSession(connector=get_connector(connector, proxy)) as session: + async with session.get(status_url, headers=headers) as response: + await raise_for_status(response) + return response.headers.get("x-vqd-4") @classmethod async def create_async_generator( cls, model: str, messages: Messages, - conversation: dict = None, + conversation: Conversation = None, + return_conversation: bool = False, proxy: str = None, + connector: BaseConnector = None, **kwargs ) -> AsyncResult: model = cls.get_model(model) + is_new_conversation = False + if conversation is None: + conversation = Conversation(model) + is_new_conversation = True + debug.last_model = model + if conversation.vqd is None: + conversation.vqd = await cls.get_vqd(proxy, connector) + if not conversation.vqd: + raise Exception("Failed to obtain VQD token") + headers = { 'accept': 'text/event-stream', 'content-type': 'application/json', 'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36', + 'x-vqd-4': conversation.vqd, } - - vqd = conversation.get('vqd') if conversation else await cls.get_vqd() - if not vqd: - raise Exception("Failed to obtain VQD token") - - headers['x-vqd-4'] = vqd - - if conversation: - message_history = conversation.get('messages', []) - message_history.append({"role": "user", "content": format_prompt(messages)}) - else: - message_history = [{"role": "user", "content": format_prompt(messages)}] - - async with ClientSession(headers=headers) as session: + async with ClientSession(headers=headers, connector=get_connector(connector, proxy)) as session: + if is_new_conversation: + conversation.message_history = [{"role": "user", "content": format_prompt(messages)}] + else: + conversation.message_history = [ + *conversation.message_history, + messages[-2], + messages[-1] + ] + if return_conversation: + yield conversation data = { - "model": model, - "messages": message_history + "model": conversation.model, + "messages": conversation.message_history } - - async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response: - response.raise_for_status() + async with session.post(cls.api_endpoint, json=data) as response: + conversation.vqd = response.headers.get("x-vqd-4") + await raise_for_status(response) async for line in response.content: if line: decoded_line = line.decode('utf-8') @@ -105,4 +117,4 @@ class DDG(AsyncGeneratorProvider, ProviderModelMixin): if 'message' in json_data: yield json_data['message'] except json.JSONDecodeError: - pass + pass
\ No newline at end of file |