From f7bb30036e5e5482611627a040f54254ac162f72 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Sat, 7 Oct 2023 10:17:43 +0200 Subject: Improve code by AI --- g4f/Provider/base_provider.py | 20 ++++++----- g4f/Provider/retry_provider.py | 32 ++++++++--------- g4f/__init__.py | 72 ++++++++++++++++++------------------- g4f/requests.py | 81 +++++++++++++++++++++--------------------- 4 files changed, 103 insertions(+), 102 deletions(-) (limited to 'g4f') diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index a21dc871..35764081 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -10,11 +10,11 @@ from ..typing import AsyncGenerator, CreateResult class BaseProvider(ABC): url: str - working = False - needs_auth = False - supports_stream = False - supports_gpt_35_turbo = False - supports_gpt_4 = False + working: bool = False + needs_auth: bool = False + supports_stream: bool = False + supports_gpt_35_turbo: bool = False + supports_gpt_4: bool = False @staticmethod @abstractmethod @@ -38,13 +38,15 @@ class BaseProvider(ABC): ) -> str: if not loop: loop = get_event_loop() - def create_func(): + + def create_func() -> str: return "".join(cls.create_completion( model, messages, False, **kwargs )) + return await loop.run_in_executor( executor, create_func @@ -52,7 +54,7 @@ class BaseProvider(ABC): @classmethod @property - def params(cls): + def params(cls) -> str: params = [ ("model", "str"), ("messages", "list[dict[str, str]]"), @@ -103,7 +105,7 @@ class AsyncGeneratorProvider(AsyncProvider): stream=stream, **kwargs ) - gen = generator.__aiter__() + gen = generator.__aiter__() while True: try: yield loop.run_until_complete(gen.__anext__()) @@ -125,7 +127,7 @@ class AsyncGeneratorProvider(AsyncProvider): **kwargs ) ]) - + @staticmethod @abstractmethod def create_async_generator( diff --git a/g4f/Provider/retry_provider.py b/g4f/Provider/retry_provider.py index c1672aba..b49020b2 100644 --- a/g4f/Provider/retry_provider.py +++ b/g4f/Provider/retry_provider.py @@ -1,33 +1,33 @@ from __future__ import annotations import random - +from typing import List, Type, Dict from ..typing import CreateResult from .base_provider import BaseProvider, AsyncProvider from ..debug import logging class RetryProvider(AsyncProvider): - __name__ = "RetryProvider" - working = True - needs_auth = False - supports_stream = True - supports_gpt_35_turbo = False - supports_gpt_4 = False + __name__: str = "RetryProvider" + working: bool = True + needs_auth: bool = False + supports_stream: bool = True + supports_gpt_35_turbo: bool = False + supports_gpt_4: bool = False def __init__( self, - providers: list[type[BaseProvider]], + providers: List[Type[BaseProvider]], shuffle: bool = True ) -> None: - self.providers = providers - self.shuffle = shuffle + self.providers: List[Type[BaseProvider]] = providers + self.shuffle: bool = shuffle def create_completion( self, model: str, - messages: list[dict[str, str]], + messages: List[Dict[str, str]], stream: bool = False, **kwargs ) -> CreateResult: @@ -38,8 +38,8 @@ class RetryProvider(AsyncProvider): if self.shuffle: random.shuffle(providers) - self.exceptions = {} - started = False + self.exceptions: Dict[str, Exception] = {} + started: bool = False for provider in providers: try: if logging: @@ -61,14 +61,14 @@ class RetryProvider(AsyncProvider): async def create_async( self, model: str, - messages: list[dict[str, str]], + messages: List[Dict[str, str]], **kwargs ) -> str: providers = [provider for provider in self.providers] if self.shuffle: random.shuffle(providers) - self.exceptions = {} + self.exceptions: Dict[str, Exception] = {} for provider in providers: try: return await provider.create_async(model, messages, **kwargs) @@ -79,7 +79,7 @@ class RetryProvider(AsyncProvider): self.raise_exceptions() - def raise_exceptions(self): + def raise_exceptions(self) -> None: if self.exceptions: raise RuntimeError("\n".join(["All providers failed:"] + [ f"{p}: {self.exceptions[p].__class__.__name__}: {self.exceptions[p]}" for p in self.exceptions diff --git a/g4f/__init__.py b/g4f/__init__.py index 268b8aab..5d0b47d8 100644 --- a/g4f/__init__.py +++ b/g4f/__init__.py @@ -1,30 +1,30 @@ from __future__ import annotations -from g4f import models -from .Provider import BaseProvider -from .typing import CreateResult, Union -from .debug import logging -from requests import get +from requests import get +from g4f.models import Model, ModelUtils +from .Provider import BaseProvider +from .typing import CreateResult, Union +from .debug import logging version = '0.1.5.4' -def check_pypi_version(): +def check_pypi_version() -> None: try: - response = get(f"https://pypi.org/pypi/g4f/json").json() + response = get("https://pypi.org/pypi/g4f/json").json() latest_version = response["info"]["version"] - + if version != latest_version: print(f'New pypi version: {latest_version} (current: {version}) | pip install -U g4f') - + except Exception as e: print(f'Failed to check g4f pypi version: {e}') check_pypi_version() -def get_model_and_provider(model: Union[models.Model, str], provider: type[BaseProvider], stream: bool): +def get_model_and_provider(model: Union[Model, str], provider: Union[type[BaseProvider], None], stream: bool) -> tuple[Model, type[BaseProvider]]: if isinstance(model, str): - if model in models.ModelUtils.convert: - model = models.ModelUtils.convert[model] + if model in ModelUtils.convert: + model = ModelUtils.convert[model] else: raise Exception(f'The model: {model} does not exist') @@ -33,14 +33,13 @@ def get_model_and_provider(model: Union[models.Model, str], provider: type[BaseP if not provider: raise Exception(f'No provider found for model: {model}') - + if not provider.working: raise Exception(f'{provider.__name__} is not working') - + if not provider.supports_stream and stream: - raise Exception( - f'ValueError: {provider.__name__} does not support "stream" argument') - + raise Exception(f'ValueError: {provider.__name__} does not support "stream" argument') + if logging: print(f'Using {provider.__name__} provider') @@ -49,11 +48,11 @@ def get_model_and_provider(model: Union[models.Model, str], provider: type[BaseP class ChatCompletion: @staticmethod def create( - model : Union[models.Model, str], - messages : list[dict[str, str]], - provider : Union[type[BaseProvider], None] = None, - stream : bool = False, - auth : Union[str, None] = None, + model: Union[Model, str], + messages: list[dict[str, str]], + provider: Union[type[BaseProvider], None] = None, + stream: bool = False, + auth: Union[str, None] = None, **kwargs ) -> Union[CreateResult, str]: @@ -62,7 +61,7 @@ class ChatCompletion: if provider.needs_auth and not auth: raise Exception( f'ValueError: {provider.__name__} requires authentication (use auth=\'cookie or token or jwt ...\' param)') - + if provider.needs_auth: kwargs['auth'] = auth @@ -71,9 +70,9 @@ class ChatCompletion: @staticmethod async def create_async( - model : Union[models.Model, str], - messages : list[dict[str, str]], - provider : Union[type[BaseProvider], None] = None, + model: Union[Model, str], + messages: list[dict[str, str]], + provider: Union[type[BaseProvider], None] = None, **kwargs ) -> str: model, provider = get_model_and_provider(model, provider, False) @@ -83,11 +82,13 @@ class ChatCompletion: class Completion: @staticmethod def create( - model : Union[models.Model, str], - prompt : str, - provider : Union[type[BaseProvider], None] = None, - stream : bool = False, **kwargs) -> Union[CreateResult, str]: - + model: str, + prompt: str, + provider: Union[type[BaseProvider], None] = None, + stream: bool = False, + **kwargs + ) -> Union[CreateResult, str]: + allowed_models = [ 'code-davinci-002', 'text-ada-001', @@ -96,13 +97,12 @@ class Completion: 'text-davinci-002', 'text-davinci-003' ] - + if model not in allowed_models: raise Exception(f'ValueError: Can\'t use {model} with Completion.create()') - + model, provider = get_model_and_provider(model, provider, stream) - result = provider.create_completion(model.name, - [{"role": "user", "content": prompt}], stream, **kwargs) + result = provider.create_completion(model.name, [{"role": "user", "content": prompt}], stream, **kwargs) - return result if stream else ''.join(result) + return result if stream else ''.join(result) \ No newline at end of file diff --git a/g4f/requests.py b/g4f/requests.py index c51d9804..3a4a3f54 100644 --- a/g4f/requests.py +++ b/g4f/requests.py @@ -1,47 +1,44 @@ from __future__ import annotations -import warnings, json, asyncio - +import warnings +import json +import asyncio from functools import partialmethod from asyncio import Future, Queue from typing import AsyncGenerator from curl_cffi.requests import AsyncSession, Response - import curl_cffi -is_newer_0_5_8 = hasattr(AsyncSession, "_set_cookies") or hasattr(curl_cffi.requests.Cookies, "get_cookies_for_curl") -is_newer_0_5_9 = hasattr(curl_cffi.AsyncCurl, "remove_handle") -is_newer_0_5_10 = hasattr(AsyncSession, "release_curl") +is_newer_0_5_8: bool = hasattr(AsyncSession, "_set_cookies") or hasattr(curl_cffi.requests.Cookies, "get_cookies_for_curl") +is_newer_0_5_9: bool = hasattr(curl_cffi.AsyncCurl, "remove_handle") +is_newer_0_5_10: bool = hasattr(AsyncSession, "release_curl") + class StreamResponse: - def __init__(self, inner: Response, queue: Queue): - self.inner = inner - self.queue = queue + def __init__(self, inner: Response, queue: Queue[bytes]) -> None: + self.inner: Response = inner + self.queue: Queue[bytes] = queue self.request = inner.request - self.status_code = inner.status_code - self.reason = inner.reason - self.ok = inner.ok + self.status_code: int = inner.status_code + self.reason: str = inner.reason + self.ok: bool = inner.ok self.headers = inner.headers self.cookies = inner.cookies async def text(self) -> str: - content = await self.read() + content: bytes = await self.read() return content.decode() - def raise_for_status(self): + def raise_for_status(self) -> None: if not self.ok: raise RuntimeError(f"HTTP Error {self.status_code}: {self.reason}") - async def json(self, **kwargs): + async def json(self, **kwargs) -> dict: return json.loads(await self.read(), **kwargs) - - async def iter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None) -> AsyncGenerator[bytes]: - """ - Copied from: https://requests.readthedocs.io/en/latest/_modules/requests/models/ - which is under the License: Apache 2.0 - """ - pending = None + + async def iter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None) -> AsyncGenerator[bytes, None]: + pending: bytes = None async for chunk in self.iter_content( chunk_size=chunk_size, decode_unicode=decode_unicode @@ -63,7 +60,7 @@ class StreamResponse: if pending is not None: yield pending - async def iter_content(self, chunk_size=None, decode_unicode=False) -> As: + async def iter_content(self, chunk_size=None, decode_unicode=False) -> AsyncGenerator[bytes, None]: if chunk_size: warnings.warn("chunk_size is ignored, there is no way to tell curl that.") if decode_unicode: @@ -77,22 +74,23 @@ class StreamResponse: async def read(self) -> bytes: return b"".join([chunk async for chunk in self.iter_content()]) + class StreamRequest: - def __init__(self, session: AsyncSession, method: str, url: str, **kwargs): - self.session = session - self.loop = session.loop if session.loop else asyncio.get_running_loop() - self.queue = Queue() - self.method = method - self.url = url - self.options = kwargs - self.handle = None - - def _on_content(self, data): + def __init__(self, session: AsyncSession, method: str, url: str, **kwargs) -> None: + self.session: AsyncSession = session + self.loop: asyncio.AbstractEventLoop = session.loop if session.loop else asyncio.get_running_loop() + self.queue: Queue[bytes] = Queue() + self.method: str = method + self.url: str = url + self.options: dict = kwargs + self.handle: curl_cffi.AsyncCurl = None + + def _on_content(self, data: bytes) -> None: if not self.enter.done(): self.enter.set_result(None) self.queue.put_nowait(data) - def _on_done(self, task: Future): + def _on_done(self, task: Future) -> None: if not self.enter.done(): self.enter.set_result(None) self.queue.put_nowait(None) @@ -102,8 +100,8 @@ class StreamRequest: async def fetch(self) -> StreamResponse: if self.handle: raise RuntimeError("Request already started") - self.curl = await self.session.pop_curl() - self.enter = self.loop.create_future() + self.curl: curl_cffi.AsyncCurl = await self.session.pop_curl() + self.enter: asyncio.Future = self.loop.create_future() if is_newer_0_5_10: request, _, header_buffer, _, _ = self.session._set_curl_options( self.curl, @@ -121,7 +119,7 @@ class StreamRequest: **self.options ) if is_newer_0_5_9: - self.handle = self.session.acurl.add_handle(self.curl) + self.handle = self.session.acurl.add_handle(self.curl) else: await self.session.acurl.add_handle(self.curl, False) self.handle = self.session.acurl._curl2future[self.curl] @@ -140,14 +138,14 @@ class StreamRequest: response, self.queue ) - + async def __aenter__(self) -> StreamResponse: return await self.fetch() - async def __aexit__(self, *args): + async def __aexit__(self, *args) -> None: self.release_curl() - def release_curl(self): + def release_curl(self) -> None: if is_newer_0_5_10: self.session.release_curl(self.curl) return @@ -162,6 +160,7 @@ class StreamRequest: self.session.push_curl(self.curl) self.curl = None + class StreamSession(AsyncSession): def request( self, @@ -170,7 +169,7 @@ class StreamSession(AsyncSession): **kwargs ) -> StreamRequest: return StreamRequest(self, method, url, **kwargs) - + head = partialmethod(request, "HEAD") get = partialmethod(request, "GET") post = partialmethod(request, "POST") -- cgit v1.2.3