diff options
author | Heiner Lohaus <heiner.lohaus@netformic.com> | 2023-08-25 06:41:32 +0200 |
---|---|---|
committer | Heiner Lohaus <heiner.lohaus@netformic.com> | 2023-08-25 06:41:32 +0200 |
commit | 126496d3cacd06a4fa8cbb4e5bde417ce6bb5b4a (patch) | |
tree | 00f989c070b0c001860c39507450aaf30e4302b1 /g4f/Provider/base_provider.py | |
parent | Add create_async method (diff) | |
download | gpt4free-126496d3cacd06a4fa8cbb4e5bde417ce6bb5b4a.tar gpt4free-126496d3cacd06a4fa8cbb4e5bde417ce6bb5b4a.tar.gz gpt4free-126496d3cacd06a4fa8cbb4e5bde417ce6bb5b4a.tar.bz2 gpt4free-126496d3cacd06a4fa8cbb4e5bde417ce6bb5b4a.tar.lz gpt4free-126496d3cacd06a4fa8cbb4e5bde417ce6bb5b4a.tar.xz gpt4free-126496d3cacd06a4fa8cbb4e5bde417ce6bb5b4a.tar.zst gpt4free-126496d3cacd06a4fa8cbb4e5bde417ce6bb5b4a.zip |
Diffstat (limited to 'g4f/Provider/base_provider.py')
-rw-r--r-- | g4f/Provider/base_provider.py | 85 |
1 files changed, 83 insertions, 2 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index 98ad3514..56d79ee6 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -1,7 +1,11 @@ from abc import ABC, abstractmethod -from ..typing import Any, CreateResult +from ..typing import Any, CreateResult, AsyncGenerator, Union +import browser_cookie3 +import asyncio +from time import time +import math class BaseProvider(ABC): url: str @@ -30,4 +34,81 @@ class BaseProvider(ABC): ("stream", "bool"), ] param = ", ".join([": ".join(p) for p in params]) - return f"g4f.provider.{cls.__name__} supports: ({param})"
\ No newline at end of file + return f"g4f.provider.{cls.__name__} supports: ({param})" + + +_cookies = {} + +def get_cookies(cookie_domain: str) -> dict: + if cookie_domain not in _cookies: + _cookies[cookie_domain] = {} + for cookie in browser_cookie3.load(cookie_domain): + _cookies[cookie_domain][cookie.name] = cookie.value + return _cookies[cookie_domain] + + +class AsyncProvider(BaseProvider): + @classmethod + def create_completion( + cls, + model: str, + messages: list[dict[str, str]], + stream: bool = False, + **kwargs: Any + ) -> CreateResult: + yield asyncio.run(cls.create_async(model, messages, **kwargs)) + + @staticmethod + @abstractmethod + async def create_async( + model: str, + messages: list[dict[str, str]], + **kwargs: Any, + ) -> str: + raise NotImplementedError() + + +class AsyncGeneratorProvider(AsyncProvider): + @classmethod + def create_completion( + cls, + model: str, + messages: list[dict[str, str]], + stream: bool = True, + **kwargs: Any + ) -> CreateResult: + if stream: + yield from run_generator(cls.create_async_generator(model, messages, **kwargs)) + else: + yield from AsyncProvider.create_completion(cls=cls, model=model, messages=messages, **kwargs) + + @classmethod + async def create_async( + cls, + model: str, + messages: list[dict[str, str]], + **kwargs: Any, + ) -> str: + chunks = [chunk async for chunk in cls.create_async_generator(model, messages, **kwargs)] + if chunks: + return "".join(chunks) + + @staticmethod + @abstractmethod + def create_async_generator( + model: str, + messages: list[dict[str, str]], + ) -> AsyncGenerator: + raise NotImplementedError() + + +def run_generator(generator: AsyncGenerator[Union[Any, str], Any]): + loop = asyncio.new_event_loop() + gen = generator.__aiter__() + + while True: + try: + yield loop.run_until_complete(gen.__anext__()) + + except StopAsyncIteration: + break |