diff options
author | kqlio67 <kqlio67@users.noreply.github.com> | 2024-11-10 19:31:11 +0100 |
---|---|---|
committer | kqlio67 <kqlio67@users.noreply.github.com> | 2024-11-10 19:31:11 +0100 |
commit | 9a0346199bdcee36e5ffb4b9ef818f60dfeb68f1 (patch) | |
tree | aec6f13029ceeb0e3621843356489e0cc98e8153 /g4f/Provider/airforce/AirforceChat.py | |
parent | Update (g4f/Provider/airforce/AirforceImage.py) (diff) | |
download | gpt4free-9a0346199bdcee36e5ffb4b9ef818f60dfeb68f1.tar gpt4free-9a0346199bdcee36e5ffb4b9ef818f60dfeb68f1.tar.gz gpt4free-9a0346199bdcee36e5ffb4b9ef818f60dfeb68f1.tar.bz2 gpt4free-9a0346199bdcee36e5ffb4b9ef818f60dfeb68f1.tar.lz gpt4free-9a0346199bdcee36e5ffb4b9ef818f60dfeb68f1.tar.xz gpt4free-9a0346199bdcee36e5ffb4b9ef818f60dfeb68f1.tar.zst gpt4free-9a0346199bdcee36e5ffb4b9ef818f60dfeb68f1.zip |
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/airforce/AirforceChat.py | 186 |
1 files changed, 93 insertions, 93 deletions
diff --git a/g4f/Provider/airforce/AirforceChat.py b/g4f/Provider/airforce/AirforceChat.py index 63a0460f..fc375270 100644 --- a/g4f/Provider/airforce/AirforceChat.py +++ b/g4f/Provider/airforce/AirforceChat.py @@ -1,14 +1,15 @@ from __future__ import annotations import re -from aiohttp import ClientSession import json -from typing import List +from aiohttp import ClientSession import requests +from typing import List from ...typing import AsyncResult, Messages from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..helper import format_prompt +# Helper function to clean the response def clean_response(text: str) -> str: """Clean response from unwanted patterns.""" patterns = [ @@ -16,35 +17,27 @@ def clean_response(text: str) -> str: r"Rate limit \(\d+\/minute\) exceeded\. Join our discord for more: .+https:\/\/discord\.com\/invite\/\S+", r"Rate limit \(\d+\/hour\) exceeded\. Join our discord for more: https:\/\/discord\.com\/invite\/\S+", r"</s>", # zephyr-7b-beta + r"\[ERROR\] '\w{8}-\w{4}-\w{4}-\w{4}-\w{12}'", # Matches [ERROR] 'UUID' ] - for pattern in patterns: text = re.sub(pattern, '', text) return text.strip() -def split_message(message: dict, chunk_size: int = 995) -> List[dict]: - """Split a message into chunks of specified size.""" - content = message.get('content', '') - if len(content) <= chunk_size: - return [message] - +def split_message(message: str, max_length: int = 1000) -> List[str]: + """Splits the message into chunks of a given length (max_length)""" + # Split the message into smaller chunks to avoid exceeding the limit chunks = [] - while content: - chunk = content[:chunk_size] - content = content[chunk_size:] - chunks.append({ - 'role': message['role'], - 'content': chunk - }) + while len(message) > max_length: + # Find the last space or punctuation before max_length to avoid cutting words + split_point = message.rfind(' ', 0, max_length) + if split_point == -1: # No space found, split at max_length + split_point = max_length + chunks.append(message[:split_point]) + message = message[split_point:].strip() + if message: + chunks.append(message) # Append the remaining part of the message return chunks -def split_messages(messages: Messages, chunk_size: int = 995) -> Messages: - """Split all messages that exceed chunk_size into smaller messages.""" - result = [] - for message in messages: - result.extend(split_message(message, chunk_size)) - return result - class AirforceChat(AsyncGeneratorProvider, ProviderModelMixin): label = "AirForce Chat" api_endpoint = "https://api.airforce/chat/completions" @@ -57,45 +50,44 @@ class AirforceChat(AsyncGeneratorProvider, ProviderModelMixin): data = response.json() text_models = [model['id'] for model in data['data']] - models = [*text_models] - + model_aliases = { - # openchat - "openchat-3.5": "openchat-3.5-0106", - - # deepseek-ai - "deepseek-coder": "deepseek-coder-6.7b-instruct", - - # NousResearch - "hermes-2-dpo": "Nous-Hermes-2-Mixtral-8x7B-DPO", - "hermes-2-pro": "hermes-2-pro-mistral-7b", - - # teknium - "openhermes-2.5": "openhermes-2.5-mistral-7b", - - # liquid - "lfm-40b": "lfm-40b-moe", - - # DiscoResearch - "german-7b": "discolm-german-7b-v1", - - # meta-llama - "llama-2-7b": "llama-2-7b-chat-int8", - "llama-2-7b": "llama-2-7b-chat-fp16", - "llama-3.1-70b": "llama-3.1-70b-chat", - "llama-3.1-8b": "llama-3.1-8b-chat", - "llama-3.1-70b": "llama-3.1-70b-turbo", - "llama-3.1-8b": "llama-3.1-8b-turbo", - - # inferless - "neural-7b": "neural-chat-7b-v3-1", - - # HuggingFaceH4 - "zephyr-7b": "zephyr-7b-beta", - - # llmplayground.net - #"any-uncensored": "any-uncensored", + # openchat + "openchat-3.5": "openchat-3.5-0106", + + # deepseek-ai + "deepseek-coder": "deepseek-coder-6.7b-instruct", + + # NousResearch + "hermes-2-dpo": "Nous-Hermes-2-Mixtral-8x7B-DPO", + "hermes-2-pro": "hermes-2-pro-mistral-7b", + + # teknium + "openhermes-2.5": "openhermes-2.5-mistral-7b", + + # liquid + "lfm-40b": "lfm-40b-moe", + + # DiscoResearch + "german-7b": "discolm-german-7b-v1", + + # meta-llama + "llama-2-7b": "llama-2-7b-chat-int8", + "llama-2-7b": "llama-2-7b-chat-fp16", + "llama-3.1-70b": "llama-3.1-70b-chat", + "llama-3.1-8b": "llama-3.1-8b-chat", + "llama-3.1-70b": "llama-3.1-70b-turbo", + "llama-3.1-8b": "llama-3.1-8b-turbo", + + # inferless + "neural-7b": "neural-chat-7b-v3-1", + + # HuggingFaceH4 + "zephyr-7b": "zephyr-7b-beta", + + # llmplayground.net + #"any-uncensored": "any-uncensored", } @classmethod @@ -112,8 +104,6 @@ class AirforceChat(AsyncGeneratorProvider, ProviderModelMixin): ) -> AsyncResult: model = cls.get_model(model) - chunked_messages = split_messages(messages) - headers = { 'accept': '*/*', 'accept-language': 'en-US,en;q=0.9', @@ -133,36 +123,46 @@ class AirforceChat(AsyncGeneratorProvider, ProviderModelMixin): 'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36' } - data = { - "messages": chunked_messages, - "model": model, - "max_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - "stream": stream - } + # Format the messages for the API + formatted_messages = format_prompt(messages) + message_chunks = split_message(formatted_messages) + + full_response = "" + for chunk in message_chunks: + data = { + "messages": [{"role": "user", "content": chunk}], + "model": model, + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "stream": stream + } + + async with ClientSession(headers=headers) as session: + async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response: + response.raise_for_status() - async with ClientSession(headers=headers) as session: - async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response: - response.raise_for_status() - text = "" - if stream: - async for line in response.content: - line = line.decode('utf-8') - if line.startswith('data: '): - json_str = line[6:] - try: - chunk = json.loads(json_str) - if 'choices' in chunk and chunk['choices']: - content = chunk['choices'][0].get('delta', {}).get('content', '') - text += content - except json.JSONDecodeError as e: - print(f"Error decoding JSON: {json_str}, Error: {e}") - elif line.strip() == "[DONE]": - break - yield clean_response(text) - else: - response_json = await response.json() - text = response_json["choices"][0]["message"]["content"] - yield clean_response(text) + text = "" + if stream: + async for line in response.content: + line = line.decode('utf-8').strip() + if line.startswith('data: '): + json_str = line[6:] + try: + if json_str and json_str != "[DONE]": + chunk = json.loads(json_str) + if 'choices' in chunk and chunk['choices']: + content = chunk['choices'][0].get('delta', {}).get('content', '') + text += content + except json.JSONDecodeError as e: + print(f"Error decoding JSON: {json_str}, Error: {e}") + elif line == "[DONE]": + break + full_response += clean_response(text) + else: + response_json = await response.json() + text = response_json["choices"][0]["message"]["content"] + full_response += clean_response(text) + # Return the complete response after all chunks + yield full_response |