2024-01-10 12:34:56 +03:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import re
|
|
|
|
import asyncio
|
2024-01-13 17:37:36 +03:00
|
|
|
from .. import debug
|
2024-01-10 12:34:56 +03:00
|
|
|
from ..typing import CreateResult, Messages
|
|
|
|
from ..base_provider import BaseProvider, ProviderType
|
|
|
|
|
|
|
|
system_message = """
|
|
|
|
You can generate custom images with the DALL-E 3 image generator.
|
2024-01-14 17:04:37 +03:00
|
|
|
To generate an image with a prompt, do this:
|
2024-01-10 12:34:56 +03:00
|
|
|
<img data-prompt=\"keywords for the image\">
|
|
|
|
Don't use images with data uri. It is important to use a prompt instead.
|
|
|
|
<img data-prompt=\"image caption\">
|
|
|
|
"""
|
|
|
|
|
|
|
|
class CreateImagesProvider(BaseProvider):
|
2024-01-14 17:04:37 +03:00
|
|
|
"""
|
|
|
|
Provider class for creating images based on text prompts.
|
|
|
|
|
|
|
|
This provider handles image creation requests embedded within message content,
|
|
|
|
using provided image creation functions.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
provider (ProviderType): The underlying provider to handle non-image related tasks.
|
|
|
|
create_images (callable): A function to create images synchronously.
|
|
|
|
create_images_async (callable): A function to create images asynchronously.
|
|
|
|
system_message (str): A message that explains the image creation capability.
|
|
|
|
include_placeholder (bool): Flag to determine whether to include the image placeholder in the output.
|
|
|
|
__name__ (str): Name of the provider.
|
|
|
|
url (str): URL of the provider.
|
|
|
|
working (bool): Indicates if the provider is operational.
|
|
|
|
supports_stream (bool): Indicates if the provider supports streaming.
|
|
|
|
"""
|
|
|
|
|
2024-01-10 12:34:56 +03:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
provider: ProviderType,
|
|
|
|
create_images: callable,
|
|
|
|
create_async: callable,
|
|
|
|
system_message: str = system_message,
|
|
|
|
include_placeholder: bool = True
|
|
|
|
) -> None:
|
2024-01-14 17:04:37 +03:00
|
|
|
"""
|
|
|
|
Initializes the CreateImagesProvider.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
provider (ProviderType): The underlying provider.
|
|
|
|
create_images (callable): Function to create images synchronously.
|
|
|
|
create_async (callable): Function to create images asynchronously.
|
|
|
|
system_message (str, optional): System message to be prefixed to messages. Defaults to a predefined message.
|
|
|
|
include_placeholder (bool, optional): Whether to include image placeholders in the output. Defaults to True.
|
|
|
|
"""
|
2024-01-10 12:34:56 +03:00
|
|
|
self.provider = provider
|
|
|
|
self.create_images = create_images
|
|
|
|
self.create_images_async = create_async
|
|
|
|
self.system_message = system_message
|
2024-01-13 17:37:36 +03:00
|
|
|
self.include_placeholder = include_placeholder
|
2024-01-10 12:34:56 +03:00
|
|
|
self.__name__ = provider.__name__
|
2024-01-13 17:37:36 +03:00
|
|
|
self.url = provider.url
|
2024-01-10 12:34:56 +03:00
|
|
|
self.working = provider.working
|
|
|
|
self.supports_stream = provider.supports_stream
|
|
|
|
|
|
|
|
def create_completion(
|
|
|
|
self,
|
|
|
|
model: str,
|
|
|
|
messages: Messages,
|
|
|
|
stream: bool = False,
|
|
|
|
**kwargs
|
|
|
|
) -> CreateResult:
|
2024-01-14 17:04:37 +03:00
|
|
|
"""
|
|
|
|
Creates a completion result, processing any image creation prompts found within the messages.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (str): The model to use for creation.
|
|
|
|
messages (Messages): The messages to process, which may contain image prompts.
|
|
|
|
stream (bool, optional): Indicates whether to stream the results. Defaults to False.
|
|
|
|
**kwargs: Additional keywordarguments for the provider.
|
|
|
|
|
|
|
|
Yields:
|
|
|
|
CreateResult: Yields chunks of the processed messages, including image data if applicable.
|
|
|
|
|
|
|
|
Note:
|
|
|
|
This method processes messages to detect image creation prompts. When such a prompt is found,
|
|
|
|
it calls the synchronous image creation function and includes the resulting image in the output.
|
|
|
|
"""
|
2024-01-10 12:34:56 +03:00
|
|
|
messages.insert(0, {"role": "system", "content": self.system_message})
|
|
|
|
buffer = ""
|
|
|
|
for chunk in self.provider.create_completion(model, messages, stream, **kwargs):
|
|
|
|
if buffer or "<" in chunk:
|
|
|
|
buffer += chunk
|
|
|
|
if ">" in buffer:
|
|
|
|
match = re.search(r'<img data-prompt="(.*?)">', buffer)
|
|
|
|
if match:
|
|
|
|
placeholder, prompt = match.group(0), match.group(1)
|
|
|
|
start, append = buffer.split(placeholder, 1)
|
|
|
|
if start:
|
|
|
|
yield start
|
|
|
|
if self.include_placeholder:
|
|
|
|
yield placeholder
|
2024-01-13 17:37:36 +03:00
|
|
|
if debug.logging:
|
|
|
|
print(f"Create images with prompt: {prompt}")
|
2024-01-10 12:34:56 +03:00
|
|
|
yield from self.create_images(prompt)
|
|
|
|
if append:
|
|
|
|
yield append
|
|
|
|
else:
|
|
|
|
yield buffer
|
|
|
|
buffer = ""
|
|
|
|
else:
|
|
|
|
yield chunk
|
|
|
|
|
|
|
|
async def create_async(
|
|
|
|
self,
|
|
|
|
model: str,
|
|
|
|
messages: Messages,
|
|
|
|
**kwargs
|
|
|
|
) -> str:
|
2024-01-14 17:04:37 +03:00
|
|
|
"""
|
|
|
|
Asynchronously creates a response, processing any image creation prompts found within the messages.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (str): The model to use for creation.
|
|
|
|
messages (Messages): The messages to process, which may contain image prompts.
|
|
|
|
**kwargs: Additional keyword arguments for the provider.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
str: The processed response string, including asynchronously generated image data if applicable.
|
|
|
|
|
|
|
|
Note:
|
|
|
|
This method processes messages to detect image creation prompts. When such a prompt is found,
|
|
|
|
it calls the asynchronous image creation function and includes the resulting image in the output.
|
|
|
|
"""
|
2024-01-10 12:34:56 +03:00
|
|
|
messages.insert(0, {"role": "system", "content": self.system_message})
|
|
|
|
response = await self.provider.create_async(model, messages, **kwargs)
|
2024-01-10 22:08:06 +03:00
|
|
|
matches = re.findall(r'(<img data-prompt="(.*?)">)', response)
|
2024-01-10 12:34:56 +03:00
|
|
|
results = []
|
2024-01-10 22:08:06 +03:00
|
|
|
placeholders = []
|
|
|
|
for placeholder, prompt in matches:
|
|
|
|
if placeholder not in placeholders:
|
2024-01-13 17:37:36 +03:00
|
|
|
if debug.logging:
|
|
|
|
print(f"Create images with prompt: {prompt}")
|
2024-01-10 22:08:06 +03:00
|
|
|
results.append(self.create_images_async(prompt))
|
|
|
|
placeholders.append(placeholder)
|
2024-01-10 12:34:56 +03:00
|
|
|
results = await asyncio.gather(*results)
|
|
|
|
for idx, result in enumerate(results):
|
2024-01-10 22:08:06 +03:00
|
|
|
placeholder = placeholder[idx]
|
2024-01-10 12:34:56 +03:00
|
|
|
if self.include_placeholder:
|
|
|
|
result = placeholder + result
|
|
|
|
response = response.replace(placeholder, result)
|
2024-01-10 22:08:06 +03:00
|
|
|
return response
|