From ae46cf72d4650d648e25397926954171d8c2d5d5 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Mon, 1 Jan 2024 23:23:45 +0100 Subject: Fix DeepInfra Provider --- g4f/Provider/DeepInfra.py | 88 ++++++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 40 deletions(-) (limited to 'g4f/Provider') diff --git a/g4f/Provider/DeepInfra.py b/g4f/Provider/DeepInfra.py index 1639bbd2..96e3a680 100644 --- a/g4f/Provider/DeepInfra.py +++ b/g4f/Provider/DeepInfra.py @@ -1,30 +1,34 @@ from __future__ import annotations -import requests, json -from ..typing import CreateResult, Messages -from .base_provider import AbstractProvider +import json +from ..typing import AsyncResult, Messages +from .base_provider import AsyncGeneratorProvider +from ..requests import StreamSession -class DeepInfra(AbstractProvider): - url: str = "https://deepinfra.com" - working: bool = True - supports_stream: bool = True - supports_message_history: bool = True +class DeepInfra(AsyncGeneratorProvider): + url = "https://deepinfra.com" + working = True + supports_stream = True + supports_message_history = True @staticmethod - def create_completion(model: str, - messages: Messages, - stream: bool, - auth: str = None, - **kwargs) -> CreateResult: + async def create_async_generator( + model: str, + messages: Messages, + stream: bool, + proxy: str = None, + timeout: int = 120, + auth: str = None, + **kwargs + ) -> AsyncResult: if not model: model = 'meta-llama/Llama-2-70b-chat-hf' headers = { - 'Accept-Language': 'en,fr-FR;q=0.9,fr;q=0.8,es-ES;q=0.7,es;q=0.6,en-US;q=0.5,am;q=0.4,de;q=0.3', - 'Cache-Control': 'no-cache', + 'Accept-Encoding': 'gzip, deflate, br', + 'Accept-Language': 'en-US', 'Connection': 'keep-alive', 'Content-Type': 'application/json', 'Origin': 'https://deepinfra.com', - 'Pragma': 'no-cache', 'Referer': 'https://deepinfra.com/', 'Sec-Fetch-Dest': 'empty', 'Sec-Fetch-Mode': 'cors', @@ -38,28 +42,32 @@ class DeepInfra(AbstractProvider): } if auth: headers['Authorization'] = f"bearer {auth}" - - json_data = json.dumps({ - 'model' : model, - 'messages': messages, - 'stream' : True}, separators=(',', ':')) - - response = requests.post('https://api.deepinfra.com/v1/openai/chat/completions', - headers=headers, data=json_data, stream=True) - - response.raise_for_status() - first = True - for line in response.content: - if line.startswith(b"data: [DONE]"): - break - elif line.startswith(b"data: "): - try: - chunk = json.loads(line[6:])["choices"][0]["delta"].get("content") - except Exception: - raise RuntimeError(f"Response: {line}") - if chunk: - if first: - chunk = chunk.lstrip() + + async with StreamSession(headers=headers, + timeout=timeout, + proxies={"https": proxy}, + impersonate="chrome110" + ) as session: + json_data = { + 'model' : model, + 'messages': messages, + 'stream' : True + } + async with session.post('https://api.deepinfra.com/v1/openai/chat/completions', + json=json_data) as response: + response.raise_for_status() + first = True + async for line in response.iter_lines(): + try: + if line.startswith(b"data: [DONE]"): + break + elif line.startswith(b"data: "): + chunk = json.loads(line[6:])["choices"][0]["delta"].get("content") if chunk: - first = False - yield chunk \ No newline at end of file + if first: + chunk = chunk.lstrip() + if chunk: + first = False + yield chunk + except Exception: + raise RuntimeError(f"Response: {line}") \ No newline at end of file -- cgit v1.2.3