quivr/backend/modules/tools/web_search.py
2024-05-10 23:49:08 +02:00

86 lines
2.8 KiB
Python

import os
from typing import Dict, Optional, Type
import requests
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.pydantic_v1 import BaseModel as BaseModelV1
from langchain.pydantic_v1 import Field as FieldV1
from langchain_core.tools import BaseTool
from logger import get_logger
from pydantic import BaseModel
logger = get_logger(__name__)
class WebSearchInput(BaseModelV1):
query: str = FieldV1(..., title="query", description="search query to look up")
class WebSearchTool(BaseTool):
name = "brave-web-search"
description = "useful for when you need to search the web for something."
args_schema: Type[BaseModel] = WebSearchInput
api_key: str = os.getenv("BRAVE_SEARCH_API_KEY", "")
def _check_environment_variable(self) -> bool:
"""Check if the environment variable is set."""
return os.getenv("BRAVE_SEARCH_API_KEY") is not None
def __init__(self):
if not self._check_environment_variable():
raise ValueError("BRAVE_SEARCH_API_KEY environment variable is not set")
super().__init__()
def _run(
self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
) -> Dict:
"""Run the tool."""
headers = {
"Accept": "application/json",
"Accept-Encoding": "gzip",
"X-Subscription-Token": self.api_key,
}
response = requests.get(
f"https://api.search.brave.com/res/v1/web/search?q={query}&count=3",
headers=headers,
)
return self._parse_response(response.json())
async def _arun(
self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None
) -> Dict:
"""Run the tool asynchronously."""
headers = {
"Accept": "application/json",
"Accept-Encoding": "gzip",
"X-Subscription-Token": self.api_key,
}
response = requests.get(
f"https://api.search.brave.com/res/v1/web/search?q={query}&count=3",
headers=headers,
)
return self._parse_response(response.json())
def _parse_response(self, response: Dict) -> str:
"""Parse the response."""
short_results = []
results = response["web"]["results"]
for result in results:
title = result["title"]
url = result["url"]
description = result["description"]
short_results.append(self._format_result(title, description, url))
return "\n".join(short_results)
def _format_result(self, title: str, description: str, url: str) -> str:
return f"**{title}**\n{description}\n{url}"
if __name__ == "__main__":
tool = WebSearchTool()
print(tool.run("python"))