diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/nexra/NexraChatGPT4o.py | 114 |
1 files changed, 63 insertions, 51 deletions
diff --git a/g4f/Provider/nexra/NexraChatGPT4o.py b/g4f/Provider/nexra/NexraChatGPT4o.py index 62144163..126d32b8 100644 --- a/g4f/Provider/nexra/NexraChatGPT4o.py +++ b/g4f/Provider/nexra/NexraChatGPT4o.py @@ -1,74 +1,86 @@ from __future__ import annotations -from aiohttp import ClientSession +import json +import requests -from ...typing import AsyncResult, Messages -from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin +from ...typing import CreateResult, Messages +from ..base_provider import ProviderModelMixin, AbstractProvider from ..helper import format_prompt -import json -class NexraChatGPT4o(AsyncGeneratorProvider, ProviderModelMixin): +class NexraChatGPT4o(AbstractProvider, ProviderModelMixin): label = "Nexra ChatGPT4o" url = "https://nexra.aryahcr.cc/documentation/chatgpt/en" api_endpoint = "https://nexra.aryahcr.cc/api/chat/complements" working = True - supports_gpt_4 = True - supports_stream = False + supports_stream = True - default_model = 'gpt-4o' + default_model = "gpt-4o" models = [default_model] - + @classmethod def get_model(cls, model: str) -> str: return cls.default_model - + @classmethod - async def create_async_generator( + def create_completion( cls, model: str, messages: Messages, + stream: bool, proxy: str = None, + markdown: bool = False, **kwargs - ) -> AsyncResult: + ) -> CreateResult: model = cls.get_model(model) - + headers = { - "Content-Type": "application/json", + 'Content-Type': 'application/json' } - async with ClientSession(headers=headers) as session: - data = { - "messages": [ - { - "role": "user", - "content": format_prompt(messages) - } - ], - "stream": False, - "markdown": False, - "model": model - } - async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response: - response.raise_for_status() - buffer = "" - last_message = "" - async for chunk in response.content.iter_any(): - chunk_str = chunk.decode() - buffer += chunk_str - while '{' in buffer and '}' in buffer: - start = buffer.index('{') - end = buffer.index('}', start) + 1 - json_str = buffer[start:end] - buffer = buffer[end:] - try: - json_obj = json.loads(json_str) - if json_obj.get("finish"): - if last_message: - yield last_message - return - elif json_obj.get("message"): - last_message = json_obj["message"] - except json.JSONDecodeError: - pass - - if last_message: - yield last_message + + data = { + "messages": [ + { + "role": "user", + "content": format_prompt(messages) + } + ], + "stream": stream, + "markdown": markdown, + "model": model + } + + response = requests.post(cls.api_endpoint, headers=headers, json=data, stream=stream) + + if stream: + return cls.process_streaming_response(response) + else: + return cls.process_non_streaming_response(response) + + @classmethod + def process_non_streaming_response(cls, response): + if response.status_code == 200: + try: + content = response.text.lstrip('') + data = json.loads(content) + return data.get('message', '') + except json.JSONDecodeError: + return "Error: Unable to decode JSON response" + else: + return f"Error: {response.status_code}" + + @classmethod + def process_streaming_response(cls, response): + full_message = "" + for line in response.iter_lines(decode_unicode=True): + if line: + try: + line = line.lstrip('') + data = json.loads(line) + if data.get('finish'): + break + message = data.get('message', '') + if message and message != full_message: + yield message[len(full_message):] + full_message = message + except json.JSONDecodeError: + pass |