fix: 🐛 brains (#2107)

selection now fixed

# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
This commit is contained in:
Stan Girard 2024-01-28 01:03:36 -08:00 committed by GitHub
parent 2e06b5c7f2
commit 652d2b32e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 80 additions and 21 deletions

View File

@ -2,10 +2,7 @@ from typing import Optional
from uuid import UUID
from fastapi import HTTPException
from langchain.embeddings.ollama import OllamaEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from logger import get_logger
from models.settings import BrainSettings, get_supabase_client
from modules.brain.dto.inputs import BrainUpdatableProperties, CreateBrainProperties
from modules.brain.entity.brain_entity import BrainEntity, BrainType, PublicBrain
from modules.brain.repository import (
@ -52,7 +49,13 @@ class BrainService:
return self.brain_repository.get_brain_by_id(brain_id)
def find_brain_from_question(
self, brain_id: UUID, question: str, user, chat_id: UUID, history
self,
brain_id: UUID,
question: str,
user,
chat_id: UUID,
history,
vector_store: CustomSupabaseVectorStore,
) -> (Optional[BrainEntity], dict[str, str]):
"""Find the brain to use for a question.
@ -67,19 +70,6 @@ class BrainService:
"""
metadata = {}
brain_settings = BrainSettings()
supabase_client = get_supabase_client()
embeddings = None
if brain_settings.ollama_api_base_url:
embeddings = OllamaEmbeddings(
base_url=brain_settings.ollama_api_base_url
) # pyright: ignore reportPrivateUsage=none
else:
embeddings = OpenAIEmbeddings()
vector_store = CustomSupabaseVectorStore(
supabase_client, embeddings, table_name="vectors", user_id=user.id
)
# Init
brain_id_to_use = brain_id
@ -92,13 +82,14 @@ class BrainService:
list_brains = [] # To return
if history and not brain_id_to_use:
# Replace the question with the first question from the history
question = history[0].user_message
if history and not brain_id:
brain_id_to_use = history[0].brain_id
brain_to_use = self.get_brain_by_id(brain_id_to_use)
# If a brain_id is provided, use it
if brain_id_to_use and not brain_to_use:
brain_to_use = self.get_brain_by_id(brain_id_to_use)
# Calculate the closest brains to the question
list_brains = vector_store.find_brain_closest_query(user.id, question)

View File

@ -0,0 +1,34 @@
from unittest.mock import Mock
from uuid import UUID
from modules.brain.entity.brain_entity import BrainEntity
from modules.brain.service.brain_service import BrainService
def test_find_brain_from_question_with_history_and_brain_id():
brain_service = BrainService()
user = Mock()
user.id = 1
chat_id = UUID("12345678123456781234567812345678")
question = "What is the meaning of life?"
brain_id = UUID("87654321876543218765432187654321")
history = [
{
"user_message": "What is AI?",
"brain_id": UUID("87654321876543218765432187654321"),
}
]
vector_store = Mock()
vector_store.find_brain_closest_query.return_value = []
brain_entity_mock = Mock(spec=BrainEntity) # Create a mock BrainEntity
brain_service.get_brain_by_id = Mock(
return_value=brain_entity_mock
) # Mock the get_brain_by_id method
brain_to_use, metadata = brain_service.find_brain_from_question(
brain_id, question, user, chat_id, history, vector_store
)
assert isinstance(brain_to_use, BrainEntity)
assert "close_brains" in metadata

View File

@ -3,8 +3,11 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from langchain.embeddings.ollama import OllamaEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from logger import get_logger
from middlewares.auth import AuthBearer, get_current_user
from models.settings import BrainSettings, get_supabase_client
from models.user_usage import UserUsage
from modules.brain.service.brain_service import BrainService
from modules.chat.controller.chat.brainful_chat import BrainfulChat
@ -23,6 +26,7 @@ from modules.chat.entity.chat import Chat
from modules.chat.service.chat_service import ChatService
from modules.notification.service.notification_service import NotificationService
from modules.user.entity.user_identity import UserIdentity
from vectorstore.supabase import CustomSupabaseVectorStore
logger = get_logger(__name__)
@ -33,6 +37,26 @@ brain_service = BrainService()
chat_service = ChatService()
def init_vector_store(user_id: UUID) -> CustomSupabaseVectorStore:
"""
Initialize the vector store
"""
brain_settings = BrainSettings()
supabase_client = get_supabase_client()
embeddings = None
if brain_settings.ollama_api_base_url:
embeddings = OllamaEmbeddings(
base_url=brain_settings.ollama_api_base_url
) # pyright: ignore reportPrivateUsage=none
else:
embeddings = OpenAIEmbeddings()
vector_store = CustomSupabaseVectorStore(
supabase_client, embeddings, table_name="vectors", user_id=user_id
)
return vector_store
def get_answer_generator(
chat_id: UUID,
chat_question: ChatQuestion,
@ -47,6 +71,8 @@ def get_answer_generator(
email=current_user.email,
)
vector_store = init_vector_store(user_id=current_user.id)
# Get History
history = chat_service.get_chat_history(chat_id)
@ -58,7 +84,7 @@ def get_answer_generator(
# Generic
brain, metadata_brain = brain_service.find_brain_from_question(
brain_id, chat_question.question, current_user, chat_id, history
brain_id, chat_question.question, current_user, chat_id, history, vector_store
)
model_to_use, metadata = find_model_and_generate_metadata(
@ -76,6 +102,7 @@ def get_answer_generator(
models_settings=models_settings,
model_name=model_to_use.name,
)
gpt_answer_generator = chat_instance.get_answer_generator(
chat_id=str(chat_id),
model=model_to_use.name,
@ -216,6 +243,13 @@ async def create_stream_question_handler(
| None = Query(..., description="The ID of the brain"),
current_user: UserIdentity = Depends(get_current_user),
) -> StreamingResponse:
chat_instance = BrainfulChat()
chat_instance.validate_authorization(user_id=current_user.id, brain_id=brain_id)
user_usage = UserUsage(
id=current_user.id,
email=current_user.email,
)
gpt_answer_generator = get_answer_generator(
chat_id, chat_question, brain_id, current_user
)