summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/ReplicateHome.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/ReplicateHome.py (renamed from g4f/Provider/ReplicateImage.py)86
1 files changed, 65 insertions, 21 deletions
diff --git a/g4f/Provider/ReplicateImage.py b/g4f/Provider/ReplicateHome.py
index cc3943d7..e6c8d2d3 100644
--- a/g4f/Provider/ReplicateImage.py
+++ b/g4f/Provider/ReplicateHome.py
@@ -1,32 +1,67 @@
from __future__ import annotations
-
+from typing import Generator, Optional, Dict, Any, Union, List
import random
import asyncio
+import base64
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..typing import AsyncResult, Messages
from ..requests import StreamSession, raise_for_status
-from ..image import ImageResponse
from ..errors import ResponseError
+from ..image import ImageResponse
-class ReplicateImage(AsyncGeneratorProvider, ProviderModelMixin):
+class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://replicate.com"
parent = "Replicate"
working = True
- default_model = 'stability-ai/sdxl'
- default_versions = [
- "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
- "2b017d9b67edd2ee1401238df49d75da53c523f36e363881e057f5dc3ed3c5b2"
+ default_model = 'stability-ai/stable-diffusion-3'
+ models = [
+ # Models for image generation
+ 'stability-ai/stable-diffusion-3',
+ 'bytedance/sdxl-lightning-4step',
+ 'playgroundai/playground-v2.5-1024px-aesthetic',
+
+ # Models for image generation
+ 'meta/meta-llama-3-70b-instruct',
+ 'mistralai/mixtral-8x7b-instruct-v0.1',
+ 'google-deepmind/gemma-2b-it',
]
- image_models = [default_model]
+
+ versions = {
+ # Model versions for generating images
+ 'stability-ai/stable-diffusion-3': [
+ "527d2a6296facb8e47ba1eaf17f142c240c19a30894f437feee9b91cc29d8e4f"
+ ],
+ 'bytedance/sdxl-lightning-4step': [
+ "5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f"
+ ],
+ 'playgroundai/playground-v2.5-1024px-aesthetic': [
+ "a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24"
+ ],
+
+
+ # Model versions for text generation
+ 'meta/meta-llama-3-70b-instruct': [
+ "dp-cf04fe09351e25db628e8b6181276547"
+ ],
+ 'mistralai/mixtral-8x7b-instruct-v0.1': [
+ "dp-89e00f489d498885048e94f9809fbc76"
+ ],
+ 'google-deepmind/gemma-2b-it': [
+ "dff94eaf770e1fc211e425a50b51baa8e4cac6c39ef074681f9e39d778773626"
+ ]
+ }
+
+ image_models = {"stability-ai/stable-diffusion-3", "bytedance/sdxl-lightning-4step", "playgroundai/playground-v2.5-1024px-aesthetic"}
+ text_models = {"meta/meta-llama-3-70b-instruct", "mistralai/mixtral-8x7b-instruct-v0.1", "google-deepmind/gemma-2b-it"}
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
- **kwargs
- ) -> AsyncResult:
+ **kwargs: Any
+ ) -> Generator[Union[str, ImageResponse], None, None]:
yield await cls.create_async(messages[-1]["content"], model, **kwargs)
@classmethod
@@ -34,13 +69,13 @@ class ReplicateImage(AsyncGeneratorProvider, ProviderModelMixin):
cls,
prompt: str,
model: str,
- api_key: str = None,
- proxy: str = None,
+ api_key: Optional[str] = None,
+ proxy: Optional[str] = None,
timeout: int = 180,
- version: str = None,
- extra_data: dict = {},
- **kwargs
- ) -> ImageResponse:
+ version: Optional[str] = None,
+ extra_data: Dict[str, Any] = {},
+ **kwargs: Any
+ ) -> Union[str, ImageResponse]:
headers = {
'Accept-Encoding': 'gzip, deflate, br',
'Accept-Language': 'en-US',
@@ -55,10 +90,12 @@ class ReplicateImage(AsyncGeneratorProvider, ProviderModelMixin):
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"macOS"',
}
+
if version is None:
- version = random.choice(cls.default_versions)
+ version = random.choice(cls.versions.get(model, []))
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
+
async with StreamSession(
proxies={"all": proxy},
headers=headers,
@@ -81,6 +118,7 @@ class ReplicateImage(AsyncGeneratorProvider, ProviderModelMixin):
result = await response.json()
if "id" not in result:
raise ResponseError(f"Invalid response: {result}")
+
while True:
if api_key is None:
url = f"https://homepage.replicate.com/api/poll?id={result['id']}"
@@ -92,7 +130,13 @@ class ReplicateImage(AsyncGeneratorProvider, ProviderModelMixin):
if "status" not in result:
raise ResponseError(f"Invalid response: {result}")
if result["status"] == "succeeded":
- images = result['output']
- images = images[0] if len(images) == 1 else images
- return ImageResponse(images, prompt)
- await asyncio.sleep(0.5) \ No newline at end of file
+ output = result['output']
+ if model in cls.text_models:
+ return ''.join(output) if isinstance(output, list) else output
+ elif model in cls.image_models:
+ images: List[Any] = output
+ images = images[0] if len(images) == 1 else images
+ return ImageResponse(images, prompt)
+ elif result["status"] == "failed":
+ raise ResponseError(f"Prediction failed: {result}")
+ await asyncio.sleep(0.5)