summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/create_images.py
blob: 852ba01ae9ddf8e7663e60e8aac4878cd4d0a272 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from __future__ import annotations

import re
from ..typing import CreateResult, Messages
from ..base_provider import BaseProvider, ProviderType
from .. import debug

system_message = """
You can generate custom images with the DALL-E 3 image generator.
To generate a image with a prompt, do this:
<img data-prompt=\"keywords for the image\">
Don't use images with data uri. It is important to use a prompt instead.
<img data-prompt=\"image caption\">
"""

class CreateImagesProvider(BaseProvider):
    def __init__(
        self,
        provider: ProviderType,
        create_images: callable,
        system_message: str = system_message
    ) -> None:
        self.provider = provider
        self.create_images = create_images
        self.system_message = system_message
        self. __name__ = provider.__name__
        if hasattr(provider, "url"):
            self.url = provider.url
        self.working = provider.working
        self.supports_stream = provider.supports_stream

    def create_completion(
        self,
        model: str,
        messages: Messages,
        stream: bool = False,
        **kwargs
    ) -> CreateResult:
        messages.insert(0, {"role": "system", "content": self.system_message})
        image_placeholder = ""
        for chunk in self.provider.create_completion(model, messages, stream, **kwargs):
            if image_placeholder or "<" in chunk:
                image_placeholder += chunk
                if ">" in image_placeholder:
                    result = re.search(r'<img data-prompt="(.*?)"', image_placeholder)
                    if result:
                        prompt = result.group(1)
                        if debug.logging:
                            print(f"Create images with prompt: {prompt}")
                        yield from self.create_images(prompt)
                    else:
                        yield image_placeholder
                    image_placeholder = ""
            else:
                yield chunk

    async def create_async(
        self,
        model: str,
        messages: Messages,
        **kwargs
    ) -> str:
        messages.insert(0, {"role": "system", "content": self.system_message})
        response = await self.provider.create_async(model, messages, **kwargs)
        result = re.search(r'<img data-prompt="(.*?)">', response)
        if result:
            search = result.group(0)
            prompt = result.group(1)
            images = "".join([*self.create_images(prompt)])
            return response.replace(search, images)
        return response