feat(azure): quivr compatible with it (#3005)

# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
This commit is contained in:
Stan Girard 2024-08-14 15:00:19 +02:00 committed by GitHub
parent 94c7e6501a
commit b5f31a83d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 51 additions and 10 deletions

View File

@ -1,10 +1,11 @@
from typing import Optional from typing import Optional
from urllib.parse import urlparse
from uuid import UUID from uuid import UUID
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain_community.embeddings.ollama import OllamaEmbeddings from langchain_community.embeddings.ollama import OllamaEmbeddings
from langchain_community.vectorstores.supabase import SupabaseVectorStore from langchain_community.vectorstores.supabase import SupabaseVectorStore
from langchain_openai import OpenAIEmbeddings from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from posthog import Posthog from posthog import Posthog
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from sqlalchemy import Engine, create_engine from sqlalchemy import Engine, create_engine
@ -116,6 +117,7 @@ class PostHogSettings(BaseSettings):
class BrainSettings(BaseSettings): class BrainSettings(BaseSettings):
model_config = SettingsConfigDict(validate_default=False) model_config = SettingsConfigDict(validate_default=False)
openai_api_key: str = "" openai_api_key: str = ""
azure_openai_embeddings_url: str = ""
supabase_url: str = "" supabase_url: str = ""
supabase_service_key: str = "" supabase_service_key: str = ""
resend_api_key: str = "null" resend_api_key: str = "null"
@ -191,7 +193,22 @@ def get_embedding_client() -> Embeddings:
base_url=settings.ollama_api_base_url, base_url=settings.ollama_api_base_url,
) # pyright: ignore reportPrivateUsage=none ) # pyright: ignore reportPrivateUsage=none
else: else:
embeddings = OpenAIEmbeddings() # pyright: ignore reportPrivateUsage=none if settings.azure_openai_embeddings_url:
# https://quivr-test.openai.azure.com/openai/deployments/embedding/embeddings?api-version=2023-05-15
# parse the url to get the deployment name
deployment = settings.azure_openai_embeddings_url.split("/")[5]
netloc = "https://" + urlparse(settings.azure_openai_embeddings_url).netloc
api_version = settings.azure_openai_embeddings_url.split("=")[1]
logger.debug(f"Using Azure OpenAI embeddings: {deployment}")
logger.debug(f"Using Azure OpenAI embeddings: {netloc}")
logger.debug(f"Using Azure OpenAI embeddings: {api_version}")
embeddings = AzureOpenAIEmbeddings(
azure_deployment=deployment,
azure_endpoint=netloc,
api_version=api_version,
)
else:
embeddings = OpenAIEmbeddings() # pyright: ignore reportPrivateUsage=none
return embeddings return embeddings

View File

@ -3,6 +3,7 @@ from typing import Optional
from uuid import UUID from uuid import UUID
from fastapi import HTTPException, UploadFile from fastapi import HTTPException, UploadFile
from quivr_api.celery_worker import process_file_and_notify from quivr_api.celery_worker import process_file_and_notify
from quivr_api.logger import get_logger from quivr_api.logger import get_logger
from quivr_api.modules.brain.entity.brain_entity import RoleEnum from quivr_api.modules.brain.entity.brain_entity import RoleEnum
@ -74,13 +75,13 @@ async def upload_file(
status_code=403, status_code=403,
detail=f"File {upload_file.filename} already exists in storage.", detail=f"File {upload_file.filename} already exists in storage.",
) )
else: else:
notification_service.update_notification_by_id( notification_service.update_notification_by_id(
notification_id, notification_id,
NotificationUpdatableProperties( NotificationUpdatableProperties(
status=NotificationsStatusEnum.ERROR, status=NotificationsStatusEnum.ERROR,
description=f"There was an error uploading the file", description="There was an error uploading the file",
), ),
) )
raise HTTPException( raise HTTPException(

View File

@ -1,9 +1,15 @@
import logging
from urllib.parse import parse_qs, urlparse
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from pydantic.v1 import SecretStr from pydantic.v1 import SecretStr
from quivr_core.brain.info import LLMInfo from quivr_core.brain.info import LLMInfo
from quivr_core.config import LLMEndpointConfig from quivr_core.config import LLMEndpointConfig
from quivr_core.utils import model_supports_function_calling from quivr_core.utils import model_supports_function_calling
logger = logging.getLogger("quivr_core")
class LLMEndpoint: class LLMEndpoint:
def __init__(self, llm_config: LLMEndpointConfig, llm: BaseChatModel): def __init__(self, llm_config: LLMEndpointConfig, llm: BaseChatModel):
@ -19,13 +25,30 @@ class LLMEndpoint:
@classmethod @classmethod
def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()): def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
try: try:
from langchain_openai import ChatOpenAI from langchain_openai import AzureChatOpenAI, ChatOpenAI
_llm = ChatOpenAI( if config.model.startswith("azure/"):
model=config.model, # Parse the URL
api_key=SecretStr(config.llm_api_key) if config.llm_api_key else None, parsed_url = urlparse(config.llm_base_url)
base_url=config.llm_base_url, deployment = parsed_url.path.split("/")[3] # type: ignore
) api_version = parse_qs(parsed_url.query).get("api-version", [None])[0] # type: ignore
azure_endpoint = f"https://{parsed_url.netloc}" # type: ignore
_llm = AzureChatOpenAI(
azure_deployment=deployment, # type: ignore
api_version=api_version,
api_key=SecretStr(config.llm_api_key)
if config.llm_api_key
else None,
azure_endpoint=azure_endpoint,
)
else:
_llm = ChatOpenAI(
model=config.model,
api_key=SecretStr(config.llm_api_key)
if config.llm_api_key
else None,
base_url=config.llm_base_url,
)
return cls(llm=_llm, llm_config=config) return cls(llm=_llm, llm_config=config)
except ImportError as e: except ImportError as e: