diff options
author | kqlio67 <kqlio67@users.noreply.github.com> | 2024-10-22 13:50:33 +0200 |
---|---|---|
committer | kqlio67 <kqlio67@users.noreply.github.com> | 2024-10-22 13:50:33 +0200 |
commit | 144c7b492256083990b06a70d8b0bc9562ec230c (patch) | |
tree | 8255b1f8d8f1b4f328e7803e126e06d46ad10350 | |
parent | Updated (g4f/Provider/nexra/) (diff) | |
download | gpt4free-144c7b492256083990b06a70d8b0bc9562ec230c.tar gpt4free-144c7b492256083990b06a70d8b0bc9562ec230c.tar.gz gpt4free-144c7b492256083990b06a70d8b0bc9562ec230c.tar.bz2 gpt4free-144c7b492256083990b06a70d8b0bc9562ec230c.tar.lz gpt4free-144c7b492256083990b06a70d8b0bc9562ec230c.tar.xz gpt4free-144c7b492256083990b06a70d8b0bc9562ec230c.tar.zst gpt4free-144c7b492256083990b06a70d8b0bc9562ec230c.zip |
-rw-r--r-- | g4f/Provider/nexra/NexraSD15.py | 70 | ||||
-rw-r--r-- | g4f/models.py | 9 |
2 files changed, 45 insertions, 34 deletions
diff --git a/g4f/Provider/nexra/NexraSD15.py b/g4f/Provider/nexra/NexraSD15.py index 03b35013..860a132f 100644 --- a/g4f/Provider/nexra/NexraSD15.py +++ b/g4f/Provider/nexra/NexraSD15.py @@ -1,18 +1,16 @@ from __future__ import annotations import json -from aiohttp import ClientSession +import requests +from ...typing import CreateResult, Messages +from ..base_provider import ProviderModelMixin, AbstractProvider from ...image import ImageResponse -from ...typing import AsyncResult, Messages -from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin - - -class NexraSD15(AsyncGeneratorProvider, ProviderModelMixin): +class NexraSD15(AbstractProvider, ProviderModelMixin): label = "Nexra Stable Diffusion 1.5" url = "https://nexra.aryahcr.cc/documentation/stable-diffusion/en" api_endpoint = "https://nexra.aryahcr.cc/api/image/complements" - working = False + working = True default_model = 'stablediffusion-1.5' models = [default_model] @@ -29,42 +27,46 @@ class NexraSD15(AsyncGeneratorProvider, ProviderModelMixin): return cls.model_aliases[model] else: return cls.default_model - + @classmethod - async def create_async_generator( + def create_completion( cls, model: str, messages: Messages, proxy: str = None, response: str = "url", # base64 or url **kwargs - ) -> AsyncResult: + ) -> CreateResult: model = cls.get_model(model) - + headers = { - "Content-Type": "application/json", + 'Content-Type': 'application/json' } - async with ClientSession(headers=headers) as session: - data = { - "prompt": messages, - "model": model, - "response": response - } - async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response: - response.raise_for_status() - text_response = await response.text() - - # Clean the response by removing unexpected characters - cleaned_response = text_response.strip('__') + + data = { + "prompt": messages[-1]["content"], + "model": model, + "response": response + } + + response = requests.post(cls.api_endpoint, headers=headers, json=data) - if not cleaned_response.strip(): - raise ValueError("Received an empty response from the server.") + result = cls.process_response(response) + yield result - try: - json_response = json.loads(cleaned_response) - image_url = json_response.get("images", [])[0] - # Create an ImageResponse object - image_response = ImageResponse(images=image_url, alt="Generated Image") - yield image_response - except json.JSONDecodeError: - raise ValueError("Unable to decode JSON from the received text response.") + @classmethod + def process_response(cls, response): + if response.status_code == 200: + try: + content = response.text.strip() + content = content.lstrip('_') + data = json.loads(content) + if data.get('status') and data.get('images'): + image_url = data['images'][0] + return ImageResponse(images=[image_url], alt="Generated Image") + else: + return "Error: No image URL found in the response" + except json.JSONDecodeError as e: + return f"Error: Unable to decode JSON response. Details: {str(e)}" + else: + return f"Error: {response.status_code}, Response: {response.text}" diff --git a/g4f/models.py b/g4f/models.py index 6fa2fca1..6f36892c 100644 --- a/g4f/models.py +++ b/g4f/models.py @@ -52,6 +52,7 @@ from .Provider import ( NexraGeminiPro, NexraMidjourney, NexraQwen, + NexraSD15, OpenaiChat, PerplexityLabs, Pi, @@ -740,6 +741,13 @@ sdxl = Model( ) +sd_1_5 = Model( + name = 'sd-1.5', + base_provider = 'Stability AI', + best_provider = NexraSD15 + +) + sd_3 = Model( name = 'sd-3', base_provider = 'Stability AI', @@ -1095,6 +1103,7 @@ class ModelUtils: ### Stability AI ### 'sdxl': sdxl, +'sd-1.5': sd_1_5, 'sd-3': sd_3, |