summaryrefslogtreecommitdiffstats
path: root/g4f/__init__.py
diff options
context:
space:
mode:
authorHeiner Lohaus <hlohaus@users.noreply.github.com>2023-11-20 14:02:51 +0100
committerHeiner Lohaus <hlohaus@users.noreply.github.com>2023-11-20 14:02:51 +0100
commita9f15815cd3a7ce4567c924868414e94174af222 (patch)
tree27a10131031a4e737d25764735513ee5278a2690 /g4f/__init__.py
parentAdd translate readme module (diff)
downloadgpt4free-a9f15815cd3a7ce4567c924868414e94174af222.tar
gpt4free-a9f15815cd3a7ce4567c924868414e94174af222.tar.gz
gpt4free-a9f15815cd3a7ce4567c924868414e94174af222.tar.bz2
gpt4free-a9f15815cd3a7ce4567c924868414e94174af222.tar.lz
gpt4free-a9f15815cd3a7ce4567c924868414e94174af222.tar.xz
gpt4free-a9f15815cd3a7ce4567c924868414e94174af222.tar.zst
gpt4free-a9f15815cd3a7ce4567c924868414e94174af222.zip
Diffstat (limited to '')
-rw-r--r--g4f/__init__.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/g4f/__init__.py b/g4f/__init__.py
index faef7923..2c9ef7d7 100644
--- a/g4f/__init__.py
+++ b/g4f/__init__.py
@@ -1,8 +1,8 @@
from __future__ import annotations
from requests import get
from .models import Model, ModelUtils, _all_models
-from .Provider import BaseProvider, RetryProvider
-from .typing import Messages, CreateResult, Union, List
+from .Provider import BaseProvider, AsyncGeneratorProvider, RetryProvider
+from .typing import Messages, CreateResult, AsyncResult, Union, List
from . import debug
version = '0.1.8.7'
@@ -80,13 +80,15 @@ class ChatCompletion:
messages : Messages,
provider : Union[type[BaseProvider], None] = None,
stream : bool = False,
- ignored : List[str] = None, **kwargs) -> str:
-
- if stream:
- raise ValueError('"create_async" does not support "stream" argument')
-
+ ignored : List[str] = None,
+ **kwargs) -> Union[AsyncResult, str]:
model, provider = get_model_and_provider(model, provider, False, ignored)
+ if stream:
+ if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
+ return await provider.create_async_generator(model.name, messages, **kwargs)
+ raise ValueError(f'{provider.__name__} does not support "stream" argument')
+
return await provider.create_async(model.name, messages, **kwargs)
class Completion: