mirror of
https://github.com/StanGirard/quivr.git
synced 2024-11-26 03:15:19 +03:00
7acb52a963
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
import os
|
|
|
|
import pytest
|
|
from langchain_core.language_models import FakeListChatModel
|
|
from pydantic.v1.error_wrappers import ValidationError
|
|
from quivr_core.config import LLMEndpointConfig
|
|
from quivr_core.llm import LLMEndpoint
|
|
|
|
|
|
@pytest.mark.base
|
|
def test_llm_endpoint_from_config_default():
|
|
from langchain_openai import ChatOpenAI
|
|
|
|
del os.environ["OPENAI_API_KEY"]
|
|
|
|
with pytest.raises((ValidationError, ValueError)):
|
|
llm = LLMEndpoint.from_config(LLMEndpointConfig())
|
|
|
|
# Working default
|
|
config = LLMEndpointConfig(llm_api_key="test")
|
|
llm = LLMEndpoint.from_config(config=config)
|
|
|
|
assert llm.supports_func_calling()
|
|
assert isinstance(llm._llm, ChatOpenAI)
|
|
assert llm._llm.model_name in llm.get_config().model
|
|
|
|
|
|
@pytest.mark.base
|
|
def test_llm_endpoint_from_config():
|
|
from langchain_openai import ChatOpenAI
|
|
|
|
config = LLMEndpointConfig(
|
|
model="llama2", llm_api_key="test", llm_base_url="http://localhost:8441"
|
|
)
|
|
llm = LLMEndpoint.from_config(config)
|
|
|
|
assert not llm.supports_func_calling()
|
|
assert isinstance(llm._llm, ChatOpenAI)
|
|
assert llm._llm.model_name in llm.get_config().model
|
|
|
|
|
|
def test_llm_endpoint_constructor():
|
|
llm_endpoint = FakeListChatModel(responses=[])
|
|
llm_endpoint = LLMEndpoint(
|
|
llm=llm_endpoint, llm_config=LLMEndpointConfig(model="test")
|
|
)
|
|
|
|
assert not llm_endpoint.supports_func_calling()
|