summaryrefslogtreecommitdiffstats
path: root/g4f/providers
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/providers')
-rw-r--r--g4f/providers/asyncio.py65
-rw-r--r--g4f/providers/base_provider.py63
2 files changed, 70 insertions, 58 deletions
diff --git a/g4f/providers/asyncio.py b/g4f/providers/asyncio.py
new file mode 100644
index 00000000..cf0ce1a0
--- /dev/null
+++ b/g4f/providers/asyncio.py
@@ -0,0 +1,65 @@
+from __future__ import annotations
+
+import asyncio
+from asyncio import AbstractEventLoop, runners
+from typing import Union, Callable, AsyncGenerator, Generator
+
+from ..errors import NestAsyncioError
+
+try:
+ import nest_asyncio
+ has_nest_asyncio = True
+except ImportError:
+ has_nest_asyncio = False
+try:
+ import uvloop
+ has_uvloop = True
+except ImportError:
+ has_uvloop = False
+
+def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]:
+ try:
+ loop = asyncio.get_running_loop()
+ # Do not patch uvloop loop because its incompatible.
+ if has_uvloop:
+ if isinstance(loop, uvloop.Loop):
+ return loop
+ if not hasattr(loop.__class__, "_nest_patched"):
+ if has_nest_asyncio:
+ nest_asyncio.apply(loop)
+ elif check_nested:
+ raise NestAsyncioError('Install "nest_asyncio" package | pip install -U nest_asyncio')
+ return loop
+ except RuntimeError:
+ pass
+
+# Fix for RuntimeError: async generator ignored GeneratorExit
+async def await_callback(callback: Callable):
+ return await callback()
+
+async def async_generator_to_list(generator: AsyncGenerator) -> list:
+ return [item async for item in generator]
+
+def to_sync_generator(generator: AsyncGenerator) -> Generator:
+ loop = get_running_loop(check_nested=False)
+ new_loop = False
+ if loop is None:
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ new_loop = True
+ gen = generator.__aiter__()
+ try:
+ while True:
+ yield loop.run_until_complete(await_callback(gen.__anext__))
+ except StopAsyncIteration:
+ pass
+ finally:
+ if new_loop:
+ try:
+ runners._cancel_all_tasks(loop)
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ if hasattr(loop, "shutdown_default_executor"):
+ loop.run_until_complete(loop.shutdown_default_executor())
+ finally:
+ asyncio.set_event_loop(None)
+ loop.close() \ No newline at end of file
diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py
index c6d0d950..e2c2f46a 100644
--- a/g4f/providers/base_provider.py
+++ b/g4f/providers/base_provider.py
@@ -7,30 +7,14 @@ from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
from abc import abstractmethod
from inspect import signature, Parameter
-from typing import Callable, Union
from ..typing import CreateResult, AsyncResult, Messages
from .types import BaseProvider
+from .asyncio import get_running_loop, to_sync_generator
from .response import FinishReason, BaseConversation, SynthesizeData
-from ..errors import NestAsyncioError, ModelNotSupportedError
+from ..errors import ModelNotSupportedError
from .. import debug
-if sys.version_info < (3, 10):
- NoneType = type(None)
-else:
- from types import NoneType
-
-try:
- import nest_asyncio
- has_nest_asyncio = True
-except ImportError:
- has_nest_asyncio = False
-try:
- import uvloop
- has_uvloop = True
-except ImportError:
- has_uvloop = False
-
# Set Windows event loop policy for better compatibility with asyncio and curl_cffi
if sys.platform == 'win32':
try:
@@ -41,26 +25,6 @@ if sys.platform == 'win32':
except ImportError:
pass
-def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]:
- try:
- loop = asyncio.get_running_loop()
- # Do not patch uvloop loop because its incompatible.
- if has_uvloop:
- if isinstance(loop, uvloop.Loop):
- return loop
- if not hasattr(loop.__class__, "_nest_patched"):
- if has_nest_asyncio:
- nest_asyncio.apply(loop)
- elif check_nested:
- raise NestAsyncioError('Install "nest_asyncio" package | pip install -U nest_asyncio')
- return loop
- except RuntimeError:
- pass
-
-# Fix for RuntimeError: async generator ignored GeneratorExit
-async def await_callback(callback: Callable):
- return await callback()
-
class AbstractProvider(BaseProvider):
"""
Abstract class for providing asynchronous functionality to derived classes.
@@ -136,7 +100,6 @@ class AbstractProvider(BaseProvider):
return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
-
class AsyncProvider(AbstractProvider):
"""
Provides asynchronous functionality for creating completions.
@@ -218,25 +181,9 @@ class AsyncGeneratorProvider(AsyncProvider):
Returns:
CreateResult: The result of the streaming completion creation.
"""
- loop = get_running_loop(check_nested=False)
- new_loop = False
- if loop is None:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- new_loop = True
-
- generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
- gen = generator.__aiter__()
-
- try:
- while True:
- yield loop.run_until_complete(await_callback(gen.__anext__))
- except StopAsyncIteration:
- pass
- finally:
- if new_loop:
- loop.close()
- asyncio.set_event_loop(None)
+ return to_sync_generator(
+ cls.create_async_generator(model, messages, stream=stream, **kwargs)
+ )
@classmethod
async def create_async(