diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/create_images.py | 61 |
1 files changed, 60 insertions, 1 deletions
diff --git a/g4f/Provider/create_images.py b/g4f/Provider/create_images.py index f8a0442d..b8bcbde3 100644 --- a/g4f/Provider/create_images.py +++ b/g4f/Provider/create_images.py @@ -8,13 +8,31 @@ from ..base_provider import BaseProvider, ProviderType system_message = """ You can generate custom images with the DALL-E 3 image generator. -To generate a image with a prompt, do this: +To generate an image with a prompt, do this: <img data-prompt=\"keywords for the image\"> Don't use images with data uri. It is important to use a prompt instead. <img data-prompt=\"image caption\"> """ class CreateImagesProvider(BaseProvider): + """ + Provider class for creating images based on text prompts. + + This provider handles image creation requests embedded within message content, + using provided image creation functions. + + Attributes: + provider (ProviderType): The underlying provider to handle non-image related tasks. + create_images (callable): A function to create images synchronously. + create_images_async (callable): A function to create images asynchronously. + system_message (str): A message that explains the image creation capability. + include_placeholder (bool): Flag to determine whether to include the image placeholder in the output. + __name__ (str): Name of the provider. + url (str): URL of the provider. + working (bool): Indicates if the provider is operational. + supports_stream (bool): Indicates if the provider supports streaming. + """ + def __init__( self, provider: ProviderType, @@ -23,6 +41,16 @@ class CreateImagesProvider(BaseProvider): system_message: str = system_message, include_placeholder: bool = True ) -> None: + """ + Initializes the CreateImagesProvider. + + Args: + provider (ProviderType): The underlying provider. + create_images (callable): Function to create images synchronously. + create_async (callable): Function to create images asynchronously. + system_message (str, optional): System message to be prefixed to messages. Defaults to a predefined message. + include_placeholder (bool, optional): Whether to include image placeholders in the output. Defaults to True. + """ self.provider = provider self.create_images = create_images self.create_images_async = create_async @@ -40,6 +68,22 @@ class CreateImagesProvider(BaseProvider): stream: bool = False, **kwargs ) -> CreateResult: + """ + Creates a completion result, processing any image creation prompts found within the messages. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process, which may contain image prompts. + stream (bool, optional): Indicates whether to stream the results. Defaults to False. + **kwargs: Additional keywordarguments for the provider. + + Yields: + CreateResult: Yields chunks of the processed messages, including image data if applicable. + + Note: + This method processes messages to detect image creation prompts. When such a prompt is found, + it calls the synchronous image creation function and includes the resulting image in the output. + """ messages.insert(0, {"role": "system", "content": self.system_message}) buffer = "" for chunk in self.provider.create_completion(model, messages, stream, **kwargs): @@ -71,6 +115,21 @@ class CreateImagesProvider(BaseProvider): messages: Messages, **kwargs ) -> str: + """ + Asynchronously creates a response, processing any image creation prompts found within the messages. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process, which may contain image prompts. + **kwargs: Additional keyword arguments for the provider. + + Returns: + str: The processed response string, including asynchronously generated image data if applicable. + + Note: + This method processes messages to detect image creation prompts. When such a prompt is found, + it calls the asynchronous image creation function and includes the resulting image in the output. + """ messages.insert(0, {"role": "system", "content": self.system_message}) response = await self.provider.create_async(model, messages, **kwargs) matches = re.findall(r'(<img data-prompt="(.*?)">)', response) |