From 21eecea02fe2c6389b49594853920f1ff2959e00 Mon Sep 17 00:00:00 2001 From: kqlio67 <> Date: Fri, 7 Feb 2025 16:50:35 +0200 Subject: Optimization and bug fixes for PollinationsAI provider: improved error handling, model validation, and HTTP request processing --- g4f/Provider/PollinationsAI.py | 137 +++++++++++++++++++---------------------- 1 file changed, 63 insertions(+), 74 deletions(-) diff --git a/g4f/Provider/PollinationsAI.py b/g4f/Provider/PollinationsAI.py index 58788bc6..6a5c6ec5 100644 --- a/g4f/Provider/PollinationsAI.py +++ b/g4f/Provider/PollinationsAI.py @@ -73,24 +73,23 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): @classmethod def get_models(cls, **kwargs): if not cls.text_models or not cls.image_models: - image_url = "https://image.pollinations.ai/models" - image_response = requests.get(image_url) - raise_for_status(image_response) - new_image_models = image_response.json() - - cls.image_models = list(dict.fromkeys([*cls.extra_image_models, *new_image_models])) - cls.extra_image_models = cls.image_models.copy() - - text_url = "https://text.pollinations.ai/models" - text_response = requests.get(text_url) - raise_for_status(text_response) - original_text_models = [model.get("name") for model in text_response.json()] - - combined_text = cls.extra_text_models + [ - model for model in original_text_models - if model not in cls.extra_text_models - ] - cls.text_models = list(dict.fromkeys(combined_text)) + try: + image_response = requests.get("https://image.pollinations.ai/models") + image_response.raise_for_status() + new_image_models = image_response.json() + cls.image_models = list(dict.fromkeys([*cls.extra_image_models, *new_image_models])) + + text_response = requests.get("https://text.pollinations.ai/models") + text_response.raise_for_status() + original_text_models = [model.get("name") for model in text_response.json()] + + combined_text = cls.extra_text_models + [ + model for model in original_text_models + if model not in cls.extra_text_models + ] + cls.text_models = list(dict.fromkeys(combined_text)) + except Exception as e: + raise RuntimeError(f"Failed to fetch models: {e}") from e return cls.text_models + cls.image_models @@ -122,13 +121,14 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): try: model = cls.get_model(model) except ModelNotFoundError: - if model not in cls.extra_image_models: + if model not in cls.image_models: raise + if not cache and seed is None: seed = random.randint(0, 10000) - if model in cls.image_models or model in cls.extra_image_models: - async for chunk in cls._generate_image( + if model in cls.image_models: + async for chunk in cls._generate_image( model=model, prompt=format_image_prompt(messages, prompt), proxy=proxy, @@ -172,25 +172,25 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): safe: bool ) -> AsyncResult: params = { - "seed": seed, - "width": width, - "height": height, + "seed": str(seed) if seed is not None else None, + "width": str(width), + "height": str(height), "model": model, - "nologo": nologo, - "private": private, - "enhance": enhance, - "safe": safe + "nologo": str(nologo).lower(), + "private": str(private).lower(), + "enhance": str(enhance).lower(), + "safe": str(safe).lower() } - params = {k: json.dumps(v) if isinstance(v, bool) else str(v) for k, v in params.items() if v is not None} - params = "&".join( "%s=%s" % (key, quote_plus(params[key])) - for key in params.keys()) - url = f"{cls.image_api_endpoint}prompt/{quote_plus(prompt)}?{params}" + params = {k: v for k, v in params.items() if v is not None} + query = "&".join(f"{k}={quote_plus(v)}" for k, v in params.items()) + url = f"{cls.image_api_endpoint}prompt/{quote_plus(prompt)}?{query}" yield ImagePreview(url, prompt) + async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session: - async with session.head(url) as response: - if response.status != 500: - await raise_for_status(response) - yield ImageResponse(str(response.url), prompt) + async with session.get(url, allow_redirects=True) as response: + await raise_for_status(response) + image_url = str(response.url) + yield ImageResponse(image_url, prompt) @classmethod async def _generate_text( @@ -207,60 +207,49 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin): seed: Optional[int], cache: bool ) -> AsyncResult: - jsonMode = False - if response_format is not None and "type" in response_format: - if response_format["type"] == "json_object": - jsonMode = True + json_mode = False + if response_format and response_format.get("type") == "json_object": + json_mode = True - if images is not None and messages: + if images and messages: last_message = messages[-1].copy() - last_message["content"] = [ - *[{ + image_content = [ + { "type": "image_url", "image_url": {"url": to_data_uri(image)} - } for image, _ in images], - { - "type": "text", - "text": messages[-1]["content"] } + for image, _ in images ] + last_message["content"] = image_content + [{"type": "text", "text": last_message["content"]}] messages[-1] = last_message async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session: - data = { + data = filter_none(**{ "messages": messages, "model": model, "temperature": temperature, "presence_penalty": presence_penalty, "top_p": top_p, "frequency_penalty": frequency_penalty, - "jsonMode": jsonMode, + "jsonMode": json_mode, "stream": False, "seed": seed, "cache": cache - } - async with session.post(cls.text_api_endpoint, json=filter_none(**data)) as response: + }) + + async with session.post(cls.text_api_endpoint, json=data) as response: await raise_for_status(response) - async for line in response.content: - decoded_chunk = line.decode(errors="replace") - if "data: [DONE]" in decoded_chunk: - break - try: - json_str = decoded_chunk.replace("data:", "").strip() - data = json.loads(json_str) - choice = data["choices"][0] - message = choice.get("message") or choice.get("delta", {}) - - if "usage" in data: - yield Usage(**data["usage"]) - content = message.get("content", "") - if content: - yield content.replace("\\(", "(").replace("\\)", ")") - if "finish_reason" in choice and choice["finish_reason"]: - yield FinishReason(choice["finish_reason"]) - break - except json.JSONDecodeError: - yield decoded_chunk.strip() - except Exception as e: - yield FinishReason("error") - break + result = await response.json() + choice = result["choices"][0] + message = choice.get("message", {}) + content = message.get("content", "") + + if content: + yield content.replace("\\(", "(").replace("\\)", ")") + + if "usage" in result: + yield Usage(**result["usage"]) + + finish_reason = choice.get("finish_reason") + if finish_reason: + yield FinishReason(finish_reason) -- cgit v1.2.3