summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/GeminiPro.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/GeminiPro.py18
1 files changed, 8 insertions, 10 deletions
diff --git a/g4f/Provider/GeminiPro.py b/g4f/Provider/GeminiPro.py
index 1c5487b1..a22304d5 100644
--- a/g4f/Provider/GeminiPro.py
+++ b/g4f/Provider/GeminiPro.py
@@ -26,38 +26,35 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
stream: bool = False,
proxy: str = None,
api_key: str = None,
- api_base: str = None,
- use_auth_header: bool = True,
+ api_base: str = "https://generativelanguage.googleapis.com/v1beta",
+ use_auth_header: bool = False,
image: ImageType = None,
connector: BaseConnector = None,
**kwargs
) -> AsyncResult:
- model = "gemini-pro-vision" if not model and image else model
+ model = "gemini-pro-vision" if model is None and image is not None else model
model = cls.get_model(model)
if not api_key:
raise MissingAuthError('Missing "api_key"')
headers = params = None
- if api_base and use_auth_header:
+ if use_auth_header:
headers = {"Authorization": f"Bearer {api_key}"}
else:
params = {"key": api_key}
- if not api_base:
- api_base = f"https://generativelanguage.googleapis.com/v1beta"
-
method = "streamGenerateContent" if stream else "generateContent"
url = f"{api_base.rstrip('/')}/models/{model}:{method}"
async with ClientSession(headers=headers, connector=get_connector(connector, proxy)) as session:
contents = [
{
- "role": "model" if message["role"] == "assistant" else message["role"],
+ "role": "model" if message["role"] == "assistant" else "user",
"parts": [{"text": message["content"]}]
}
for message in messages
]
- if image:
+ if image is not None:
image = to_bytes(image)
contents[-1]["parts"].append({
"inline_data": {
@@ -87,7 +84,8 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
lines = [b"{\n"]
elif chunk == b",\r\n" or chunk == b"]":
try:
- data = json.loads(b"".join(lines))
+ data = b"".join(lines)
+ data = json.loads(data)
yield data["candidates"][0]["content"]["parts"][0]["text"]
except:
data = data.decode() if isinstance(data, bytes) else data