From 5c2809a22feb63efe72d6c8bc21657a030c1da37 Mon Sep 17 00:00:00 2001 From: zengrr <47846202+zeng-rr@users.noreply.github.com> Date: Tue, 22 Aug 2023 14:59:58 +0800 Subject: set encoding and temperature --- g4f/Provider/Acytoo.py | 7 ++++--- g4f/Provider/Aichat.py | 6 ++++-- g4f/Provider/H2o.py | 2 ++ 3 files changed, 10 insertions(+), 5 deletions(-) (limited to 'g4f') diff --git a/g4f/Provider/Acytoo.py b/g4f/Provider/Acytoo.py index 32c67c0c..2edd9efd 100644 --- a/g4f/Provider/Acytoo.py +++ b/g4f/Provider/Acytoo.py @@ -19,11 +19,12 @@ class Acytoo(BaseProvider): **kwargs: Any, ) -> CreateResult: headers = _create_header() - payload = _create_payload(messages) + payload = _create_payload(messages, kwargs.get('temperature', 0.5)) url = "https://chat.acytoo.com/api/completions" response = requests.post(url=url, headers=headers, json=payload) response.raise_for_status() + response.encoding = "utf-8" yield response.text @@ -34,7 +35,7 @@ def _create_header(): } -def _create_payload(messages: list[dict[str, str]]): +def _create_payload(messages: list[dict[str, str]], temperature): payload_messages = [ message | {"createdAt": int(time.time()) * 1000} for message in messages ] @@ -42,6 +43,6 @@ def _create_payload(messages: list[dict[str, str]]): "key": "", "model": "gpt-3.5-turbo", "messages": payload_messages, - "temperature": 1, + "temperature": temperature, "password": "", } diff --git a/g4f/Provider/Aichat.py b/g4f/Provider/Aichat.py index 6992d071..a1d90db7 100644 --- a/g4f/Provider/Aichat.py +++ b/g4f/Provider/Aichat.py @@ -40,9 +40,9 @@ class Aichat(BaseProvider): json_data = { "message": base, - "temperature": 1, + "temperature": kwargs.get('temperature', 0.5), "presence_penalty": 0, - "top_p": 1, + "top_p": kwargs.get('top_p', 1), "frequency_penalty": 0, } @@ -52,4 +52,6 @@ class Aichat(BaseProvider): json=json_data, ) response.raise_for_status() + if not response.json()['response']: + raise Exception("Error Response: " + response.json()) yield response.json()["message"] diff --git a/g4f/Provider/H2o.py b/g4f/Provider/H2o.py index c2492e59..f9b799bb 100644 --- a/g4f/Provider/H2o.py +++ b/g4f/Provider/H2o.py @@ -75,6 +75,8 @@ class H2o(BaseProvider): headers=headers, json=data, ) + response.raise_for_status() + response.encoding = "utf-8" generated_text = response.text.replace("\n", "").split("data:") generated_text = json.loads(generated_text[-1]) -- cgit v1.2.3