summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/airforce/AirforceChat.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/airforce/AirforceChat.py186
1 files changed, 93 insertions, 93 deletions
diff --git a/g4f/Provider/airforce/AirforceChat.py b/g4f/Provider/airforce/AirforceChat.py
index 63a0460f..fc375270 100644
--- a/g4f/Provider/airforce/AirforceChat.py
+++ b/g4f/Provider/airforce/AirforceChat.py
@@ -1,14 +1,15 @@
from __future__ import annotations
import re
-from aiohttp import ClientSession
import json
-from typing import List
+from aiohttp import ClientSession
import requests
+from typing import List
from ...typing import AsyncResult, Messages
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt
+# Helper function to clean the response
def clean_response(text: str) -> str:
"""Clean response from unwanted patterns."""
patterns = [
@@ -16,35 +17,27 @@ def clean_response(text: str) -> str:
r"Rate limit \(\d+\/minute\) exceeded\. Join our discord for more: .+https:\/\/discord\.com\/invite\/\S+",
r"Rate limit \(\d+\/hour\) exceeded\. Join our discord for more: https:\/\/discord\.com\/invite\/\S+",
r"</s>", # zephyr-7b-beta
+ r"\[ERROR\] '\w{8}-\w{4}-\w{4}-\w{4}-\w{12}'", # Matches [ERROR] 'UUID'
]
-
for pattern in patterns:
text = re.sub(pattern, '', text)
return text.strip()
-def split_message(message: dict, chunk_size: int = 995) -> List[dict]:
- """Split a message into chunks of specified size."""
- content = message.get('content', '')
- if len(content) <= chunk_size:
- return [message]
-
+def split_message(message: str, max_length: int = 1000) -> List[str]:
+ """Splits the message into chunks of a given length (max_length)"""
+ # Split the message into smaller chunks to avoid exceeding the limit
chunks = []
- while content:
- chunk = content[:chunk_size]
- content = content[chunk_size:]
- chunks.append({
- 'role': message['role'],
- 'content': chunk
- })
+ while len(message) > max_length:
+ # Find the last space or punctuation before max_length to avoid cutting words
+ split_point = message.rfind(' ', 0, max_length)
+ if split_point == -1: # No space found, split at max_length
+ split_point = max_length
+ chunks.append(message[:split_point])
+ message = message[split_point:].strip()
+ if message:
+ chunks.append(message) # Append the remaining part of the message
return chunks
-def split_messages(messages: Messages, chunk_size: int = 995) -> Messages:
- """Split all messages that exceed chunk_size into smaller messages."""
- result = []
- for message in messages:
- result.extend(split_message(message, chunk_size))
- return result
-
class AirforceChat(AsyncGeneratorProvider, ProviderModelMixin):
label = "AirForce Chat"
api_endpoint = "https://api.airforce/chat/completions"
@@ -57,45 +50,44 @@ class AirforceChat(AsyncGeneratorProvider, ProviderModelMixin):
data = response.json()
text_models = [model['id'] for model in data['data']]
-
models = [*text_models]
-
+
model_aliases = {
- # openchat
- "openchat-3.5": "openchat-3.5-0106",
-
- # deepseek-ai
- "deepseek-coder": "deepseek-coder-6.7b-instruct",
-
- # NousResearch
- "hermes-2-dpo": "Nous-Hermes-2-Mixtral-8x7B-DPO",
- "hermes-2-pro": "hermes-2-pro-mistral-7b",
-
- # teknium
- "openhermes-2.5": "openhermes-2.5-mistral-7b",
-
- # liquid
- "lfm-40b": "lfm-40b-moe",
-
- # DiscoResearch
- "german-7b": "discolm-german-7b-v1",
-
- # meta-llama
- "llama-2-7b": "llama-2-7b-chat-int8",
- "llama-2-7b": "llama-2-7b-chat-fp16",
- "llama-3.1-70b": "llama-3.1-70b-chat",
- "llama-3.1-8b": "llama-3.1-8b-chat",
- "llama-3.1-70b": "llama-3.1-70b-turbo",
- "llama-3.1-8b": "llama-3.1-8b-turbo",
-
- # inferless
- "neural-7b": "neural-chat-7b-v3-1",
-
- # HuggingFaceH4
- "zephyr-7b": "zephyr-7b-beta",
-
- # llmplayground.net
- #"any-uncensored": "any-uncensored",
+ # openchat
+ "openchat-3.5": "openchat-3.5-0106",
+
+ # deepseek-ai
+ "deepseek-coder": "deepseek-coder-6.7b-instruct",
+
+ # NousResearch
+ "hermes-2-dpo": "Nous-Hermes-2-Mixtral-8x7B-DPO",
+ "hermes-2-pro": "hermes-2-pro-mistral-7b",
+
+ # teknium
+ "openhermes-2.5": "openhermes-2.5-mistral-7b",
+
+ # liquid
+ "lfm-40b": "lfm-40b-moe",
+
+ # DiscoResearch
+ "german-7b": "discolm-german-7b-v1",
+
+ # meta-llama
+ "llama-2-7b": "llama-2-7b-chat-int8",
+ "llama-2-7b": "llama-2-7b-chat-fp16",
+ "llama-3.1-70b": "llama-3.1-70b-chat",
+ "llama-3.1-8b": "llama-3.1-8b-chat",
+ "llama-3.1-70b": "llama-3.1-70b-turbo",
+ "llama-3.1-8b": "llama-3.1-8b-turbo",
+
+ # inferless
+ "neural-7b": "neural-chat-7b-v3-1",
+
+ # HuggingFaceH4
+ "zephyr-7b": "zephyr-7b-beta",
+
+ # llmplayground.net
+ #"any-uncensored": "any-uncensored",
}
@classmethod
@@ -112,8 +104,6 @@ class AirforceChat(AsyncGeneratorProvider, ProviderModelMixin):
) -> AsyncResult:
model = cls.get_model(model)
- chunked_messages = split_messages(messages)
-
headers = {
'accept': '*/*',
'accept-language': 'en-US,en;q=0.9',
@@ -133,36 +123,46 @@ class AirforceChat(AsyncGeneratorProvider, ProviderModelMixin):
'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36'
}
- data = {
- "messages": chunked_messages,
- "model": model,
- "max_tokens": max_tokens,
- "temperature": temperature,
- "top_p": top_p,
- "stream": stream
- }
+ # Format the messages for the API
+ formatted_messages = format_prompt(messages)
+ message_chunks = split_message(formatted_messages)
+
+ full_response = ""
+ for chunk in message_chunks:
+ data = {
+ "messages": [{"role": "user", "content": chunk}],
+ "model": model,
+ "max_tokens": max_tokens,
+ "temperature": temperature,
+ "top_p": top_p,
+ "stream": stream
+ }
+
+ async with ClientSession(headers=headers) as session:
+ async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
+ response.raise_for_status()
- async with ClientSession(headers=headers) as session:
- async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
- response.raise_for_status()
- text = ""
- if stream:
- async for line in response.content:
- line = line.decode('utf-8')
- if line.startswith('data: '):
- json_str = line[6:]
- try:
- chunk = json.loads(json_str)
- if 'choices' in chunk and chunk['choices']:
- content = chunk['choices'][0].get('delta', {}).get('content', '')
- text += content
- except json.JSONDecodeError as e:
- print(f"Error decoding JSON: {json_str}, Error: {e}")
- elif line.strip() == "[DONE]":
- break
- yield clean_response(text)
- else:
- response_json = await response.json()
- text = response_json["choices"][0]["message"]["content"]
- yield clean_response(text)
+ text = ""
+ if stream:
+ async for line in response.content:
+ line = line.decode('utf-8').strip()
+ if line.startswith('data: '):
+ json_str = line[6:]
+ try:
+ if json_str and json_str != "[DONE]":
+ chunk = json.loads(json_str)
+ if 'choices' in chunk and chunk['choices']:
+ content = chunk['choices'][0].get('delta', {}).get('content', '')
+ text += content
+ except json.JSONDecodeError as e:
+ print(f"Error decoding JSON: {json_str}, Error: {e}")
+ elif line == "[DONE]":
+ break
+ full_response += clean_response(text)
+ else:
+ response_json = await response.json()
+ text = response_json["choices"][0]["message"]["content"]
+ full_response += clean_response(text)
+ # Return the complete response after all chunks
+ yield full_response