summaryrefslogtreecommitdiffstats
path: root/g4f/client.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/client.py24
1 files changed, 16 insertions, 8 deletions
diff --git a/g4f/client.py b/g4f/client.py
index 595beaf9..750c623f 100644
--- a/g4f/client.py
+++ b/g4f/client.py
@@ -7,7 +7,7 @@ import random
import string
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
-from .typing import Union, Generator, Messages, ImageType
+from .typing import Union, Iterator, Messages, ImageType
from .providers.types import BaseProvider, ProviderType
from .image import ImageResponse as ImageProviderResponse
from .Provider.BingCreateImages import BingCreateImages
@@ -17,7 +17,7 @@ from . import get_model_and_provider, get_last_provider
ImageProvider = Union[BaseProvider, object]
Proxies = Union[dict, str]
-IterResponse = Generator[Union[ChatCompletion, ChatCompletionChunk], None, None]
+IterResponse = Iterator[Union[ChatCompletion, ChatCompletionChunk]]
def read_json(text: str) -> dict:
"""
@@ -110,6 +110,12 @@ class Client():
elif "https" in self.proxies:
return self.proxies["https"]
+def filter_none(**kwargs):
+ for key in list(kwargs.keys()):
+ if kwargs[key] is None:
+ del kwargs[key]
+ return kwargs
+
class Completions():
def __init__(self, client: Client, provider: ProviderType = None):
self.client: Client = client
@@ -126,7 +132,7 @@ class Completions():
stop: Union[list[str], str] = None,
api_key: str = None,
**kwargs
- ) -> Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]]:
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
model, provider = get_model_and_provider(
model,
self.provider if provider is None else provider,
@@ -135,11 +141,13 @@ class Completions():
)
stop = [stop] if isinstance(stop, str) else stop
response = provider.create_completion(
- model, messages, stream,
- proxy=self.client.get_proxy(),
- max_tokens=max_tokens,
- stop=stop,
- api_key=self.client.api_key if api_key is None else api_key,
+ model, messages, stream,
+ **filter_none(
+ proxy=self.client.get_proxy(),
+ max_tokens=max_tokens,
+ stop=stop,
+ api_key=self.client.api_key if api_key is None else api_key
+ ),
**kwargs
)
response = iter_response(response, stream, response_format, max_tokens, stop)