summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/HuggingChat.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/HuggingChat.py')
-rw-r--r--g4f/Provider/HuggingChat.py36
1 files changed, 26 insertions, 10 deletions
diff --git a/g4f/Provider/HuggingChat.py b/g4f/Provider/HuggingChat.py
index 668ce4b1..527f0a56 100644
--- a/g4f/Provider/HuggingChat.py
+++ b/g4f/Provider/HuggingChat.py
@@ -6,12 +6,14 @@ from aiohttp import ClientSession, BaseConnector
from ..typing import AsyncResult, Messages
from ..requests.raise_for_status import raise_for_status
+from ..providers.conversation import BaseConversation
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
-from .helper import format_prompt, get_connector
+from .helper import format_prompt, get_connector, get_cookies
class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://huggingface.co/chat"
working = True
+ needs_auth = True
default_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
models = [
"HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
@@ -22,9 +24,6 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin):
'mistralai/Mistral-7B-Instruct-v0.2',
'meta-llama/Meta-Llama-3-70B-Instruct'
]
- model_aliases = {
- "openchat/openchat_3.5": "openchat/openchat-3.5-0106",
- }
@classmethod
def get_models(cls):
@@ -45,9 +44,16 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin):
connector: BaseConnector = None,
web_search: bool = False,
cookies: dict = None,
+ conversation: Conversation = None,
+ return_conversation: bool = False,
+ delete_conversation: bool = True,
**kwargs
) -> AsyncResult:
options = {"model": cls.get_model(model)}
+ if cookies is None:
+ cookies = get_cookies("huggingface.co", False)
+ if return_conversation:
+ delete_conversation = False
system_prompt = "\n".join([message["content"] for message in messages if message["role"] == "system"])
if system_prompt:
@@ -61,9 +67,14 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin):
headers=headers,
connector=get_connector(connector, proxy)
) as session:
- async with session.post(f"{cls.url}/conversation", json=options) as response:
- await raise_for_status(response)
- conversation_id = (await response.json())["conversationId"]
+ if conversation is None:
+ async with session.post(f"{cls.url}/conversation", json=options) as response:
+ await raise_for_status(response)
+ conversation_id = (await response.json())["conversationId"]
+ if return_conversation:
+ yield Conversation(conversation_id)
+ else:
+ conversation_id = conversation.conversation_id
async with session.get(f"{cls.url}/conversation/{conversation_id}/__data.json") as response:
await raise_for_status(response)
data: list = (await response.json())["nodes"][1]["data"]
@@ -72,7 +83,7 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin):
message_id: str = data[message_keys["id"]]
options = {
"id": message_id,
- "inputs": format_prompt(messages),
+ "inputs": format_prompt(messages) if conversation is None else messages[-1]["content"],
"is_continue": False,
"is_retry": False,
"web_search": web_search
@@ -92,5 +103,10 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin):
yield token
elif line["type"] == "finalAnswer":
break
- async with session.delete(f"{cls.url}/conversation/{conversation_id}") as response:
- await raise_for_status(response)
+ if delete_conversation:
+ async with session.delete(f"{cls.url}/conversation/{conversation_id}") as response:
+ await raise_for_status(response)
+
+class Conversation(BaseConversation):
+ def __init__(self, conversation_id: str) -> None:
+ self.conversation_id = conversation_id \ No newline at end of file