From b5f31a83d4a1c4432943bbbaa0766c46927ef125 Mon Sep 17 00:00:00 2001 From: Stan Girard Date: Wed, 14 Aug 2024 15:00:19 +0200 Subject: [PATCH] feat(azure): quivr compatible with it (#3005) # Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate): --- backend/api/quivr_api/models/settings.py | 21 +++++++++-- .../quivr_api/modules/sync/utils/upload.py | 5 +-- backend/core/quivr_core/llm/llm_endpoint.py | 35 +++++++++++++++---- 3 files changed, 51 insertions(+), 10 deletions(-) diff --git a/backend/api/quivr_api/models/settings.py b/backend/api/quivr_api/models/settings.py index fcb2bf212..d5a3303ca 100644 --- a/backend/api/quivr_api/models/settings.py +++ b/backend/api/quivr_api/models/settings.py @@ -1,10 +1,11 @@ from typing import Optional +from urllib.parse import urlparse from uuid import UUID from langchain.embeddings.base import Embeddings from langchain_community.embeddings.ollama import OllamaEmbeddings from langchain_community.vectorstores.supabase import SupabaseVectorStore -from langchain_openai import OpenAIEmbeddings +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from posthog import Posthog from pydantic_settings import BaseSettings, SettingsConfigDict from sqlalchemy import Engine, create_engine @@ -116,6 +117,7 @@ class PostHogSettings(BaseSettings): class BrainSettings(BaseSettings): model_config = SettingsConfigDict(validate_default=False) openai_api_key: str = "" + azure_openai_embeddings_url: str = "" supabase_url: str = "" supabase_service_key: str = "" resend_api_key: str = "null" @@ -191,7 +193,22 @@ def get_embedding_client() -> Embeddings: base_url=settings.ollama_api_base_url, ) # pyright: ignore reportPrivateUsage=none else: - embeddings = OpenAIEmbeddings() # pyright: ignore reportPrivateUsage=none + if settings.azure_openai_embeddings_url: + # https://quivr-test.openai.azure.com/openai/deployments/embedding/embeddings?api-version=2023-05-15 + # parse the url to get the deployment name + deployment = settings.azure_openai_embeddings_url.split("/")[5] + netloc = "https://" + urlparse(settings.azure_openai_embeddings_url).netloc + api_version = settings.azure_openai_embeddings_url.split("=")[1] + logger.debug(f"Using Azure OpenAI embeddings: {deployment}") + logger.debug(f"Using Azure OpenAI embeddings: {netloc}") + logger.debug(f"Using Azure OpenAI embeddings: {api_version}") + embeddings = AzureOpenAIEmbeddings( + azure_deployment=deployment, + azure_endpoint=netloc, + api_version=api_version, + ) + else: + embeddings = OpenAIEmbeddings() # pyright: ignore reportPrivateUsage=none return embeddings diff --git a/backend/api/quivr_api/modules/sync/utils/upload.py b/backend/api/quivr_api/modules/sync/utils/upload.py index 8ce23d1de..26b75536c 100644 --- a/backend/api/quivr_api/modules/sync/utils/upload.py +++ b/backend/api/quivr_api/modules/sync/utils/upload.py @@ -3,6 +3,7 @@ from typing import Optional from uuid import UUID from fastapi import HTTPException, UploadFile + from quivr_api.celery_worker import process_file_and_notify from quivr_api.logger import get_logger from quivr_api.modules.brain.entity.brain_entity import RoleEnum @@ -74,13 +75,13 @@ async def upload_file( status_code=403, detail=f"File {upload_file.filename} already exists in storage.", ) - + else: notification_service.update_notification_by_id( notification_id, NotificationUpdatableProperties( status=NotificationsStatusEnum.ERROR, - description=f"There was an error uploading the file", + description="There was an error uploading the file", ), ) raise HTTPException( diff --git a/backend/core/quivr_core/llm/llm_endpoint.py b/backend/core/quivr_core/llm/llm_endpoint.py index b3eac0a13..fec513b26 100644 --- a/backend/core/quivr_core/llm/llm_endpoint.py +++ b/backend/core/quivr_core/llm/llm_endpoint.py @@ -1,9 +1,15 @@ +import logging +from urllib.parse import parse_qs, urlparse + from langchain_core.language_models.chat_models import BaseChatModel from pydantic.v1 import SecretStr + from quivr_core.brain.info import LLMInfo from quivr_core.config import LLMEndpointConfig from quivr_core.utils import model_supports_function_calling +logger = logging.getLogger("quivr_core") + class LLMEndpoint: def __init__(self, llm_config: LLMEndpointConfig, llm: BaseChatModel): @@ -19,13 +25,30 @@ class LLMEndpoint: @classmethod def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()): try: - from langchain_openai import ChatOpenAI + from langchain_openai import AzureChatOpenAI, ChatOpenAI - _llm = ChatOpenAI( - model=config.model, - api_key=SecretStr(config.llm_api_key) if config.llm_api_key else None, - base_url=config.llm_base_url, - ) + if config.model.startswith("azure/"): + # Parse the URL + parsed_url = urlparse(config.llm_base_url) + deployment = parsed_url.path.split("/")[3] # type: ignore + api_version = parse_qs(parsed_url.query).get("api-version", [None])[0] # type: ignore + azure_endpoint = f"https://{parsed_url.netloc}" # type: ignore + _llm = AzureChatOpenAI( + azure_deployment=deployment, # type: ignore + api_version=api_version, + api_key=SecretStr(config.llm_api_key) + if config.llm_api_key + else None, + azure_endpoint=azure_endpoint, + ) + else: + _llm = ChatOpenAI( + model=config.model, + api_key=SecretStr(config.llm_api_key) + if config.llm_api_key + else None, + base_url=config.llm_base_url, + ) return cls(llm=_llm, llm_config=config) except ImportError as e: