From 775a0c43a0856f57dbd847a73b9d20b7cddb5063 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Sat, 24 Feb 2024 01:31:17 +0100 Subject: Add help me coding guide Add MissingAuthError in GeminiPro --- g4f/Provider/GeminiPro.py | 19 ++++++++++--------- g4f/__init__.py | 2 +- 2 files changed, 11 insertions(+), 10 deletions(-) (limited to 'g4f') diff --git a/g4f/Provider/GeminiPro.py b/g4f/Provider/GeminiPro.py index b296f253..e1738dc8 100644 --- a/g4f/Provider/GeminiPro.py +++ b/g4f/Provider/GeminiPro.py @@ -7,7 +7,7 @@ from aiohttp import ClientSession from ..typing import AsyncResult, Messages, ImageType from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..image import to_bytes, is_accepted_format - +from ..errors import MissingAuthError class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): url = "https://ai.google.dev" @@ -29,7 +29,8 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): ) -> AsyncResult: model = "gemini-pro-vision" if not model and image else model model = cls.get_model(model) - api_key = api_key if api_key else kwargs.get("access_token") + if not api_key: + raise MissingAuthError('Missing "api_key" for auth') headers = { "Content-Type": "application/json", } @@ -53,13 +54,13 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): }) data = { "contents": contents, - # "generationConfig": { - # "stopSequences": kwargs.get("stop"), - # "temperature": kwargs.get("temperature"), - # "maxOutputTokens": kwargs.get("max_tokens"), - # "topP": kwargs.get("top_p"), - # "topK": kwargs.get("top_k"), - # } + "generationConfig": { + "stopSequences": kwargs.get("stop"), + "temperature": kwargs.get("temperature"), + "maxOutputTokens": kwargs.get("max_tokens"), + "topP": kwargs.get("top_p"), + "topK": kwargs.get("top_k"), + } } async with session.post(url, params={"key": api_key}, json=data, proxy=proxy) as response: if not response.ok: diff --git a/g4f/__init__.py b/g4f/__init__.py index 6716c727..5df942ae 100644 --- a/g4f/__init__.py +++ b/g4f/__init__.py @@ -42,7 +42,7 @@ def get_model_and_provider(model : Union[Model, str], if debug.version_check: debug.version_check = False version.utils.check_version() - + if isinstance(provider, str): if " " in provider: provider_list = [ProviderUtils.convert[p] for p in provider.split() if p in ProviderUtils.convert] -- cgit v1.2.3