diff options
author | Heiner Lohaus <hlohaus@users.noreply.github.com> | 2024-04-21 22:39:00 +0200 |
---|---|---|
committer | Heiner Lohaus <hlohaus@users.noreply.github.com> | 2024-04-21 22:39:00 +0200 |
commit | 3a23e81de93c4c9a83aa22b70ea13066f06541e3 (patch) | |
tree | caa0c5f892c3ab8df393c1821bbdab780c5d83de /g4f/Provider | |
parent | Add image model list (diff) | |
download | gpt4free-3a23e81de93c4c9a83aa22b70ea13066f06541e3.tar gpt4free-3a23e81de93c4c9a83aa22b70ea13066f06541e3.tar.gz gpt4free-3a23e81de93c4c9a83aa22b70ea13066f06541e3.tar.bz2 gpt4free-3a23e81de93c4c9a83aa22b70ea13066f06541e3.tar.lz gpt4free-3a23e81de93c4c9a83aa22b70ea13066f06541e3.tar.xz gpt4free-3a23e81de93c4c9a83aa22b70ea13066f06541e3.tar.zst gpt4free-3a23e81de93c4c9a83aa22b70ea13066f06541e3.zip |
Diffstat (limited to 'g4f/Provider')
-rw-r--r-- | g4f/Provider/MetaAI.py | 2 | ||||
-rw-r--r-- | g4f/Provider/Replicate.py | 84 | ||||
-rw-r--r-- | g4f/Provider/__init__.py | 2 | ||||
-rw-r--r-- | g4f/Provider/needs_auth/OpenaiChat.py | 15 | ||||
-rw-r--r-- | g4f/Provider/unfinished/Replicate.py | 78 |
5 files changed, 94 insertions, 87 deletions
diff --git a/g4f/Provider/MetaAI.py b/g4f/Provider/MetaAI.py index 045255e7..caed7778 100644 --- a/g4f/Provider/MetaAI.py +++ b/g4f/Provider/MetaAI.py @@ -89,7 +89,7 @@ class MetaAI(AsyncGeneratorProvider): headers = {} headers = { 'content-type': 'application/x-www-form-urlencoded', - 'cookie': format_cookies(cookies), + 'cookie': format_cookies(self.cookies), 'origin': 'https://www.meta.ai', 'referer': 'https://www.meta.ai/', 'x-asbd-id': '129477', diff --git a/g4f/Provider/Replicate.py b/g4f/Provider/Replicate.py new file mode 100644 index 00000000..593fd04d --- /dev/null +++ b/g4f/Provider/Replicate.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin +from .helper import format_prompt, filter_none +from ..typing import AsyncResult, Messages +from ..requests import raise_for_status +from ..requests.aiohttp import StreamSession +from ..errors import ResponseError, MissingAuthError + +class Replicate(AsyncGeneratorProvider, ProviderModelMixin): + url = "https://replicate.com" + working = True + default_model = "meta/meta-llama-3-70b-instruct" + + @classmethod + async def create_async_generator( + cls, + model: str, + messages: Messages, + api_key: str = None, + proxy: str = None, + timeout: int = 180, + system_prompt: str = None, + max_new_tokens: int = None, + temperature: float = None, + top_p: float = None, + top_k: float = None, + stop: list = None, + extra_data: dict = {}, + headers: dict = { + "accept": "application/json", + }, + **kwargs + ) -> AsyncResult: + model = cls.get_model(model) + if cls.needs_auth and api_key is None: + raise MissingAuthError("api_key is missing") + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" + api_base = "https://api.replicate.com/v1/models/" + else: + api_base = "https://replicate.com/api/models/" + async with StreamSession( + proxy=proxy, + headers=headers, + timeout=timeout + ) as session: + data = { + "stream": True, + "input": { + "prompt": format_prompt(messages), + **filter_none( + system_prompt=system_prompt, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stop_sequences=",".join(stop) if stop else None + ), + **extra_data + }, + } + url = f"{api_base.rstrip('/')}/{model}/predictions" + async with session.post(url, json=data) as response: + message = "Model not found" if response.status == 404 else None + await raise_for_status(response, message) + result = await response.json() + if "id" not in result: + raise ResponseError(f"Invalid response: {result}") + async with session.get(result["urls"]["stream"], headers={"Accept": "text/event-stream"}) as response: + await raise_for_status(response) + event = None + async for line in response.iter_lines(): + if line.startswith(b"event: "): + event = line[7:] + if event == b"done": + break + elif event == b"output": + if line.startswith(b"data: "): + new_text = line[6:].decode() + if new_text: + yield new_text + else: + yield "\n"
\ No newline at end of file diff --git a/g4f/Provider/__init__.py b/g4f/Provider/__init__.py index 27c14672..d2d9bfda 100644 --- a/g4f/Provider/__init__.py +++ b/g4f/Provider/__init__.py @@ -9,7 +9,6 @@ from .deprecated import * from .not_working import * from .selenium import * from .needs_auth import * -from .unfinished import * from .Aichatos import Aichatos from .Aura import Aura @@ -46,6 +45,7 @@ from .MetaAI import MetaAI from .MetaAIAccount import MetaAIAccount from .PerplexityLabs import PerplexityLabs from .Pi import Pi +from .Replicate import Replicate from .ReplicateImage import ReplicateImage from .Vercel import Vercel from .WhiteRabbitNeo import WhiteRabbitNeo diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 7952d606..3d6e9858 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -340,9 +340,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): Raises: RuntimeError: If an error occurs during processing. """ - async with StreamSession( - proxies={"all": proxy}, + proxy=proxy, impersonate="chrome", timeout=timeout ) as session: @@ -364,26 +363,27 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): api_key = cls._api_key = None cls._create_request_args() if debug.logging: - print("OpenaiChat: Load default_model failed") + print("OpenaiChat: Load default model failed") print(f"{e.__class__.__name__}: {e}") arkose_token = None if cls.default_model is None: + error = None try: arkose_token, api_key, cookies, headers = await getArkoseAndAccessToken(proxy) cls._create_request_args(cookies, headers) cls._set_api_key(api_key) except NoValidHarFileError as e: - ... + error = e if cls._api_key is None: await cls.nodriver_access_token() if cls._api_key is None and cls.needs_auth: - raise e + raise error cls.default_model = cls.get_model(await cls.get_default_model(session, cls._headers)) async with session.post( f"{cls.url}/backend-anon/sentinel/chat-requirements" - if not cls._api_key else + if cls._api_key is None else f"{cls.url}/backend-api/sentinel/chat-requirements", json={"conversation_mode_kind": "primary_assistant"}, headers=cls._headers @@ -412,7 +412,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): print("OpenaiChat: Upload image failed") print(f"{e.__class__.__name__}: {e}") - model = cls.get_model(model).replace("gpt-3.5-turbo", "text-davinci-002-render-sha") + model = cls.get_model(model) + model = "text-davinci-002-render-sha" if model == "gpt-3.5-turbo" else model if conversation is None: conversation = Conversation(conversation_id, str(uuid.uuid4()) if parent_id is None else parent_id) else: diff --git a/g4f/Provider/unfinished/Replicate.py b/g4f/Provider/unfinished/Replicate.py deleted file mode 100644 index aaaf31b3..00000000 --- a/g4f/Provider/unfinished/Replicate.py +++ /dev/null @@ -1,78 +0,0 @@ -from __future__ import annotations - -import asyncio - -from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin -from ..helper import format_prompt, filter_none -from ...typing import AsyncResult, Messages -from ...requests import StreamSession, raise_for_status -from ...image import ImageResponse -from ...errors import ResponseError, MissingAuthError - -class Replicate(AsyncGeneratorProvider, ProviderModelMixin): - url = "https://replicate.com" - working = True - default_model = "mistralai/mixtral-8x7b-instruct-v0.1" - api_base = "https://api.replicate.com/v1/models/" - - @classmethod - async def create_async_generator( - cls, - model: str, - messages: Messages, - api_key: str = None, - proxy: str = None, - timeout: int = 180, - system_prompt: str = None, - max_new_tokens: int = None, - temperature: float = None, - top_p: float = None, - top_k: float = None, - stop: list = None, - extra_data: dict = {}, - headers: dict = {}, - **kwargs - ) -> AsyncResult: - model = cls.get_model(model) - if api_key is None: - raise MissingAuthError("api_key is missing") - headers["Authorization"] = f"Bearer {api_key}" - async with StreamSession( - proxies={"all": proxy}, - headers=headers, - timeout=timeout - ) as session: - data = { - "stream": True, - "input": { - "prompt": format_prompt(messages), - **filter_none( - system_prompt=system_prompt, - max_new_tokens=max_new_tokens, - temperature=temperature, - top_p=top_p, - top_k=top_k, - stop_sequences=",".join(stop) if stop else None - ), - **extra_data - }, - } - url = f"{cls.api_base.rstrip('/')}/{model}/predictions" - async with session.post(url, json=data) as response: - await raise_for_status(response) - result = await response.json() - if "id" not in result: - raise ResponseError(f"Invalid response: {result}") - async with session.get(result["urls"]["stream"], headers={"Accept": "text/event-stream"}) as response: - await raise_for_status(response) - event = None - async for line in response.iter_lines(): - if line.startswith(b"event: "): - event = line[7:] - elif event == b"output": - if line.startswith(b"data: "): - yield line[6:].decode() - elif not line.startswith(b"id: "): - continue#yield "+"+line.decode() - elif event == b"done": - break
\ No newline at end of file |