diff options
Diffstat (limited to 'g4f/providers')
-rw-r--r-- | g4f/providers/asyncio.py | 65 | ||||
-rw-r--r-- | g4f/providers/base_provider.py | 63 |
2 files changed, 70 insertions, 58 deletions
diff --git a/g4f/providers/asyncio.py b/g4f/providers/asyncio.py new file mode 100644 index 00000000..cf0ce1a0 --- /dev/null +++ b/g4f/providers/asyncio.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import asyncio +from asyncio import AbstractEventLoop, runners +from typing import Union, Callable, AsyncGenerator, Generator + +from ..errors import NestAsyncioError + +try: + import nest_asyncio + has_nest_asyncio = True +except ImportError: + has_nest_asyncio = False +try: + import uvloop + has_uvloop = True +except ImportError: + has_uvloop = False + +def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]: + try: + loop = asyncio.get_running_loop() + # Do not patch uvloop loop because its incompatible. + if has_uvloop: + if isinstance(loop, uvloop.Loop): + return loop + if not hasattr(loop.__class__, "_nest_patched"): + if has_nest_asyncio: + nest_asyncio.apply(loop) + elif check_nested: + raise NestAsyncioError('Install "nest_asyncio" package | pip install -U nest_asyncio') + return loop + except RuntimeError: + pass + +# Fix for RuntimeError: async generator ignored GeneratorExit +async def await_callback(callback: Callable): + return await callback() + +async def async_generator_to_list(generator: AsyncGenerator) -> list: + return [item async for item in generator] + +def to_sync_generator(generator: AsyncGenerator) -> Generator: + loop = get_running_loop(check_nested=False) + new_loop = False + if loop is None: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + new_loop = True + gen = generator.__aiter__() + try: + while True: + yield loop.run_until_complete(await_callback(gen.__anext__)) + except StopAsyncIteration: + pass + finally: + if new_loop: + try: + runners._cancel_all_tasks(loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + if hasattr(loop, "shutdown_default_executor"): + loop.run_until_complete(loop.shutdown_default_executor()) + finally: + asyncio.set_event_loop(None) + loop.close()
\ No newline at end of file diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py index c6d0d950..e2c2f46a 100644 --- a/g4f/providers/base_provider.py +++ b/g4f/providers/base_provider.py @@ -7,30 +7,14 @@ from asyncio import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor from abc import abstractmethod from inspect import signature, Parameter -from typing import Callable, Union from ..typing import CreateResult, AsyncResult, Messages from .types import BaseProvider +from .asyncio import get_running_loop, to_sync_generator from .response import FinishReason, BaseConversation, SynthesizeData -from ..errors import NestAsyncioError, ModelNotSupportedError +from ..errors import ModelNotSupportedError from .. import debug -if sys.version_info < (3, 10): - NoneType = type(None) -else: - from types import NoneType - -try: - import nest_asyncio - has_nest_asyncio = True -except ImportError: - has_nest_asyncio = False -try: - import uvloop - has_uvloop = True -except ImportError: - has_uvloop = False - # Set Windows event loop policy for better compatibility with asyncio and curl_cffi if sys.platform == 'win32': try: @@ -41,26 +25,6 @@ if sys.platform == 'win32': except ImportError: pass -def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]: - try: - loop = asyncio.get_running_loop() - # Do not patch uvloop loop because its incompatible. - if has_uvloop: - if isinstance(loop, uvloop.Loop): - return loop - if not hasattr(loop.__class__, "_nest_patched"): - if has_nest_asyncio: - nest_asyncio.apply(loop) - elif check_nested: - raise NestAsyncioError('Install "nest_asyncio" package | pip install -U nest_asyncio') - return loop - except RuntimeError: - pass - -# Fix for RuntimeError: async generator ignored GeneratorExit -async def await_callback(callback: Callable): - return await callback() - class AbstractProvider(BaseProvider): """ Abstract class for providing asynchronous functionality to derived classes. @@ -136,7 +100,6 @@ class AbstractProvider(BaseProvider): return f"g4f.Provider.{cls.__name__} supports: ({args}\n)" - class AsyncProvider(AbstractProvider): """ Provides asynchronous functionality for creating completions. @@ -218,25 +181,9 @@ class AsyncGeneratorProvider(AsyncProvider): Returns: CreateResult: The result of the streaming completion creation. """ - loop = get_running_loop(check_nested=False) - new_loop = False - if loop is None: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - new_loop = True - - generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) - gen = generator.__aiter__() - - try: - while True: - yield loop.run_until_complete(await_callback(gen.__anext__)) - except StopAsyncIteration: - pass - finally: - if new_loop: - loop.close() - asyncio.set_event_loop(None) + return to_sync_generator( + cls.create_async_generator(model, messages, stream=stream, **kwargs) + ) @classmethod async def create_async( |