diff options
Diffstat (limited to 'g4f/Provider/AIUncensored.py')
-rw-r--r-- | g4f/Provider/AIUncensored.py | 148 |
1 files changed, 76 insertions, 72 deletions
diff --git a/g4f/Provider/AIUncensored.py b/g4f/Provider/AIUncensored.py index d653191c..ce492b38 100644 --- a/g4f/Provider/AIUncensored.py +++ b/g4f/Provider/AIUncensored.py @@ -2,33 +2,49 @@ from __future__ import annotations import json from aiohttp import ClientSession +from itertools import cycle from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .helper import format_prompt from ..image import ImageResponse + class AIUncensored(AsyncGeneratorProvider, ProviderModelMixin): url = "https://www.aiuncensored.info" + api_endpoints_text = [ + "https://twitterclone-i0wr.onrender.com/api/chat", + "https://twitterclone-4e8t.onrender.com/api/chat", + "https://twitterclone-8wd1.onrender.com/api/chat", + ] + api_endpoints_image = [ + "https://twitterclone-4e8t.onrender.com/api/image", + "https://twitterclone-i0wr.onrender.com/api/image", + "https://twitterclone-8wd1.onrender.com/api/image", + ] + api_endpoints_cycle_text = cycle(api_endpoints_text) + api_endpoints_cycle_image = cycle(api_endpoints_image) working = True supports_stream = True supports_system_message = True supports_message_history = True - default_model = 'ai_uncensored' - chat_models = [default_model] - image_models = ['ImageGenerator'] - models = [*chat_models, *image_models] - - api_endpoints = { - 'ai_uncensored': "https://twitterclone-i0wr.onrender.com/api/chat", - 'ImageGenerator': "https://twitterclone-4e8t.onrender.com/api/image" + default_model = 'TextGenerations' + text_models = [default_model] + image_models = ['ImageGenerations'] + models = [*text_models, *image_models] + + model_aliases = { + #"": "TextGenerations", + "flux": "ImageGenerations", } @classmethod def get_model(cls, model: str) -> str: if model in cls.models: return model + elif model in cls.model_aliases: + return cls.model_aliases[model] else: return cls.default_model @@ -38,75 +54,63 @@ class AIUncensored(AsyncGeneratorProvider, ProviderModelMixin): model: str, messages: Messages, proxy: str = None, - stream: bool = False, **kwargs ) -> AsyncResult: model = cls.get_model(model) - if model in cls.chat_models: - async with ClientSession(headers={"content-type": "application/json"}) as session: + headers = { + 'accept': '*/*', + 'accept-language': 'en-US,en;q=0.9', + 'cache-control': 'no-cache', + 'content-type': 'application/json', + 'origin': 'https://www.aiuncensored.info', + 'pragma': 'no-cache', + 'priority': 'u=1, i', + 'referer': 'https://www.aiuncensored.info/', + 'sec-ch-ua': '"Not?A_Brand";v="99", "Chromium";v="130"', + 'sec-ch-ua-mobile': '?0', + 'sec-ch-ua-platform': '"Linux"', + 'sec-fetch-dest': 'empty', + 'sec-fetch-mode': 'cors', + 'sec-fetch-site': 'cross-site', + 'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36' + } + + async with ClientSession(headers=headers) as session: + if model in cls.image_models: + prompt = messages[-1]['content'] data = { - "messages": [ - {"role": "user", "content": format_prompt(messages)} - ], - "stream": stream + "prompt": prompt, } - async with session.post(cls.api_endpoints[model], json=data, proxy=proxy) as response: + api_endpoint = next(cls.api_endpoints_cycle_image) + async with session.post(api_endpoint, json=data, proxy=proxy) as response: response.raise_for_status() - if stream: - async for chunk in cls._handle_streaming_response(response): - yield chunk - else: - yield await cls._handle_non_streaming_response(response) - elif model in cls.image_models: - headers = { - "accept": "*/*", - "accept-language": "en-US,en;q=0.9", - "cache-control": "no-cache", - "content-type": "application/json", - "origin": cls.url, - "pragma": "no-cache", - "priority": "u=1, i", - "referer": f"{cls.url}/", - "sec-ch-ua": '"Chromium";v="129", "Not=A?Brand";v="8"', - "sec-ch-ua-mobile": "?0", - "sec-ch-ua-platform": '"Linux"', - "sec-fetch-dest": "empty", - "sec-fetch-mode": "cors", - "sec-fetch-site": "cross-site", - "user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36" - } - async with ClientSession(headers=headers) as session: - prompt = messages[0]['content'] - data = {"prompt": prompt} - async with session.post(cls.api_endpoints[model], json=data, proxy=proxy) as response: + response_data = await response.json() + image_url = response_data['image_url'] + image_response = ImageResponse(images=image_url, alt=prompt) + yield image_response + elif model in cls.text_models: + data = { + "messages": [ + { + "role": "user", + "content": format_prompt(messages) + } + ] + } + api_endpoint = next(cls.api_endpoints_cycle_text) + async with session.post(api_endpoint, json=data, proxy=proxy) as response: response.raise_for_status() - result = await response.json() - image_url = result.get('image_url', '') - if image_url: - yield ImageResponse(image_url, alt=prompt) - else: - yield "Failed to generate image. Please try again." - - @classmethod - async def _handle_streaming_response(cls, response): - async for line in response.content: - line = line.decode('utf-8').strip() - if line.startswith("data: "): - if line == "data: [DONE]": - break - try: - json_data = json.loads(line[6:]) - if 'data' in json_data: - yield json_data['data'] - except json.JSONDecodeError: - pass - - @classmethod - async def _handle_non_streaming_response(cls, response): - response_json = await response.json() - return response_json.get('content', "Sorry, I couldn't generate a response.") - - @classmethod - def validate_response(cls, response: str) -> str: - return response + full_response = "" + async for line in response.content: + line = line.decode('utf-8') + if line.startswith("data: "): + try: + json_str = line[6:] + if json_str != "[DONE]": + data = json.loads(json_str) + if "data" in data: + full_response += data["data"] + yield data["data"] + except json.JSONDecodeError: + continue |