diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/client.py | 24 |
1 files changed, 16 insertions, 8 deletions
diff --git a/g4f/client.py b/g4f/client.py index 595beaf9..750c623f 100644 --- a/g4f/client.py +++ b/g4f/client.py @@ -7,7 +7,7 @@ import random import string from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse -from .typing import Union, Generator, Messages, ImageType +from .typing import Union, Iterator, Messages, ImageType from .providers.types import BaseProvider, ProviderType from .image import ImageResponse as ImageProviderResponse from .Provider.BingCreateImages import BingCreateImages @@ -17,7 +17,7 @@ from . import get_model_and_provider, get_last_provider ImageProvider = Union[BaseProvider, object] Proxies = Union[dict, str] -IterResponse = Generator[Union[ChatCompletion, ChatCompletionChunk], None, None] +IterResponse = Iterator[Union[ChatCompletion, ChatCompletionChunk]] def read_json(text: str) -> dict: """ @@ -110,6 +110,12 @@ class Client(): elif "https" in self.proxies: return self.proxies["https"] +def filter_none(**kwargs): + for key in list(kwargs.keys()): + if kwargs[key] is None: + del kwargs[key] + return kwargs + class Completions(): def __init__(self, client: Client, provider: ProviderType = None): self.client: Client = client @@ -126,7 +132,7 @@ class Completions(): stop: Union[list[str], str] = None, api_key: str = None, **kwargs - ) -> Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]]: + ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: model, provider = get_model_and_provider( model, self.provider if provider is None else provider, @@ -135,11 +141,13 @@ class Completions(): ) stop = [stop] if isinstance(stop, str) else stop response = provider.create_completion( - model, messages, stream, - proxy=self.client.get_proxy(), - max_tokens=max_tokens, - stop=stop, - api_key=self.client.api_key if api_key is None else api_key, + model, messages, stream, + **filter_none( + proxy=self.client.get_proxy(), + max_tokens=max_tokens, + stop=stop, + api_key=self.client.api_key if api_key is None else api_key + ), **kwargs ) response = iter_response(response, stream, response_format, max_tokens, stop) |