2024-07-12 16:07:39 +03:00
|
|
|
import asyncio
|
|
|
|
import json
|
|
|
|
from uuid import uuid4
|
|
|
|
|
|
|
|
from langchain_core.embeddings import DeterministicFakeEmbedding
|
|
|
|
from langchain_core.messages.ai import AIMessageChunk
|
|
|
|
from langchain_core.vectorstores import InMemoryVectorStore
|
2024-10-31 19:57:54 +03:00
|
|
|
from quivr_core.rag.entities.chat import ChatHistory
|
|
|
|
from quivr_core.rag.entities.config import LLMEndpointConfig, RetrievalConfig
|
2024-07-12 16:07:39 +03:00
|
|
|
from quivr_core.llm import LLMEndpoint
|
2024-10-31 19:57:54 +03:00
|
|
|
from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph
|
2024-07-12 16:07:39 +03:00
|
|
|
|
|
|
|
|
|
|
|
async def main():
|
2024-09-23 19:11:06 +03:00
|
|
|
retrieval_config = RetrievalConfig(llm_config=LLMEndpointConfig(model="gpt-4o"))
|
2024-07-12 16:07:39 +03:00
|
|
|
embedder = DeterministicFakeEmbedding(size=20)
|
|
|
|
vec = InMemoryVectorStore(embedder)
|
|
|
|
|
2024-09-23 19:11:06 +03:00
|
|
|
llm = LLMEndpoint.from_config(retrieval_config.llm_config)
|
2024-07-12 16:07:39 +03:00
|
|
|
chat_history = ChatHistory(uuid4(), uuid4())
|
2024-09-23 19:11:06 +03:00
|
|
|
rag_pipeline = QuivrQARAGLangGraph(
|
|
|
|
retrieval_config=retrieval_config, llm=llm, vector_store=vec
|
|
|
|
)
|
2024-07-12 16:07:39 +03:00
|
|
|
|
2024-09-23 19:11:06 +03:00
|
|
|
conversational_qa_chain = rag_pipeline.build_chain()
|
2024-07-12 16:07:39 +03:00
|
|
|
|
|
|
|
with open("response.jsonl", "w") as f:
|
2024-09-23 19:11:06 +03:00
|
|
|
async for event in conversational_qa_chain.astream_events(
|
2024-07-12 16:07:39 +03:00
|
|
|
{
|
2024-09-23 19:11:06 +03:00
|
|
|
"messages": [
|
|
|
|
("user", "What is NLP, give a very long detailed answer"),
|
|
|
|
],
|
2024-07-12 16:07:39 +03:00
|
|
|
"chat_history": chat_history,
|
|
|
|
"custom_personality": None,
|
|
|
|
},
|
2024-09-23 19:11:06 +03:00
|
|
|
version="v1",
|
2024-07-12 16:07:39 +03:00
|
|
|
config={"metadata": {}},
|
|
|
|
):
|
2024-09-23 19:11:06 +03:00
|
|
|
kind = event["event"]
|
|
|
|
if (
|
|
|
|
kind == "on_chat_model_stream"
|
|
|
|
and event["metadata"]["langgraph_node"] == "generate"
|
|
|
|
):
|
|
|
|
chunk = event["data"]["chunk"]
|
|
|
|
dict_chunk = {
|
|
|
|
k: v.dict() if isinstance(v, AIMessageChunk) else v
|
|
|
|
for k, v in chunk.items()
|
|
|
|
}
|
|
|
|
f.write(json.dumps(dict_chunk) + "\n")
|
2024-07-12 16:07:39 +03:00
|
|
|
|
|
|
|
|
|
|
|
asyncio.run(main())
|