diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/RubiksAI.py | 124 |
1 files changed, 47 insertions, 77 deletions
diff --git a/g4f/Provider/RubiksAI.py b/g4f/Provider/RubiksAI.py index 7e76d558..c06e6c3d 100644 --- a/g4f/Provider/RubiksAI.py +++ b/g4f/Provider/RubiksAI.py @@ -1,7 +1,6 @@ + from __future__ import annotations -import asyncio -import aiohttp import random import string import json @@ -11,34 +10,24 @@ from aiohttp import ClientSession from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider, ProviderModelMixin -from .helper import format_prompt - +from ..requests.raise_for_status import raise_for_status class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin): label = "Rubiks AI" url = "https://rubiks.ai" - api_endpoint = "https://rubiks.ai/search/api.php" + api_endpoint = "https://rubiks.ai/search/api/" working = True supports_stream = True supports_system_message = True supports_message_history = True - default_model = 'llama-3.1-70b-versatile' - models = [default_model, 'gpt-4o-mini'] + default_model = 'gpt-4o-mini' + models = [default_model, 'gpt-4o', 'o1-mini', 'claude-3.5-sonnet', 'grok-beta', 'gemini-1.5-pro', 'nova-pro'] model_aliases = { "llama-3.1-70b": "llama-3.1-70b-versatile", } - @classmethod - def get_model(cls, model: str) -> str: - if model in cls.models: - return model - elif model in cls.model_aliases: - return cls.model_aliases[model] - else: - return cls.default_model - @staticmethod def generate_mid() -> str: """ @@ -70,7 +59,8 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin): model: str, messages: Messages, proxy: str = None, - websearch: bool = False, + web_search: bool = False, + temperature: float = 0.6, **kwargs ) -> AsyncResult: """ @@ -80,20 +70,18 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin): - model (str): The model to use in the request. - messages (Messages): The messages to send as a prompt. - proxy (str, optional): Proxy URL, if needed. - - websearch (bool, optional): Indicates whether to include search sources in the response. Defaults to False. + - web_search (bool, optional): Indicates whether to include search sources in the response. Defaults to False. """ model = cls.get_model(model) - prompt = format_prompt(messages) - q_value = prompt mid_value = cls.generate_mid() - referer = cls.create_referer(q=q_value, mid=mid_value, model=model) - - url = cls.api_endpoint - params = { - 'q': q_value, - 'model': model, - 'id': '', - 'mid': mid_value + referer = cls.create_referer(q=messages[-1]["content"], mid=mid_value, model=model) + + data = { + "messages": messages, + "model": model, + "search": web_search, + "stream": True, + "temperature": temperature } headers = { @@ -111,52 +99,34 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin): 'sec-ch-ua-mobile': '?0', 'sec-ch-ua-platform': '"Linux"' } - - try: - timeout = aiohttp.ClientTimeout(total=None) - async with ClientSession(timeout=timeout) as session: - async with session.get(url, headers=headers, params=params, proxy=proxy) as response: - if response.status != 200: - yield f"Request ended with status code {response.status}" - return - - assistant_text = '' - sources = [] - - async for line in response.content: - decoded_line = line.decode('utf-8').strip() - if not decoded_line.startswith('data: '): - continue - data = decoded_line[6:] - if data in ('[DONE]', '{"done": ""}'): - break - try: - json_data = json.loads(data) - except json.JSONDecodeError: - continue - - if 'url' in json_data and 'title' in json_data: - if websearch: - sources.append({'title': json_data['title'], 'url': json_data['url']}) - - elif 'choices' in json_data: - for choice in json_data['choices']: - delta = choice.get('delta', {}) - content = delta.get('content', '') - role = delta.get('role', '') - if role == 'assistant': - continue - assistant_text += content - - if websearch and sources: - sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)]) - assistant_text += f"\n\n**Source:**\n{sources_text}" - - yield assistant_text - - except asyncio.CancelledError: - yield "The request was cancelled." - except aiohttp.ClientError as e: - yield f"An error occurred during the request: {e}" - except Exception as e: - yield f"An unexpected error occurred: {e}" + async with ClientSession() as session: + async with session.post(cls.api_endpoint, headers=headers, json=data, proxy=proxy) as response: + await raise_for_status(response) + + sources = [] + async for line in response.content: + decoded_line = line.decode('utf-8').strip() + if not decoded_line.startswith('data: '): + continue + data = decoded_line[6:] + if data in ('[DONE]', '{"done": ""}'): + break + try: + json_data = json.loads(data) + except json.JSONDecodeError: + continue + + if 'url' in json_data and 'title' in json_data: + if web_search: + sources.append({'title': json_data['title'], 'url': json_data['url']}) + + elif 'choices' in json_data: + for choice in json_data['choices']: + delta = choice.get('delta', {}) + content = delta.get('content', '') + if content: + yield content + + if web_search and sources: + sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)]) + yield f"\n\n**Source:**\n{sources_text}"
\ No newline at end of file |