diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/Koala.py | 36 |
1 files changed, 24 insertions, 12 deletions
diff --git a/g4f/Provider/Koala.py b/g4f/Provider/Koala.py index 849bcdbe..c708bcb9 100644 --- a/g4f/Provider/Koala.py +++ b/g4f/Provider/Koala.py @@ -1,7 +1,8 @@ from __future__ import annotations import json -from aiohttp import ClientSession, BaseConnector +from typing import AsyncGenerator, Optional, List, Dict, Union, Any +from aiohttp import ClientSession, BaseConnector, ClientResponse from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider @@ -19,12 +20,13 @@ class Koala(AsyncGeneratorProvider): cls, model: str, messages: Messages, - proxy: str = None, - connector: BaseConnector = None, - **kwargs - ) -> AsyncResult: + proxy: Optional[str] = None, + connector: Optional[BaseConnector] = None, + **kwargs: Any + ) -> AsyncGenerator[Dict[str, Union[str, int, float, List[Dict[str, Any]], None]], None]: if not model: model = "gpt-3.5-turbo" + headers = { "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:122.0) Gecko/20100101 Firefox/122.0", "Accept": "text/event-stream", @@ -40,13 +42,17 @@ class Koala(AsyncGeneratorProvider): "Sec-Fetch-Site": "same-origin", "TE": "trailers", } + async with ClientSession(headers=headers, connector=get_connector(connector, proxy)) as session: - input = messages[-1]["content"] - system_messages = [message["content"] for message in messages if message["role"] == "system"] + input_text = messages[-1]["content"] + system_messages = " ".join( + message["content"] for message in messages if message["role"] == "system" + ) if system_messages: - input += " ".join(system_messages) + input_text += f" {system_messages}" + data = { - "input": input, + "input": input_text, "inputHistory": [ message["content"] for message in messages[:-1] @@ -59,8 +65,14 @@ class Koala(AsyncGeneratorProvider): ], "model": model, } + async with session.post(f"{cls.url}/api/gpt/", json=data, proxy=proxy) as response: await raise_for_status(response) - async for chunk in response.content: - if chunk.startswith(b"data: "): - yield json.loads(chunk[6:])
\ No newline at end of file + async for chunk in cls._parse_event_stream(response): + yield chunk + + @staticmethod + async def _parse_event_stream(response: ClientResponse) -> AsyncGenerator[Dict[str, Any], None]: + async for chunk in response.content: + if chunk.startswith(b"data: "): + yield json.loads(chunk[6:]) |