From 03ebbe70fefcaa9a9d8ef57239bc9bab859b14dc Mon Sep 17 00:00:00 2001 From: Stan Girard Date: Fri, 10 May 2024 11:42:49 +0200 Subject: [PATCH] chore: tools (#2575) # Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate): --- .../modules/brain/integrations/GPT4/Brain.py | 51 +--------------- backend/modules/tools/__init__.py | 2 + backend/modules/tools/image_generator.py | 61 +++++++++++++++++++ 3 files changed, 64 insertions(+), 50 deletions(-) create mode 100644 backend/modules/tools/__init__.py create mode 100644 backend/modules/tools/image_generator.py diff --git a/backend/modules/brain/integrations/GPT4/Brain.py b/backend/modules/brain/integrations/GPT4/Brain.py index 6a8af8c91..eb4457932 100644 --- a/backend/modules/brain/integrations/GPT4/Brain.py +++ b/backend/modules/brain/integrations/GPT4/Brain.py @@ -24,6 +24,7 @@ from modules.chat.dto.outputs import GetChatHistoryOutput from modules.chat.service.chat_service import ChatService from openai import OpenAI from pydantic import BaseModel +from modules.tools import ImageGeneratorTool class AgentState(TypedDict): @@ -36,56 +37,6 @@ logger = get_logger(__name__) chat_service = ChatService() - -class ImageGenerationInput(BaseModelV1): - query: str = FieldV1( - ..., - title="description", - description="A detailled prompt to generate the image from. Takes into account the history of the chat.", - ) - - -class ImageGeneratorTool(BaseTool): - name = "image-generator" - description = "useful for when you need to generate an image from a prompt." - args_schema: Type[BaseModel] = ImageGenerationInput - return_direct = True - - def _run( - self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None - ) -> str: - client = OpenAI() - - response = client.images.generate( - model="dall-e-3", - prompt=query, - size="1024x1024", - quality="standard", - n=1, - ) - image_url = response.data[0].url - revised_prompt = response.data[0].revised_prompt - # Make the url a markdown image - return f"{revised_prompt} \n ![Generated Image]({image_url}) " - - async def _arun( - self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None - ) -> str: - """Use the tool asynchronously.""" - client = OpenAI() - response = await run_manager.run_async( - client.images.generate, - model="dall-e-3", - prompt=query, - size="1024x1024", - quality="standard", - n=1, - ) - image_url = response.data[0].url - # Make the url a markdown image - return f"![Generated Image]({image_url})" - - class GPT4Brain(KnowledgeBrainQA): """This is the Notion brain class. it is a KnowledgeBrainQA has the data is stored locally. It is going to call the Data Store internally to get the data. diff --git a/backend/modules/tools/__init__.py b/backend/modules/tools/__init__.py new file mode 100644 index 000000000..3267c76bf --- /dev/null +++ b/backend/modules/tools/__init__.py @@ -0,0 +1,2 @@ +from .image_generator import ImageGeneratorTool + diff --git a/backend/modules/tools/image_generator.py b/backend/modules/tools/image_generator.py new file mode 100644 index 000000000..debbc1b31 --- /dev/null +++ b/backend/modules/tools/image_generator.py @@ -0,0 +1,61 @@ +from typing import Optional, Type + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.pydantic_v1 import BaseModel as BaseModelV1 +from langchain.pydantic_v1 import Field as FieldV1 +from langchain.tools import BaseTool +from langchain_core.tools import BaseTool +from openai import OpenAI +from pydantic import BaseModel + + +class ImageGenerationInput(BaseModelV1): + query: str = FieldV1( + ..., + title="description", + description="A detailled prompt to generate the image from. Takes into account the history of the chat.", + ) + + +class ImageGeneratorTool(BaseTool): + name = "image-generator" + description = "useful for when you need to generate an image from a prompt." + args_schema: Type[BaseModel] = ImageGenerationInput + return_direct = True + + def _run( + self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: + client = OpenAI() + + response = client.images.generate( + model="dall-e-3", + prompt=query, + size="1024x1024", + quality="standard", + n=1, + ) + image_url = response.data[0].url + revised_prompt = response.data[0].revised_prompt + # Make the url a markdown image + return f"{revised_prompt} \n ![Generated Image]({image_url}) " + + async def _arun( + self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None + ) -> str: + """Use the tool asynchronously.""" + client = OpenAI() + response = await run_manager.run_async( + client.images.generate, + model="dall-e-3", + prompt=query, + size="1024x1024", + quality="standard", + n=1, + ) + image_url = response.data[0].url + # Make the url a markdown image + return f"![Generated Image]({image_url})"