feat: Add support for langchain-anthropic in LLMEndpoint

This commit adds support for the langchain-anthropic language model in the LLMEndpoint class. The  function has been updated to include the new model names. Now, when configuring the LLMEndpoint, if the model name starts with claude, the ChatAnthropic class from langchain-anthropic will be used instead of ChatOpenAI from langchain-openai.
This commit is contained in:
Stan Girard 2024-08-12 11:09:05 +02:00
parent d19d01e556
commit f1407cc8cf
7 changed files with 867 additions and 625 deletions

1377
backend/core/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -18,6 +18,7 @@ aiofiles = ">=23.0.0,<25.0.0"
faiss-cpu = { version = "^1.8.0.post1", optional = true }
langchain-community = { version = "^0.2.6", optional = true }
langchain-openai = { version = "^0.1.14", optional = true }
langchain-anthropic = { version = "^0.1.22", optional = true }
# To install unstructured, youll also need to install the following system dependencies:
# libmagic, poppler, libreoffice, pandoc, and tesseract.
# NOTE: for now poetry doesn't support groups as extra:
@ -49,7 +50,7 @@ unstructured = { version = "^0.15.0", optional = true, extras = [
docx2txt = { version = "^0.8", optional = true }
[tool.poetry.extras]
base = ["langchain-community", "faiss-cpu", "langchain-openai"]
base = ["langchain-community", "faiss-cpu", "langchain-openai", "langchain-anthropic"]
csv = ["langchain-community"]
md = ["langchain-community", "unstructured"]
ipynb = ["langchain-community"]
@ -66,6 +67,7 @@ all = [
"unstructured",
"docx2txt",
"megaparse",
"langchain-anthropic",
]
[tool.poetry.group.dev]

View File

@ -3,6 +3,7 @@ from typing import Any, Generator, Tuple
from uuid import UUID, uuid4
from langchain_core.messages import AIMessage, HumanMessage
from quivr_core.models import ChatMessage

View File

@ -1,5 +1,6 @@
from langchain_core.language_models.chat_models import BaseChatModel
from pydantic.v1 import SecretStr
from quivr_core.brain.info import LLMInfo
from quivr_core.config import LLMEndpointConfig
from quivr_core.utils import model_supports_function_calling
@ -19,13 +20,26 @@ class LLMEndpoint:
@classmethod
def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
try:
from langchain_openai import ChatOpenAI
if config.model.startswith("claude"):
from langchain_anthropic import ChatAnthropic
_llm = ChatOpenAI(
model=config.model,
api_key=SecretStr(config.llm_api_key) if config.llm_api_key else None,
base_url=config.llm_base_url,
)
_llm = ChatAnthropic(
model=config.model,
api_key=SecretStr(config.llm_api_key)
if config.llm_api_key
else None,
base_url=config.llm_base_url,
)
else:
from langchain_openai import ChatOpenAI
_llm = ChatOpenAI(
model=config.model,
api_key=SecretStr(config.llm_api_key)
if config.llm_api_key
else None,
base_url=config.llm_base_url,
)
return cls(llm=_llm, llm_config=config)
except ImportError as e:

View File

@ -1,4 +1,5 @@
import logging
import re
from typing import Any, List, Tuple, no_type_check
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
@ -23,19 +24,19 @@ logger = logging.getLogger("quivr_core")
def model_supports_function_calling(model_name: str):
models_supporting_function_calls = [
"gpt-4",
"gpt-4-1106-preview",
"gpt-4-0613",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0613",
"gpt-4-0125-preview",
"gpt-3.5-turbo",
"gpt-4-turbo",
"gpt-4o",
"gpt-4o-mini",
r"gpt-4(-\d{4}-preview)?",
r"gpt-4-\d{4}",
r"gpt-3\.5-turbo(-\d{4})?",
r"gpt-4-turbo",
r"gpt-4o(-mini)?",
r"mistral-(small|large)-latest",
r"claude-3-opus-\d{8}",
r"claude-3-haiku-\d{8}",
r"claude-3-sonnet-\d{8}",
]
return model_name in models_supporting_function_calls
return any(
re.match(pattern, model_name) for pattern in models_supporting_function_calls
)
def format_history_to_openai_mesages(

View File

@ -1,6 +1,7 @@
from uuid import uuid4
import pytest
from quivr_core.chat import ChatHistory
from quivr_core.config import LLMEndpointConfig, RAGConfig
from quivr_core.llm import LLMEndpoint
@ -66,3 +67,50 @@ async def test_quivrqarag(
# Assert whole response makes sense
assert "".join([r.answer for r in stream_responses]) == full_response
@pytest.mark.base
@pytest.mark.asyncio
async def test_quivrqarag_claude(
mem_vector_store, full_response, mock_chain_qa_stream, openai_api_key
):
# Making sure the model
llm_config = LLMEndpointConfig(model="claude-3-sonnet-20240229")
llm = LLMEndpoint.from_config(llm_config)
rag_config = RAGConfig(llm_config=llm_config)
chat_history = ChatHistory(uuid4(), uuid4())
rag_pipeline = QuivrQARAG(
rag_config=rag_config, llm=llm, vector_store=mem_vector_store
)
stream_responses: list[ParsedRAGChunkResponse] = []
# Making sure that we are calling the func_calling code path
assert rag_pipeline.llm_endpoint.supports_func_calling()
async for resp in rag_pipeline.answer_astream(
"answer in bullet points. tell me something", chat_history, []
):
stream_responses.append(resp)
assert all(
not r.last_chunk for r in stream_responses[:-1]
), "Some chunks before last have last_chunk=True"
assert stream_responses[-1].last_chunk
for idx, response in enumerate(stream_responses[1:-1]):
assert (
len(response.answer) > 0
), f"Sent an empty answer {response} at index {idx+1}"
# Verify metadata
default_metadata = RAGResponseMetadata().model_dump()
assert all(
r.metadata.model_dump() == default_metadata for r in stream_responses[:-1]
)
last_response = stream_responses[-1]
# TODO(@aminediro) : test responses with sources
assert last_response.metadata.sources == []
assert last_response.metadata.citations == []
# Assert whole response makes sense
assert "".join([r.answer for r in stream_responses]) == full_response

View File

@ -13,7 +13,18 @@ from quivr_core.utils import (
def test_model_supports_function_calling():
assert model_supports_function_calling("gpt-4") is True
assert model_supports_function_calling("gpt-4-turbo") is True
assert model_supports_function_calling("gpt-4o") is True
assert model_supports_function_calling("gpt-4o-mini") is True
assert model_supports_function_calling("ollama3") is False
assert model_supports_function_calling("mistral-small-latest") is True
assert model_supports_function_calling("mistral-large-latest") is True
assert model_supports_function_calling("gpt-3.5-turbo") is True
assert model_supports_function_calling("gpt-3.5-turbo-0125") is True
assert model_supports_function_calling("gpt-8.4") is False
assert model_supports_function_calling("mistral-burger-latest") is False
assert model_supports_function_calling("claude-3-opus-20240229") is True
def test_get_prev_message_incorrect_message():