From 69d0b09816113f0cfc46dad7f45d6e74390eb4e2 Mon Sep 17 00:00:00 2001 From: hlohaus <983577+hlohaus@users.noreply.github.com> Date: Fri, 21 Feb 2025 13:19:28 +0100 Subject: Use gradio api in flux dev --- g4f/Provider/hf_space/BlackForestLabsFlux1Dev.py | 107 ++++++++++++++--------- g4f/Provider/hf_space/G4F.py | 1 + 2 files changed, 69 insertions(+), 39 deletions(-) (limited to 'g4f/Provider/hf_space') diff --git a/g4f/Provider/hf_space/BlackForestLabsFlux1Dev.py b/g4f/Provider/hf_space/BlackForestLabsFlux1Dev.py index dbd9fa0a..63fbfbd7 100644 --- a/g4f/Provider/hf_space/BlackForestLabsFlux1Dev.py +++ b/g4f/Provider/hf_space/BlackForestLabsFlux1Dev.py @@ -1,10 +1,11 @@ from __future__ import annotations import json -from aiohttp import ClientSession +import uuid from ...typing import AsyncResult, Messages -from ...providers.response import ImageResponse, ImagePreview, JsonConversation +from ...providers.response import ImageResponse, ImagePreview, JsonConversation, Reasoning +from ...requests import StreamSession from ...errors import ResponseError from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..helper import format_image_prompt @@ -14,7 +15,7 @@ from .raise_for_status import raise_for_status class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin): url = "https://black-forest-labs-flux-1-dev.hf.space" space = "black-forest-labs/FLUX.1-dev" - api_endpoint = "/gradio_api/call/infer" + referer = f"{url}/?__theme=light" working = True @@ -24,6 +25,29 @@ class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin): image_models = [default_image_model, *model_aliases.keys()] models = image_models + @classmethod + def run(cls, method: str, session: StreamSession, conversation: JsonConversation, data: list = None): + headers = { + "accept": "application/json", + "content-type": "application/json", + "x-zerogpu-token": conversation.zerogpu_token, + "x-zerogpu-uuid": conversation.zerogpu_uuid, + "referer": cls.referer, + } + if method == "post": + return session.post(f"{cls.url}/gradio_api/queue/join?__theme=light", **{ + "headers": {k: v for k, v in headers.items() if v is not None}, + "json": {"data": data,"event_data":None,"fn_index":2,"trigger_id":4,"session_hash":conversation.session_hash} + + }) + return session.get(f"{cls.url}/gradio_api/queue/data?session_hash={conversation.session_hash}", **{ + "headers": { + "accept": "text/event-stream", + "content-type": "application/json", + "referer": cls.referer, + } + }) + @classmethod async def create_async_generator( cls, @@ -43,44 +67,49 @@ class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin): **kwargs ) -> AsyncResult: model = cls.get_model(model) - headers = { - "Content-Type": "application/json", - "Accept": "application/json", - } - async with ClientSession(headers=headers) as session: + async with StreamSession(impersonate="chrome", proxy=proxy) as session: prompt = format_image_prompt(messages, prompt) - data = { - "data": [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps] - } - if zerogpu_token is None: - zerogpu_uuid, zerogpu_token = await get_zerogpu_token(cls.space, session, JsonConversation(), cookies) - headers = { - "x-zerogpu-token": zerogpu_token, - "x-zerogpu-uuid": zerogpu_uuid, - } - headers = {k: v for k, v in headers.items() if v is not None} - async with session.post(f"{cls.url}{cls.api_endpoint}", json=data, proxy=proxy, headers=headers) as response: + data = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps] + conversation = JsonConversation(zerogpu_token=zerogpu_token, zerogpu_uuid=zerogpu_uuid, session_hash=uuid.uuid4().hex) + if conversation.zerogpu_token is None: + conversation.zerogpu_uuid, conversation.zerogpu_token = await get_zerogpu_token(cls.space, session, conversation, cookies) + async with cls.run(f"post", session, conversation, data) as response: await raise_for_status(response) - event_id = (await response.json()).get("event_id") - async with session.get(f"{cls.url}{cls.api_endpoint}/{event_id}") as event_response: + assert (await response.json()).get("event_id") + async with cls.run("get", session, conversation) as event_response: await raise_for_status(event_response) - event = None - async for chunk in event_response.content: - if chunk.startswith(b"event: "): - event = chunk[7:].decode(errors="replace").strip() + async for chunk in event_response.iter_lines(): if chunk.startswith(b"data: "): - if event == "error": - raise ResponseError(f"GPU token limit exceeded: {chunk.decode(errors='replace')}") - if event in ("complete", "generating"): - try: - data = json.loads(chunk[6:]) - if data is None: - continue - url = data[0]["url"] - except (json.JSONDecodeError, KeyError, TypeError) as e: - raise RuntimeError(f"Failed to parse image URL: {chunk.decode(errors='replace')}", e) - if event == "generating": - yield ImagePreview(url, prompt) - else: - yield ImageResponse(url, prompt) + try: + json_data = json.loads(chunk[6:]) + if json_data is None: + continue + if json_data.get('msg') == 'log': + yield Reasoning(status=json_data["log"]) + + if json_data.get('msg') == 'progress': + if 'progress_data' in json_data: + if json_data['progress_data']: + progress = json_data['progress_data'][0] + yield Reasoning(status=f"{progress['desc']} {progress['index']}/{progress['length']}") + else: + yield Reasoning(status=f"Generating") + + elif json_data.get('msg') == 'process_generating': + for item in json_data['output']['data'][0]: + if isinstance(item, dict) and "url" in item: + yield ImagePreview(item["url"], prompt) + elif isinstance(item, list) and len(item) > 2 and "url" in item[1]: + yield ImagePreview(item[2], prompt) + + elif json_data.get('msg') == 'process_completed': + if 'output' in json_data and 'error' in json_data['output']: + json_data['output']['error'] = json_data['output']['error'].split(" 0: + yield ImageResponse(json_data['output']['data'][0]["url"], prompt) break + except (json.JSONDecodeError, KeyError, TypeError) as e: + raise RuntimeError(f"Failed to parse message: {chunk.decode(errors='replace')}", e) \ No newline at end of file diff --git a/g4f/Provider/hf_space/G4F.py b/g4f/Provider/hf_space/G4F.py index 5eab0cbf..c0539c01 100644 --- a/g4f/Provider/hf_space/G4F.py +++ b/g4f/Provider/hf_space/G4F.py @@ -14,6 +14,7 @@ from .raise_for_status import raise_for_status class FluxDev(BlackForestLabsFlux1Dev): url = "https://roxky-flux-1-dev.hf.space" space = "roxky/FLUX.1-dev" + referer = f"{url}/?__theme=light" class G4F(Janus_Pro_7B): label = "G4F framework" -- cgit v1.2.3