From f618b713afdfa2a1596ea8c258a47e6e70becb1a Mon Sep 17 00:00:00 2001 From: ItzCrazyKns Date: Thu, 2 May 2024 12:14:26 +0530 Subject: [PATCH] feat(chatModels): load model from localstorage --- README.md | 10 +++--- package.json | 2 +- sample.config.toml | 2 -- src/config.ts | 17 +++------- src/routes/config.ts | 9 ------ src/routes/images.ts | 13 ++++---- src/routes/index.ts | 2 ++ src/routes/models.ts | 18 +++++++++++ src/routes/videos.ts | 13 ++++---- src/websocket/connectionManager.ts | 16 +++++++--- src/websocket/websocketServer.ts | 4 +-- ui/components/ChatWindow.tsx | 36 ++++++++++++++++++--- ui/components/SearchImages.tsx | 6 ++++ ui/components/SearchVideos.tsx | 6 ++++ ui/components/SettingsDialog.tsx | 51 ++++++++++++++++-------------- ui/package.json | 2 +- 16 files changed, 126 insertions(+), 81 deletions(-) create mode 100644 src/routes/models.ts diff --git a/README.md b/README.md index 50e0e1d..bb7171b 100644 --- a/README.md +++ b/README.md @@ -59,13 +59,11 @@ There are mainly 2 ways of installing Perplexica - With Docker, Without Docker. 4. Rename the `sample.config.toml` file to `config.toml`. For Docker setups, you need only fill in the following fields: - - `CHAT_MODEL`: The name of the LLM to use. Like `llama3:latest` (using Ollama), `gpt-3.5-turbo` (using OpenAI), etc. - - `CHAT_MODEL_PROVIDER`: The chat model provider, either `openai` or `ollama`. Depending upon which provider you use you would have to fill in the following fields: + - `OPENAI`: Your OpenAI API key. **You only need to fill this if you wish to use OpenAI's models**. + - `OLLAMA`: Your Ollama API URL. You should enter it as `http://host.docker.internal:PORT_NUMBER`. If you installed Ollama on port 11434, use `http://host.docker.internal:11434`. For other ports, adjust accordingly. **You need to fill this if you wish to use Ollama's models instead of OpenAI's**. + - `GROQ`: Your Groq API key. **You only need to fill this if you wish to use Groq's hosted models** - - `OPENAI`: Your OpenAI API key. **You only need to fill this if you wish to use OpenAI's models**. - - `OLLAMA`: Your Ollama API URL. You should enter it as `http://host.docker.internal:PORT_NUMBER`. If you installed Ollama on port 11434, use `http://host.docker.internal:11434`. For other ports, adjust accordingly. **You need to fill this if you wish to use Ollama's models instead of OpenAI's**. - - **Note**: You can change these and use different models after running Perplexica as well from the settings page. + **Note**: You can change these after starting Perplexica from the settings dialog. - `SIMILARITY_MEASURE`: The similarity measure to use (This is filled by default; you can leave it as is if you are unsure about it.) diff --git a/package.json b/package.json index a4b91e1..9434569 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "perplexica-backend", - "version": "1.0.0", + "version": "1.1.0", "license": "MIT", "author": "ItzCrazyKns", "scripts": { diff --git a/sample.config.toml b/sample.config.toml index e283826..7bc8880 100644 --- a/sample.config.toml +++ b/sample.config.toml @@ -1,8 +1,6 @@ [GENERAL] PORT = 3001 # Port to run the server on SIMILARITY_MEASURE = "cosine" # "cosine" or "dot" -CHAT_MODEL_PROVIDER = "openai" # "openai" or "ollama" or "groq" -CHAT_MODEL = "gpt-3.5-turbo" # Name of the model to use [API_KEYS] OPENAI = "" # OpenAI API key - sk-1234567890abcdef1234567890abcdef diff --git a/src/config.ts b/src/config.ts index 25dcbf4..7c0c7f1 100644 --- a/src/config.ts +++ b/src/config.ts @@ -8,8 +8,6 @@ interface Config { GENERAL: { PORT: number; SIMILARITY_MEASURE: string; - CHAT_MODEL_PROVIDER: string; - CHAT_MODEL: string; }; API_KEYS: { OPENAI: string; @@ -35,11 +33,6 @@ export const getPort = () => loadConfig().GENERAL.PORT; export const getSimilarityMeasure = () => loadConfig().GENERAL.SIMILARITY_MEASURE; -export const getChatModelProvider = () => - loadConfig().GENERAL.CHAT_MODEL_PROVIDER; - -export const getChatModel = () => loadConfig().GENERAL.CHAT_MODEL; - export const getOpenaiApiKey = () => loadConfig().API_KEYS.OPENAI; export const getGroqApiKey = () => loadConfig().API_KEYS.GROQ; @@ -52,21 +45,19 @@ export const updateConfig = (config: RecursivePartial) => { const currentConfig = loadConfig(); for (const key in currentConfig) { - /* if (currentConfig[key] && !config[key]) { - config[key] = currentConfig[key]; - } */ + if (!config[key]) config[key] = {}; - if (currentConfig[key] && typeof currentConfig[key] === 'object') { + if (typeof currentConfig[key] === 'object' && currentConfig[key] !== null) { for (const nestedKey in currentConfig[key]) { if ( - currentConfig[key][nestedKey] && !config[key][nestedKey] && + currentConfig[key][nestedKey] && config[key][nestedKey] !== '' ) { config[key][nestedKey] = currentConfig[key][nestedKey]; } } - } else if (currentConfig[key] && !config[key] && config[key] !== '') { + } else if (currentConfig[key] && config[key] !== '') { config[key] = currentConfig[key]; } } diff --git a/src/routes/config.ts b/src/routes/config.ts index 4d22ec5..9518c5f 100644 --- a/src/routes/config.ts +++ b/src/routes/config.ts @@ -1,8 +1,6 @@ import express from 'express'; import { getAvailableProviders } from '../lib/providers'; import { - getChatModel, - getChatModelProvider, getGroqApiKey, getOllamaApiEndpoint, getOpenaiApiKey, @@ -26,9 +24,6 @@ router.get('/', async (_, res) => { config['providers'][provider] = Object.keys(providers[provider]); } - config['selectedProvider'] = getChatModelProvider(); - config['selectedChatModel'] = getChatModel(); - config['openeaiApiKey'] = getOpenaiApiKey(); config['ollamaApiUrl'] = getOllamaApiEndpoint(); config['groqApiKey'] = getGroqApiKey(); @@ -40,10 +35,6 @@ router.post('/', async (req, res) => { const config = req.body; const updatedConfig = { - GENERAL: { - CHAT_MODEL_PROVIDER: config.selectedProvider, - CHAT_MODEL: config.selectedChatModel, - }, API_KEYS: { OPENAI: config.openeaiApiKey, GROQ: config.groqApiKey, diff --git a/src/routes/images.ts b/src/routes/images.ts index 066a3ee..3906689 100644 --- a/src/routes/images.ts +++ b/src/routes/images.ts @@ -2,7 +2,6 @@ import express from 'express'; import handleImageSearch from '../agents/imageSearchAgent'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { getAvailableProviders } from '../lib/providers'; -import { getChatModel, getChatModelProvider } from '../config'; import { HumanMessage, AIMessage } from '@langchain/core/messages'; import logger from '../utils/logger'; @@ -10,7 +9,7 @@ const router = express.Router(); router.post('/', async (req, res) => { try { - let { query, chat_history } = req.body; + let { query, chat_history, chat_model_provider, chat_model } = req.body; chat_history = chat_history.map((msg: any) => { if (msg.role === 'user') { @@ -20,14 +19,14 @@ router.post('/', async (req, res) => { } }); - const models = await getAvailableProviders(); - const provider = getChatModelProvider(); - const chatModel = getChatModel(); + const chatModels = await getAvailableProviders(); + const provider = chat_model_provider || Object.keys(chatModels)[0]; + const chatModel = chat_model || Object.keys(chatModels[provider])[0]; let llm: BaseChatModel | undefined; - if (models[provider] && models[provider][chatModel]) { - llm = models[provider][chatModel] as BaseChatModel | undefined; + if (chatModels[provider] && chatModels[provider][chatModel]) { + llm = chatModels[provider][chatModel] as BaseChatModel | undefined; } if (!llm) { diff --git a/src/routes/index.ts b/src/routes/index.ts index bcfc3d3..04390cd 100644 --- a/src/routes/index.ts +++ b/src/routes/index.ts @@ -2,11 +2,13 @@ import express from 'express'; import imagesRouter from './images'; import videosRouter from './videos'; import configRouter from './config'; +import modelsRouter from './models'; const router = express.Router(); router.use('/images', imagesRouter); router.use('/videos', videosRouter); router.use('/config', configRouter); +router.use('/models', modelsRouter); export default router; diff --git a/src/routes/models.ts b/src/routes/models.ts new file mode 100644 index 0000000..f2332f4 --- /dev/null +++ b/src/routes/models.ts @@ -0,0 +1,18 @@ +import express from 'express'; +import logger from '../utils/logger'; +import { getAvailableProviders } from '../lib/providers'; + +const router = express.Router(); + +router.get('/', async (req, res) => { + try { + const providers = await getAvailableProviders(); + + res.status(200).json({ providers }); + } catch (err) { + res.status(500).json({ message: 'An error has occurred.' }); + logger.error(err.message); + } +}); + +export default router; diff --git a/src/routes/videos.ts b/src/routes/videos.ts index bfd5fa8..fecd874 100644 --- a/src/routes/videos.ts +++ b/src/routes/videos.ts @@ -1,7 +1,6 @@ import express from 'express'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { getAvailableProviders } from '../lib/providers'; -import { getChatModel, getChatModelProvider } from '../config'; import { HumanMessage, AIMessage } from '@langchain/core/messages'; import logger from '../utils/logger'; import handleVideoSearch from '../agents/videoSearchAgent'; @@ -10,7 +9,7 @@ const router = express.Router(); router.post('/', async (req, res) => { try { - let { query, chat_history } = req.body; + let { query, chat_history, chat_model_provider, chat_model } = req.body; chat_history = chat_history.map((msg: any) => { if (msg.role === 'user') { @@ -20,14 +19,14 @@ router.post('/', async (req, res) => { } }); - const models = await getAvailableProviders(); - const provider = getChatModelProvider(); - const chatModel = getChatModel(); + const chatModels = await getAvailableProviders(); + const provider = chat_model_provider || Object.keys(chatModels)[0]; + const chatModel = chat_model || Object.keys(chatModels[provider])[0]; let llm: BaseChatModel | undefined; - if (models[provider] && models[provider][chatModel]) { - llm = models[provider][chatModel] as BaseChatModel | undefined; + if (chatModels[provider] && chatModels[provider][chatModel]) { + llm = chatModels[provider][chatModel] as BaseChatModel | undefined; } if (!llm) { diff --git a/src/websocket/connectionManager.ts b/src/websocket/connectionManager.ts index afaaf44..c2f3798 100644 --- a/src/websocket/connectionManager.ts +++ b/src/websocket/connectionManager.ts @@ -1,15 +1,23 @@ import { WebSocket } from 'ws'; import { handleMessage } from './messageHandler'; -import { getChatModel, getChatModelProvider } from '../config'; import { getAvailableProviders } from '../lib/providers'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import type { Embeddings } from '@langchain/core/embeddings'; +import type { IncomingMessage } from 'http'; import logger from '../utils/logger'; -export const handleConnection = async (ws: WebSocket) => { +export const handleConnection = async ( + ws: WebSocket, + request: IncomingMessage, +) => { + const searchParams = new URL(request.url, `http://${request.headers.host}`) + .searchParams; + const models = await getAvailableProviders(); - const provider = getChatModelProvider(); - const chatModel = getChatModel(); + const provider = + searchParams.get('chatModelProvider') || Object.keys(models)[0]; + const chatModel = + searchParams.get('chatModel') || Object.keys(models[provider])[0]; let llm: BaseChatModel | undefined; let embeddings: Embeddings | undefined; diff --git a/src/websocket/websocketServer.ts b/src/websocket/websocketServer.ts index bc84f52..3ab0b51 100644 --- a/src/websocket/websocketServer.ts +++ b/src/websocket/websocketServer.ts @@ -10,9 +10,7 @@ export const initServer = ( const port = getPort(); const wss = new WebSocketServer({ server }); - wss.on('connection', (ws) => { - handleConnection(ws); - }); + wss.on('connection', handleConnection); logger.info(`WebSocket server started on port ${port}`); }; diff --git a/ui/components/ChatWindow.tsx b/ui/components/ChatWindow.tsx index 4c138ff..68a2ba0 100644 --- a/ui/components/ChatWindow.tsx +++ b/ui/components/ChatWindow.tsx @@ -19,14 +19,42 @@ const useSocket = (url: string) => { useEffect(() => { if (!ws) { - const ws = new WebSocket(url); - ws.onopen = () => { - console.log('[DEBUG] open'); - setWs(ws); + const connectWs = async () => { + let chatModel = localStorage.getItem('chatModel'); + let chatModelProvider = localStorage.getItem('chatModelProvider'); + + if (!chatModel || !chatModelProvider) { + const chatModelProviders = await fetch( + `${process.env.NEXT_PUBLIC_API_URL}/models`, + ).then(async (res) => (await res.json())['providers']); + + if ( + !chatModelProviders || + Object.keys(chatModelProviders).length === 0 + ) + return console.error('No chat models available'); + + chatModelProvider = Object.keys(chatModelProviders)[0]; + chatModel = Object.keys(chatModelProviders[chatModelProvider])[0]; + + localStorage.setItem('chatModel', chatModel!); + localStorage.setItem('chatModelProvider', chatModelProvider); + } + + const ws = new WebSocket( + `${url}?chatModel=${chatModel}&chatModelProvider=${chatModelProvider}`, + ); + ws.onopen = () => { + console.log('[DEBUG] open'); + setWs(ws); + }; }; + + connectWs(); } return () => { + 1; ws?.close(); console.log('[DEBUG] closed'); }; diff --git a/ui/components/SearchImages.tsx b/ui/components/SearchImages.tsx index 137571c..aa70c96 100644 --- a/ui/components/SearchImages.tsx +++ b/ui/components/SearchImages.tsx @@ -29,6 +29,10 @@ const SearchImages = ({