mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-15 17:43:03 +03:00
feat(azure): quivr compatible with it (#3005)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
This commit is contained in:
parent
94c7e6501a
commit
b5f31a83d4
@ -1,10 +1,11 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain_community.embeddings.ollama import OllamaEmbeddings
|
from langchain_community.embeddings.ollama import OllamaEmbeddings
|
||||||
from langchain_community.vectorstores.supabase import SupabaseVectorStore
|
from langchain_community.vectorstores.supabase import SupabaseVectorStore
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||||
from posthog import Posthog
|
from posthog import Posthog
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
from sqlalchemy import Engine, create_engine
|
from sqlalchemy import Engine, create_engine
|
||||||
@ -116,6 +117,7 @@ class PostHogSettings(BaseSettings):
|
|||||||
class BrainSettings(BaseSettings):
|
class BrainSettings(BaseSettings):
|
||||||
model_config = SettingsConfigDict(validate_default=False)
|
model_config = SettingsConfigDict(validate_default=False)
|
||||||
openai_api_key: str = ""
|
openai_api_key: str = ""
|
||||||
|
azure_openai_embeddings_url: str = ""
|
||||||
supabase_url: str = ""
|
supabase_url: str = ""
|
||||||
supabase_service_key: str = ""
|
supabase_service_key: str = ""
|
||||||
resend_api_key: str = "null"
|
resend_api_key: str = "null"
|
||||||
@ -190,6 +192,21 @@ def get_embedding_client() -> Embeddings:
|
|||||||
embeddings = OllamaEmbeddings(
|
embeddings = OllamaEmbeddings(
|
||||||
base_url=settings.ollama_api_base_url,
|
base_url=settings.ollama_api_base_url,
|
||||||
) # pyright: ignore reportPrivateUsage=none
|
) # pyright: ignore reportPrivateUsage=none
|
||||||
|
else:
|
||||||
|
if settings.azure_openai_embeddings_url:
|
||||||
|
# https://quivr-test.openai.azure.com/openai/deployments/embedding/embeddings?api-version=2023-05-15
|
||||||
|
# parse the url to get the deployment name
|
||||||
|
deployment = settings.azure_openai_embeddings_url.split("/")[5]
|
||||||
|
netloc = "https://" + urlparse(settings.azure_openai_embeddings_url).netloc
|
||||||
|
api_version = settings.azure_openai_embeddings_url.split("=")[1]
|
||||||
|
logger.debug(f"Using Azure OpenAI embeddings: {deployment}")
|
||||||
|
logger.debug(f"Using Azure OpenAI embeddings: {netloc}")
|
||||||
|
logger.debug(f"Using Azure OpenAI embeddings: {api_version}")
|
||||||
|
embeddings = AzureOpenAIEmbeddings(
|
||||||
|
azure_deployment=deployment,
|
||||||
|
azure_endpoint=netloc,
|
||||||
|
api_version=api_version,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
embeddings = OpenAIEmbeddings() # pyright: ignore reportPrivateUsage=none
|
embeddings = OpenAIEmbeddings() # pyright: ignore reportPrivateUsage=none
|
||||||
return embeddings
|
return embeddings
|
||||||
|
@ -3,6 +3,7 @@ from typing import Optional
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import HTTPException, UploadFile
|
from fastapi import HTTPException, UploadFile
|
||||||
|
|
||||||
from quivr_api.celery_worker import process_file_and_notify
|
from quivr_api.celery_worker import process_file_and_notify
|
||||||
from quivr_api.logger import get_logger
|
from quivr_api.logger import get_logger
|
||||||
from quivr_api.modules.brain.entity.brain_entity import RoleEnum
|
from quivr_api.modules.brain.entity.brain_entity import RoleEnum
|
||||||
@ -80,7 +81,7 @@ async def upload_file(
|
|||||||
notification_id,
|
notification_id,
|
||||||
NotificationUpdatableProperties(
|
NotificationUpdatableProperties(
|
||||||
status=NotificationsStatusEnum.ERROR,
|
status=NotificationsStatusEnum.ERROR,
|
||||||
description=f"There was an error uploading the file",
|
description="There was an error uploading the file",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -1,9 +1,15 @@
|
|||||||
|
import logging
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from pydantic.v1 import SecretStr
|
from pydantic.v1 import SecretStr
|
||||||
|
|
||||||
from quivr_core.brain.info import LLMInfo
|
from quivr_core.brain.info import LLMInfo
|
||||||
from quivr_core.config import LLMEndpointConfig
|
from quivr_core.config import LLMEndpointConfig
|
||||||
from quivr_core.utils import model_supports_function_calling
|
from quivr_core.utils import model_supports_function_calling
|
||||||
|
|
||||||
|
logger = logging.getLogger("quivr_core")
|
||||||
|
|
||||||
|
|
||||||
class LLMEndpoint:
|
class LLMEndpoint:
|
||||||
def __init__(self, llm_config: LLMEndpointConfig, llm: BaseChatModel):
|
def __init__(self, llm_config: LLMEndpointConfig, llm: BaseChatModel):
|
||||||
@ -19,11 +25,28 @@ class LLMEndpoint:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
|
def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
|
||||||
try:
|
try:
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
||||||
|
|
||||||
|
if config.model.startswith("azure/"):
|
||||||
|
# Parse the URL
|
||||||
|
parsed_url = urlparse(config.llm_base_url)
|
||||||
|
deployment = parsed_url.path.split("/")[3] # type: ignore
|
||||||
|
api_version = parse_qs(parsed_url.query).get("api-version", [None])[0] # type: ignore
|
||||||
|
azure_endpoint = f"https://{parsed_url.netloc}" # type: ignore
|
||||||
|
_llm = AzureChatOpenAI(
|
||||||
|
azure_deployment=deployment, # type: ignore
|
||||||
|
api_version=api_version,
|
||||||
|
api_key=SecretStr(config.llm_api_key)
|
||||||
|
if config.llm_api_key
|
||||||
|
else None,
|
||||||
|
azure_endpoint=azure_endpoint,
|
||||||
|
)
|
||||||
|
else:
|
||||||
_llm = ChatOpenAI(
|
_llm = ChatOpenAI(
|
||||||
model=config.model,
|
model=config.model,
|
||||||
api_key=SecretStr(config.llm_api_key) if config.llm_api_key else None,
|
api_key=SecretStr(config.llm_api_key)
|
||||||
|
if config.llm_api_key
|
||||||
|
else None,
|
||||||
base_url=config.llm_base_url,
|
base_url=config.llm_base_url,
|
||||||
)
|
)
|
||||||
return cls(llm=_llm, llm_config=config)
|
return cls(llm=_llm, llm_config=config)
|
||||||
|
Loading…
Reference in New Issue
Block a user