From 307c8f53e74f1668282a6842ee0de672857b49fa Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Sun, 25 Feb 2024 09:41:39 +0100 Subject: Custom api_base for GeminiPro --- g4f/Provider/GeminiPro.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) (limited to 'g4f/Provider') diff --git a/g4f/Provider/GeminiPro.py b/g4f/Provider/GeminiPro.py index e1738dc8..792cd5d1 100644 --- a/g4f/Provider/GeminiPro.py +++ b/g4f/Provider/GeminiPro.py @@ -13,6 +13,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): url = "https://ai.google.dev" working = True supports_message_history = True + needs_auth = True default_model = "gemini-pro" models = ["gemini-pro", "gemini-pro-vision"] @@ -24,19 +25,24 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): stream: bool = False, proxy: str = None, api_key: str = None, + api_base: str = None, image: ImageType = None, **kwargs ) -> AsyncResult: model = "gemini-pro-vision" if not model and image else model model = cls.get_model(model) - if not api_key: - raise MissingAuthError('Missing "api_key" for auth') - headers = { - "Content-Type": "application/json", - } - async with ClientSession(headers=headers) as session: - method = "streamGenerateContent" if stream else "generateContent" - url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:{method}" + + if not api_key and not api_base: + raise MissingAuthError('Missing "api_key" or "api_base"') + if not api_base: + api_base = f"https://generativelanguage.googleapis.com/v1beta" + + method = "streamGenerateContent" if stream else "generateContent" + url = f"{api_base.rstrip('/')}/models/{model}:{method}" + if api_key: + url += f"?key={api_key}" + + async with ClientSession() as session: contents = [ { "role": "model" if message["role"] == "assistant" else message["role"], @@ -62,7 +68,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): "topK": kwargs.get("top_k"), } } - async with session.post(url, params={"key": api_key}, json=data, proxy=proxy) as response: + async with session.post(url, json=data, proxy=proxy) as response: if not response.ok: data = await response.json() raise RuntimeError(data[0]["error"]["message"]) @@ -78,7 +84,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): yield data["candidates"][0]["content"]["parts"][0]["text"] except: data = data.decode() if isinstance(data, bytes) else data - raise RuntimeError(f"Read text failed. data: {data}") + raise RuntimeError(f"Read chunk failed. data: {data}") lines = [] else: lines.append(chunk) -- cgit v1.2.3 From b4b74c991be143c701795836cfc9fce56da4a497 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Sun, 25 Feb 2024 15:48:03 +0100 Subject: gui: remove cursor on errors Add auth header to GeminiPro provider --- g4f/Provider/GeminiPro.py | 16 +++++++++------- g4f/Provider/Liaobots.py | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) (limited to 'g4f/Provider') diff --git a/g4f/Provider/GeminiPro.py b/g4f/Provider/GeminiPro.py index 792cd5d1..87ded3ac 100644 --- a/g4f/Provider/GeminiPro.py +++ b/g4f/Provider/GeminiPro.py @@ -32,17 +32,20 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): model = "gemini-pro-vision" if not model and image else model model = cls.get_model(model) - if not api_key and not api_base: - raise MissingAuthError('Missing "api_key" or "api_base"') + if not api_key: + raise MissingAuthError('Missing "api_key"') if not api_base: api_base = f"https://generativelanguage.googleapis.com/v1beta" method = "streamGenerateContent" if stream else "generateContent" url = f"{api_base.rstrip('/')}/models/{model}:{method}" - if api_key: + headers = None + if api_base: + headers = {f"Authorization": "Bearer {api_key}"} + else: url += f"?key={api_key}" - async with ClientSession() as session: + async with ClientSession(headers=headers) as session: contents = [ { "role": "model" if message["role"] == "assistant" else message["role"], @@ -79,12 +82,11 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): lines = [b"{\n"] elif chunk == b",\r\n" or chunk == b"]": try: - data = b"".join(lines) - data = json.loads(data) + data = json.loads(b"".join(lines)) yield data["candidates"][0]["content"]["parts"][0]["text"] except: data = data.decode() if isinstance(data, bytes) else data - raise RuntimeError(f"Read chunk failed. data: {data}") + raise RuntimeError(f"Read chunk failed: {data}") lines = [] else: lines.append(chunk) diff --git a/g4f/Provider/Liaobots.py b/g4f/Provider/Liaobots.py index e93642ba..54bf7f2e 100644 --- a/g4f/Provider/Liaobots.py +++ b/g4f/Provider/Liaobots.py @@ -78,7 +78,7 @@ class Liaobots(AsyncGeneratorProvider, ProviderModelMixin): supports_gpt_35_turbo = True supports_gpt_4 = True default_model = "gpt-3.5-turbo" - models = [m for m in models] + models = list(models) model_aliases = { "claude-v2": "claude-2" } -- cgit v1.2.3