From 307c8f53e74f1668282a6842ee0de672857b49fa Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Sun, 25 Feb 2024 09:41:39 +0100 Subject: Custom api_base for GeminiPro --- g4f/Provider/GeminiPro.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/g4f/Provider/GeminiPro.py b/g4f/Provider/GeminiPro.py index e1738dc8..792cd5d1 100644 --- a/g4f/Provider/GeminiPro.py +++ b/g4f/Provider/GeminiPro.py @@ -13,6 +13,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): url = "https://ai.google.dev" working = True supports_message_history = True + needs_auth = True default_model = "gemini-pro" models = ["gemini-pro", "gemini-pro-vision"] @@ -24,19 +25,24 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): stream: bool = False, proxy: str = None, api_key: str = None, + api_base: str = None, image: ImageType = None, **kwargs ) -> AsyncResult: model = "gemini-pro-vision" if not model and image else model model = cls.get_model(model) - if not api_key: - raise MissingAuthError('Missing "api_key" for auth') - headers = { - "Content-Type": "application/json", - } - async with ClientSession(headers=headers) as session: - method = "streamGenerateContent" if stream else "generateContent" - url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:{method}" + + if not api_key and not api_base: + raise MissingAuthError('Missing "api_key" or "api_base"') + if not api_base: + api_base = f"https://generativelanguage.googleapis.com/v1beta" + + method = "streamGenerateContent" if stream else "generateContent" + url = f"{api_base.rstrip('/')}/models/{model}:{method}" + if api_key: + url += f"?key={api_key}" + + async with ClientSession() as session: contents = [ { "role": "model" if message["role"] == "assistant" else message["role"], @@ -62,7 +68,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): "topK": kwargs.get("top_k"), } } - async with session.post(url, params={"key": api_key}, json=data, proxy=proxy) as response: + async with session.post(url, json=data, proxy=proxy) as response: if not response.ok: data = await response.json() raise RuntimeError(data[0]["error"]["message"]) @@ -78,7 +84,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): yield data["candidates"][0]["content"]["parts"][0]["text"] except: data = data.decode() if isinstance(data, bytes) else data - raise RuntimeError(f"Read text failed. data: {data}") + raise RuntimeError(f"Read chunk failed. data: {data}") lines = [] else: lines.append(chunk) -- cgit v1.2.3