First attempt at providing a persistent chat list experience.

Limitations:

1) Context is not restored for gpt-j models
2) When you switch between different model types in an existing chat
   the context and all the conversation is lost
3) The settings are not chat or conversation specific
4) The sizes of the chat persisted files are very large due to how much
   data the llama.cpp backend tries to persist. Need to investigate how
   we can shrink this.
This commit is contained in:
Adam Treat 2023-05-04 15:31:41 -04:00
parent 081d32bd97
commit f291853e51
19 changed files with 530 additions and 208 deletions

View File

@ -60,7 +60,7 @@ qt_add_executable(chat
main.cpp
chat.h chat.cpp
chatllm.h chatllm.cpp
chatmodel.h chatlistmodel.h
chatmodel.h chatlistmodel.h chatlistmodel.cpp
download.h download.cpp
network.h network.cpp
llm.h llm.cpp

182
chat.cpp
View File

@ -1,32 +1,37 @@
#include "chat.h"
#include "llm.h"
#include "network.h"
#include "download.h"
Chat::Chat(QObject *parent)
: QObject(parent)
, m_llmodel(new ChatLLM)
, m_id(Network::globalInstance()->generateUniqueId())
, m_name(tr("New Chat"))
, m_chatModel(new ChatModel(this))
, m_responseInProgress(false)
, m_desiredThreadCount(std::min(4, (int32_t) std::thread::hardware_concurrency()))
, m_creationDate(QDateTime::currentSecsSinceEpoch())
, m_llmodel(new ChatLLM(this))
{
// Should be in same thread
connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
connect(this, &Chat::modelNameChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
// Should be in different threads
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::responseChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::modelNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::threadCountChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::syncThreadCount, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::recalcChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::handleModelNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection);
connect(this, &Chat::unloadRequested, m_llmodel, &ChatLLM::unload, Qt::QueuedConnection);
connect(this, &Chat::reloadRequested, m_llmodel, &ChatLLM::reload, Qt::QueuedConnection);
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection);
connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection);
connect(this, &Chat::unloadModelRequested, m_llmodel, &ChatLLM::unloadModel, Qt::QueuedConnection);
connect(this, &Chat::reloadModelRequested, m_llmodel, &ChatLLM::reloadModel, Qt::QueuedConnection);
connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection);
connect(this, &Chat::setThreadCountRequested, m_llmodel, &ChatLLM::setThreadCount, Qt::QueuedConnection);
// The following are blocking operations and will block the gui thread, therefore must be fast
// to respond to
@ -38,9 +43,21 @@ Chat::Chat(QObject *parent)
void Chat::reset()
{
stopGenerating();
// Erase our current on disk representation as we're completely resetting the chat along with id
LLM::globalInstance()->chatListModel()->removeChatFile(this);
emit resetContextRequested(); // blocking queued connection
m_id = Network::globalInstance()->generateUniqueId();
emit idChanged();
// NOTE: We deliberately do no reset the name or creation date to indictate that this was originally
// an older chat that was reset for another purpose. Resetting this data will lead to the chat
// name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat'
// further down in the list. This might surprise the user. In the future, we me might get rid of
// the "reset context" button in the UI. Right now, by changing the model in the combobox dropdown
// we effectively do a reset context. We *have* to do this right now when switching between different
// types of models. The only way to get rid of that would be a very long recalculate where we rebuild
// the context if we switch between different types of models. Probably the right way to fix this
// is to allow switching models but throwing up a dialog warning users if we switch between types
// of models that a long recalculation will ensue.
m_chatModel->clear();
}
@ -49,10 +66,12 @@ bool Chat::isModelLoaded() const
return m_llmodel->isModelLoaded();
}
void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens)
void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict,
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens)
{
emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens);
emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch,
repeat_penalty, repeat_penalty_tokens, LLM::globalInstance()->threadCount());
}
void Chat::regenerateResponse()
@ -70,6 +89,13 @@ QString Chat::response() const
return m_llmodel->response();
}
void Chat::handleResponseChanged()
{
const int index = m_chatModel->count() - 1;
m_chatModel->updateValue(index, response());
emit responseChanged();
}
void Chat::responseStarted()
{
m_responseInProgress = true;
@ -98,21 +124,6 @@ void Chat::setModelName(const QString &modelName)
emit modelNameChangeRequested(modelName);
}
void Chat::syncThreadCount() {
emit setThreadCountRequested(m_desiredThreadCount);
}
void Chat::setThreadCount(int32_t n_threads) {
if (n_threads <= 0)
n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
m_desiredThreadCount = n_threads;
syncThreadCount();
}
int32_t Chat::threadCount() {
return m_llmodel->threadCount();
}
void Chat::newPromptResponsePair(const QString &prompt)
{
m_chatModel->appendPrompt(tr("Prompt: "), prompt);
@ -125,16 +136,25 @@ bool Chat::isRecalc() const
return m_llmodel->isRecalc();
}
void Chat::unload()
void Chat::loadDefaultModel()
{
m_savedModelName = m_llmodel->modelName();
stopGenerating();
emit unloadRequested();
emit loadDefaultModelRequested();
}
void Chat::reload()
void Chat::loadModel(const QString &modelName)
{
emit reloadRequested(m_savedModelName);
emit loadModelRequested(modelName);
}
void Chat::unloadModel()
{
stopGenerating();
emit unloadModelRequested();
}
void Chat::reloadModel()
{
emit reloadModelRequested(m_savedModelName);
}
void Chat::generatedNameChanged()
@ -150,4 +170,98 @@ void Chat::generatedNameChanged()
void Chat::handleRecalculating()
{
Network::globalInstance()->sendRecalculatingContext(m_chatModel->count());
emit recalcChanged();
}
void Chat::handleModelNameChanged()
{
m_savedModelName = modelName();
emit modelNameChanged();
}
bool Chat::serialize(QDataStream &stream) const
{
stream << m_creationDate;
stream << m_id;
stream << m_name;
stream << m_userName;
stream << m_savedModelName;
if (!m_llmodel->serialize(stream))
return false;
if (!m_chatModel->serialize(stream))
return false;
return stream.status() == QDataStream::Ok;
}
bool Chat::deserialize(QDataStream &stream)
{
stream >> m_creationDate;
stream >> m_id;
emit idChanged();
stream >> m_name;
stream >> m_userName;
emit nameChanged();
stream >> m_savedModelName;
if (!m_llmodel->deserialize(stream))
return false;
if (!m_chatModel->deserialize(stream))
return false;
emit chatModelChanged();
return stream.status() == QDataStream::Ok;
}
QList<QString> Chat::modelList() const
{
// Build a model list from exepath and from the localpath
QList<QString> list;
QString exePath = QCoreApplication::applicationDirPath() + QDir::separator();
QString localPath = Download::globalInstance()->downloadLocalModelsPath();
{
QDir dir(exePath);
dir.setNameFilters(QStringList() << "ggml-*.bin");
QStringList fileNames = dir.entryList();
for (QString f : fileNames) {
QString filePath = exePath + f;
QFileInfo info(filePath);
QString name = info.completeBaseName().remove(0, 5);
if (info.exists()) {
if (name == modelName())
list.prepend(name);
else
list.append(name);
}
}
}
if (localPath != exePath) {
QDir dir(localPath);
dir.setNameFilters(QStringList() << "ggml-*.bin");
QStringList fileNames = dir.entryList();
for (QString f : fileNames) {
QString filePath = localPath + f;
QFileInfo info(filePath);
QString name = info.completeBaseName().remove(0, 5);
if (info.exists() && !list.contains(name)) { // don't allow duplicates
if (name == modelName())
list.prepend(name);
else
list.append(name);
}
}
}
if (list.isEmpty()) {
if (exePath != localPath) {
qWarning() << "ERROR: Could not find any applicable models in"
<< exePath << "nor" << localPath;
} else {
qWarning() << "ERROR: Could not find any applicable models in"
<< exePath;
}
return QList<QString>();
}
return list;
}

42
chat.h
View File

@ -3,6 +3,7 @@
#include <QObject>
#include <QtQml>
#include <QDataStream>
#include "chatllm.h"
#include "chatmodel.h"
@ -17,8 +18,8 @@ class Chat : public QObject
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
Q_PROPERTY(QList<QString> modelList READ modelList NOTIFY modelListChanged)
QML_ELEMENT
QML_UNCREATABLE("Only creatable from c++!")
@ -36,13 +37,10 @@ public:
Q_INVOKABLE void reset();
Q_INVOKABLE bool isModelLoaded() const;
Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict,
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
Q_INVOKABLE void regenerateResponse();
Q_INVOKABLE void stopGenerating();
Q_INVOKABLE void syncThreadCount();
Q_INVOKABLE void setThreadCount(int32_t n_threads);
Q_INVOKABLE int32_t threadCount();
Q_INVOKABLE void newPromptResponsePair(const QString &prompt);
QString response() const;
@ -51,8 +49,16 @@ public:
void setModelName(const QString &modelName);
bool isRecalc() const;
void unload();
void reload();
void loadDefaultModel();
void loadModel(const QString &modelName);
void unloadModel();
void reloadModel();
qint64 creationDate() const { return m_creationDate; }
bool serialize(QDataStream &stream) const;
bool deserialize(QDataStream &stream);
QList<QString> modelList() const;
Q_SIGNALS:
void idChanged();
@ -61,35 +67,39 @@ Q_SIGNALS:
void isModelLoadedChanged();
void responseChanged();
void responseInProgressChanged();
void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict,
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens,
int32_t n_threads);
void regenerateResponseRequested();
void resetResponseRequested();
void resetContextRequested();
void modelNameChangeRequested(const QString &modelName);
void modelNameChanged();
void threadCountChanged();
void setThreadCountRequested(int32_t threadCount);
void recalcChanged();
void unloadRequested();
void reloadRequested(const QString &modelName);
void loadDefaultModelRequested();
void loadModelRequested(const QString &modelName);
void unloadModelRequested();
void reloadModelRequested(const QString &modelName);
void generateNameRequested();
void modelListChanged();
private Q_SLOTS:
void handleResponseChanged();
void responseStarted();
void responseStopped();
void generatedNameChanged();
void handleRecalculating();
void handleModelNameChanged();
private:
ChatLLM *m_llmodel;
QString m_id;
QString m_name;
QString m_userName;
QString m_savedModelName;
ChatModel *m_chatModel;
bool m_responseInProgress;
int32_t m_desiredThreadCount;
qint64 m_creationDate;
ChatLLM *m_llmodel;
};
#endif // CHAT_H

72
chatlistmodel.cpp Normal file
View File

@ -0,0 +1,72 @@
#include "chatlistmodel.h"
#include <QFile>
#include <QDataStream>
void ChatListModel::removeChatFile(Chat *chat) const
{
QSettings settings;
QFileInfo settingsInfo(settings.fileName());
QString settingsPath = settingsInfo.absolutePath();
QFile file(settingsPath + "/gpt4all-" + chat->id() + ".chat");
if (!file.exists())
return;
bool success = file.remove();
if (!success)
qWarning() << "ERROR: Couldn't remove chat file:" << file.fileName();
}
void ChatListModel::saveChats() const
{
QSettings settings;
QFileInfo settingsInfo(settings.fileName());
QString settingsPath = settingsInfo.absolutePath();
for (Chat *chat : m_chats) {
QFile file(settingsPath + "/gpt4all-" + chat->id() + ".chat");
bool success = file.open(QIODevice::WriteOnly);
if (!success) {
qWarning() << "ERROR: Couldn't save chat to file:" << file.fileName();
continue;
}
QDataStream out(&file);
if (!chat->serialize(out)) {
qWarning() << "ERROR: Couldn't serialize chat to file:" << file.fileName();
file.remove();
}
file.close();
}
}
void ChatListModel::restoreChats()
{
QSettings settings;
QFileInfo settingsInfo(settings.fileName());
QString settingsPath = settingsInfo.absolutePath();
QDir dir(settingsPath);
dir.setNameFilters(QStringList() << "gpt4all-*.chat");
QStringList fileNames = dir.entryList();
beginResetModel();
for (QString f : fileNames) {
QString filePath = settingsPath + "/" + f;
QFile file(filePath);
bool success = file.open(QIODevice::ReadOnly);
if (!success) {
qWarning() << "ERROR: Couldn't restore chat from file:" << file.fileName();
continue;
}
QDataStream in(&file);
Chat *chat = new Chat(this);
if (!chat->deserialize(in)) {
qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName();
file.remove();
} else {
connect(chat, &Chat::nameChanged, this, &ChatListModel::nameChanged);
m_chats.append(chat);
}
file.close();
}
std::sort(m_chats.begin(), m_chats.end(), [](const Chat* a, const Chat* b) {
return a->creationDate() > b->creationDate();
});
endResetModel();
}

View File

@ -55,7 +55,7 @@ public:
Q_INVOKABLE void addChat()
{
// Don't add a new chat if the current chat is empty
// Don't add a new chat if we already have one
if (m_newChat)
return;
@ -73,13 +73,29 @@ public:
setCurrentChat(m_newChat);
}
void setNewChat(Chat* chat)
{
// Don't add a new chat if we already have one
if (m_newChat)
return;
m_newChat = chat;
connect(m_newChat->chatModel(), &ChatModel::countChanged,
this, &ChatListModel::newChatCountChanged);
connect(m_newChat, &Chat::nameChanged,
this, &ChatListModel::nameChanged);
setCurrentChat(m_newChat);
}
Q_INVOKABLE void removeChat(Chat* chat)
{
if (!m_chats.contains(chat)) {
qDebug() << "WARNING: Removing chat failed with id" << chat->id();
qWarning() << "WARNING: Removing chat failed with id" << chat->id();
return;
}
removeChatFile(chat);
emit disconnectChat(chat);
if (chat == m_newChat) {
m_newChat->disconnect(this);
@ -115,20 +131,20 @@ public:
void setCurrentChat(Chat *chat)
{
if (!m_chats.contains(chat)) {
qDebug() << "ERROR: Setting current chat failed with id" << chat->id();
qWarning() << "ERROR: Setting current chat failed with id" << chat->id();
return;
}
if (m_currentChat) {
if (m_currentChat->isModelLoaded())
m_currentChat->unload();
m_currentChat->unloadModel();
emit disconnect(m_currentChat);
}
emit connectChat(chat);
m_currentChat = chat;
if (!m_currentChat->isModelLoaded())
m_currentChat->reload();
m_currentChat->reloadModel();
emit currentChatChanged();
}
@ -138,9 +154,12 @@ public:
return m_chats.at(index);
}
int count() const { return m_chats.size(); }
void removeChatFile(Chat *chat) const;
void saveChats() const;
void restoreChats();
Q_SIGNALS:
void countChanged();
void connectChat(Chat*);

View File

@ -1,7 +1,7 @@
#include "chatllm.h"
#include "chat.h"
#include "download.h"
#include "network.h"
#include "llm.h"
#include "llmodel/gptj.h"
#include "llmodel/llamamodel.h"
@ -32,28 +32,29 @@ static QString modelFilePath(const QString &modelName)
return QString();
}
ChatLLM::ChatLLM()
ChatLLM::ChatLLM(Chat *parent)
: QObject{nullptr}
, m_llmodel(nullptr)
, m_promptResponseTokens(0)
, m_responseLogits(0)
, m_isRecalc(false)
, m_chat(parent)
{
moveToThread(&m_llmThread);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::loadModel);
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
m_llmThread.setObjectName("llm thread"); // FIXME: Should identify these with chat name
connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
m_llmThread.setObjectName(m_chat->id());
m_llmThread.start();
}
bool ChatLLM::loadModel()
bool ChatLLM::loadDefaultModel()
{
const QList<QString> models = LLM::globalInstance()->modelList();
const QList<QString> models = m_chat->modelList();
if (models.isEmpty()) {
// try again when we get a list of models
connect(Download::globalInstance(), &Download::modelListChanged, this,
&ChatLLM::loadModel, Qt::SingleShotConnection);
&ChatLLM::loadDefaultModel, Qt::SingleShotConnection);
return false;
}
@ -62,10 +63,10 @@ bool ChatLLM::loadModel()
QString defaultModel = settings.value("defaultModel", "gpt4all-j-v1.3-groovy").toString();
if (defaultModel.isEmpty() || !models.contains(defaultModel))
defaultModel = models.first();
return loadModelPrivate(defaultModel);
return loadModel(defaultModel);
}
bool ChatLLM::loadModelPrivate(const QString &modelName)
bool ChatLLM::loadModel(const QString &modelName)
{
if (isModelLoaded() && m_modelName == modelName)
return true;
@ -100,12 +101,13 @@ bool ChatLLM::loadModelPrivate(const QString &modelName)
}
emit isModelLoadedChanged();
emit threadCountChanged();
if (isFirstLoad)
emit sendStartup();
else
emit sendModelLoaded();
} else {
qWarning() << "ERROR: Could not find model at" << filePath;
}
if (m_llmodel)
@ -114,19 +116,6 @@ bool ChatLLM::loadModelPrivate(const QString &modelName)
return m_llmodel;
}
void ChatLLM::setThreadCount(int32_t n_threads) {
if (m_llmodel && m_llmodel->threadCount() != n_threads) {
m_llmodel->setThreadCount(n_threads);
emit threadCountChanged();
}
}
int32_t ChatLLM::threadCount() {
if (!m_llmodel)
return 1;
return m_llmodel->threadCount();
}
bool ChatLLM::isModelLoaded() const
{
return m_llmodel && m_llmodel->isModelLoaded();
@ -203,7 +192,7 @@ void ChatLLM::setModelName(const QString &modelName)
void ChatLLM::modelNameChangeRequested(const QString &modelName)
{
if (!loadModelPrivate(modelName))
if (!loadModel(modelName))
qWarning() << "ERROR: Could not load model" << modelName;
}
@ -247,8 +236,8 @@ bool ChatLLM::handleRecalculate(bool isRecalc)
return !m_stopGenerating;
}
bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens)
bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k,
float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, int n_threads)
{
if (!isModelLoaded())
return false;
@ -269,6 +258,7 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty;
m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llmodel->setThreadCount(n_threads);
#if defined(DEBUG)
printf("%s", qPrintable(instructPrompt));
fflush(stdout);
@ -288,19 +278,22 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
return true;
}
void ChatLLM::unload()
void ChatLLM::unloadModel()
{
saveState();
delete m_llmodel;
m_llmodel = nullptr;
emit isModelLoadedChanged();
}
void ChatLLM::reload(const QString &modelName)
void ChatLLM::reloadModel(const QString &modelName)
{
if (modelName.isEmpty())
loadModel();
else
loadModelPrivate(modelName);
if (modelName.isEmpty()) {
loadDefaultModel();
} else {
loadModel(modelName);
}
restoreState();
}
void ChatLLM::generateName()
@ -333,6 +326,11 @@ void ChatLLM::generateName()
}
}
void ChatLLM::handleChatIdChanged()
{
m_llmThread.setObjectName(m_chat->id());
}
bool ChatLLM::handleNamePrompt(int32_t token)
{
Q_UNUSED(token);
@ -354,3 +352,60 @@ bool ChatLLM::handleNameRecalculate(bool isRecalc)
Q_UNREACHABLE();
return true;
}
bool ChatLLM::serialize(QDataStream &stream)
{
stream << response();
stream << generatedName();
stream << m_promptResponseTokens;
stream << m_responseLogits;
stream << m_ctx.n_past;
stream << quint64(m_ctx.logits.size());
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.logits.data()), m_ctx.logits.size() * sizeof(float));
stream << quint64(m_ctx.tokens.size());
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.tokens.data()), m_ctx.tokens.size() * sizeof(int));
saveState();
stream << m_state;
return stream.status() == QDataStream::Ok;
}
bool ChatLLM::deserialize(QDataStream &stream)
{
QString response;
stream >> response;
m_response = response.toStdString();
QString nameResponse;
stream >> nameResponse;
m_nameResponse = nameResponse.toStdString();
stream >> m_promptResponseTokens;
stream >> m_responseLogits;
stream >> m_ctx.n_past;
quint64 logitsSize;
stream >> logitsSize;
m_ctx.logits.resize(logitsSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.logits.data()), logitsSize * sizeof(float));
quint64 tokensSize;
stream >> tokensSize;
m_ctx.tokens.resize(tokensSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
stream >> m_state;
return stream.status() == QDataStream::Ok;
}
void ChatLLM::saveState()
{
if (!isModelLoaded())
return;
const size_t stateSize = m_llmodel->stateSize();
m_state.resize(stateSize);
m_llmodel->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
}
void ChatLLM::restoreState()
{
if (!isModelLoaded())
return;
m_llmodel->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
}

View File

@ -6,18 +6,18 @@
#include "llmodel/llmodel.h"
class Chat;
class ChatLLM : public QObject
{
Q_OBJECT
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged)
public:
ChatLLM();
ChatLLM(Chat *parent);
bool isModelLoaded() const;
void regenerateResponse();
@ -25,8 +25,6 @@ public:
void resetContext();
void stopGenerating() { m_stopGenerating = true; }
void setThreadCount(int32_t n_threads);
int32_t threadCount();
QString response() const;
QString modelName() const;
@ -37,14 +35,20 @@ public:
QString generatedName() const { return QString::fromStdString(m_nameResponse); }
bool serialize(QDataStream &stream);
bool deserialize(QDataStream &stream);
public Q_SLOTS:
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
bool loadModel();
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict,
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens,
int32_t n_threads);
bool loadDefaultModel();
bool loadModel(const QString &modelName);
void modelNameChangeRequested(const QString &modelName);
void unload();
void reload(const QString &modelName);
void unloadModel();
void reloadModel(const QString &modelName);
void generateName();
void handleChatIdChanged();
Q_SIGNALS:
void isModelLoadedChanged();
@ -52,22 +56,23 @@ Q_SIGNALS:
void responseStarted();
void responseStopped();
void modelNameChanged();
void threadCountChanged();
void recalcChanged();
void sendStartup();
void sendModelLoaded();
void sendResetContext();
void generatedNameChanged();
void stateChanged();
private:
void resetContextPrivate();
bool loadModelPrivate(const QString &modelName);
bool handlePrompt(int32_t token);
bool handleResponse(int32_t token, const std::string &response);
bool handleRecalculate(bool isRecalc);
bool handleNamePrompt(int32_t token);
bool handleNameResponse(int32_t token, const std::string &response);
bool handleNameRecalculate(bool isRecalc);
void saveState();
void restoreState();
private:
LLModel::PromptContext m_ctx;
@ -77,6 +82,8 @@ private:
quint32 m_promptResponseTokens;
quint32 m_responseLogits;
QString m_modelName;
Chat *m_chat;
QByteArray m_state;
QThread m_llmThread;
std::atomic<bool> m_stopGenerating;
bool m_isRecalc;

View File

@ -3,6 +3,7 @@
#include <QAbstractListModel>
#include <QtQml>
#include <QDataStream>
struct ChatItem
{
@ -209,6 +210,46 @@ public:
int count() const { return m_chatItems.size(); }
bool serialize(QDataStream &stream) const
{
stream << count();
for (auto c : m_chatItems) {
stream << c.id;
stream << c.name;
stream << c.value;
stream << c.prompt;
stream << c.newResponse;
stream << c.currentResponse;
stream << c.stopped;
stream << c.thumbsUpState;
stream << c.thumbsDownState;
}
return stream.status() == QDataStream::Ok;
}
bool deserialize(QDataStream &stream)
{
int size;
stream >> size;
for (int i = 0; i < size; ++i) {
ChatItem c;
stream >> c.id;
stream >> c.name;
stream >> c.value;
stream >> c.prompt;
stream >> c.newResponse;
stream >> c.currentResponse;
stream >> c.stopped;
stream >> c.thumbsUpState;
stream >> c.thumbsDownState;
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size());
m_chatItems.append(c);
endInsertRows();
}
emit countChanged();
return stream.status() == QDataStream::Ok;
}
Q_SIGNALS:
void countChanged();

96
llm.cpp
View File

@ -20,77 +20,22 @@ LLM *LLM::globalInstance()
LLM::LLM()
: QObject{nullptr}
, m_chatListModel(new ChatListModel(this))
, m_threadCount(std::min(4, (int32_t) std::thread::hardware_concurrency()))
{
// Should be in the same thread
connect(Download::globalInstance(), &Download::modelListChanged,
this, &LLM::modelListChanged, Qt::DirectConnection);
connect(m_chatListModel, &ChatListModel::connectChat,
this, &LLM::connectChat, Qt::DirectConnection);
connect(m_chatListModel, &ChatListModel::disconnectChat,
this, &LLM::disconnectChat, Qt::DirectConnection);
connect(QCoreApplication::instance(), &QCoreApplication::aboutToQuit,
this, &LLM::aboutToQuit);
if (!m_chatListModel->count())
m_chatListModel->restoreChats();
if (m_chatListModel->count()) {
Chat *firstChat = m_chatListModel->get(0);
if (firstChat->chatModel()->count() < 2)
m_chatListModel->setNewChat(firstChat);
else
m_chatListModel->setCurrentChat(firstChat);
} else
m_chatListModel->addChat();
}
QList<QString> LLM::modelList() const
{
Q_ASSERT(m_chatListModel->currentChat());
const Chat *currentChat = m_chatListModel->currentChat();
// Build a model list from exepath and from the localpath
QList<QString> list;
QString exePath = QCoreApplication::applicationDirPath() + QDir::separator();
QString localPath = Download::globalInstance()->downloadLocalModelsPath();
{
QDir dir(exePath);
dir.setNameFilters(QStringList() << "ggml-*.bin");
QStringList fileNames = dir.entryList();
for (QString f : fileNames) {
QString filePath = exePath + f;
QFileInfo info(filePath);
QString name = info.completeBaseName().remove(0, 5);
if (info.exists()) {
if (name == currentChat->modelName())
list.prepend(name);
else
list.append(name);
}
}
}
if (localPath != exePath) {
QDir dir(localPath);
dir.setNameFilters(QStringList() << "ggml-*.bin");
QStringList fileNames = dir.entryList();
for (QString f : fileNames) {
QString filePath = localPath + f;
QFileInfo info(filePath);
QString name = info.completeBaseName().remove(0, 5);
if (info.exists() && !list.contains(name)) { // don't allow duplicates
if (name == currentChat->modelName())
list.prepend(name);
else
list.append(name);
}
}
}
if (list.isEmpty()) {
if (exePath != localPath) {
qWarning() << "ERROR: Could not find any applicable models in"
<< exePath << "nor" << localPath;
} else {
qWarning() << "ERROR: Could not find any applicable models in"
<< exePath;
}
return QList<QString>();
}
return list;
}
bool LLM::checkForUpdates() const
{
Network::globalInstance()->sendCheckForUpdates();
@ -113,21 +58,20 @@ bool LLM::checkForUpdates() const
return QProcess::startDetached(fileName);
}
bool LLM::isRecalc() const
int32_t LLM::threadCount() const
{
Q_ASSERT(m_chatListModel->currentChat());
return m_chatListModel->currentChat()->isRecalc();
return m_threadCount;
}
void LLM::connectChat(Chat *chat)
void LLM::setThreadCount(int32_t n_threads)
{
// Should be in the same thread
connect(chat, &Chat::modelNameChanged, this, &LLM::modelListChanged, Qt::DirectConnection);
connect(chat, &Chat::recalcChanged, this, &LLM::recalcChanged, Qt::DirectConnection);
connect(chat, &Chat::responseChanged, this, &LLM::responseChanged, Qt::DirectConnection);
if (n_threads <= 0)
n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
m_threadCount = n_threads;
emit threadCountChanged();
}
void LLM::disconnectChat(Chat *chat)
void LLM::aboutToQuit()
{
chat->disconnect(this);
m_chatListModel->saveChats();
}

16
llm.h
View File

@ -3,37 +3,33 @@
#include <QObject>
#include "chat.h"
#include "chatlistmodel.h"
class LLM : public QObject
{
Q_OBJECT
Q_PROPERTY(QList<QString> modelList READ modelList NOTIFY modelListChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
Q_PROPERTY(ChatListModel *chatListModel READ chatListModel NOTIFY chatListModelChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
public:
static LLM *globalInstance();
QList<QString> modelList() const;
bool isRecalc() const;
ChatListModel *chatListModel() const { return m_chatListModel; }
int32_t threadCount() const;
void setThreadCount(int32_t n_threads);
Q_INVOKABLE bool checkForUpdates() const;
Q_SIGNALS:
void modelListChanged();
void recalcChanged();
void responseChanged();
void chatListModelChanged();
void threadCountChanged();
private Q_SLOTS:
void connectChat(Chat*);
void disconnectChat(Chat*);
void aboutToQuit();
private:
ChatListModel *m_chatListModel;
int32_t m_threadCount;
private:
explicit LLM();

View File

@ -67,6 +67,7 @@ int32_t LLamaModel::threadCount() {
LLamaModel::~LLamaModel()
{
llama_free(d_ptr->ctx);
}
bool LLamaModel::isModelLoaded() const
@ -74,6 +75,21 @@ bool LLamaModel::isModelLoaded() const
return d_ptr->modelLoaded;
}
size_t LLamaModel::stateSize() const
{
return llama_get_state_size(d_ptr->ctx);
}
size_t LLamaModel::saveState(uint8_t *dest) const
{
return llama_copy_state_data(d_ptr->ctx, dest);
}
size_t LLamaModel::restoreState(const uint8_t *src)
{
return llama_set_state_data(d_ptr->ctx, src);
}
void LLamaModel::prompt(const std::string &prompt,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,

View File

@ -14,6 +14,9 @@ public:
bool loadModel(const std::string &modelPath) override;
bool isModelLoaded() const override;
size_t stateSize() const override;
size_t saveState(uint8_t *dest) const override;
size_t restoreState(const uint8_t *src) override;
void prompt(const std::string &prompt,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,

View File

@ -12,6 +12,9 @@ public:
virtual bool loadModel(const std::string &modelPath) = 0;
virtual bool isModelLoaded() const = 0;
virtual size_t stateSize() const { return 0; }
virtual size_t saveState(uint8_t *dest) const { return 0; }
virtual size_t restoreState(const uint8_t *src) { return 0; }
struct PromptContext {
std::vector<float> logits; // logits of current context
std::vector<int32_t> tokens; // current tokens in the context window

View File

@ -48,6 +48,24 @@ bool llmodel_isModelLoaded(llmodel_model model)
return wrapper->llModel->isModelLoaded();
}
uint64_t llmodel_get_state_size(llmodel_model model)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->stateSize();
}
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->saveState(dest);
}
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
return wrapper->llModel->restoreState(src);
}
// Wrapper functions for the C callbacks
bool prompt_wrapper(int32_t token_id, void *user_data) {
llmodel_prompt_callback callback = reinterpret_cast<llmodel_prompt_callback>(user_data);

View File

@ -98,6 +98,32 @@ bool llmodel_loadModel(llmodel_model model, const char *model_path);
*/
bool llmodel_isModelLoaded(llmodel_model model);
/**
* Get the size of the internal state of the model.
* NOTE: This state data is specific to the type of model you have created.
* @param model A pointer to the llmodel_model instance.
* @return the size in bytes of the internal state of the model
*/
uint64_t llmodel_get_state_size(llmodel_model model);
/**
* Saves the internal state of the model to the specified destination address.
* NOTE: This state data is specific to the type of model you have created.
* @param model A pointer to the llmodel_model instance.
* @param dest A pointer to the destination.
* @return the number of bytes copied
*/
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest);
/**
* Restores the internal state of the model using data from the specified address.
* NOTE: This state data is specific to the type of model you have created.
* @param model A pointer to the llmodel_model instance.
* @param src A pointer to the src.
* @return the number of bytes read
*/
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
/**
* Generate a response using the model.
* @param model A pointer to the llmodel_model instance.

View File

@ -65,7 +65,7 @@ Window {
}
// check for any current models and if not, open download dialog
if (LLM.modelList.length === 0 && !firstStartDialog.opened) {
if (currentChat.modelList.length === 0 && !firstStartDialog.opened) {
downloadNewModels.open();
return;
}
@ -125,7 +125,7 @@ Window {
anchors.horizontalCenter: parent.horizontalCenter
font.pixelSize: theme.fontSizeLarge
spacing: 0
model: LLM.modelList
model: currentChat.modelList
Accessible.role: Accessible.ComboBox
Accessible.name: qsTr("ComboBox for displaying/picking the current model")
Accessible.description: qsTr("Use this for picking the current model to use; the first item is the current model")
@ -367,9 +367,9 @@ Window {
text: qsTr("Recalculating context.")
Connections {
target: LLM
target: currentChat
function onRecalcChanged() {
if (LLM.isRecalc)
if (currentChat.isRecalc)
recalcPopup.open()
else
recalcPopup.close()
@ -422,10 +422,7 @@ Window {
var item = chatModel.get(i)
var string = item.name;
var isResponse = item.name === qsTr("Response: ")
if (item.currentResponse)
string += currentChat.response
else
string += chatModel.get(i).value
string += chatModel.get(i).value
if (isResponse && item.stopped)
string += " <stopped>"
string += "\n"
@ -440,10 +437,7 @@ Window {
var item = chatModel.get(i)
var isResponse = item.name === qsTr("Response: ")
str += "{\"content\": ";
if (item.currentResponse)
str += JSON.stringify(currentChat.response)
else
str += JSON.stringify(item.value)
str += JSON.stringify(item.value)
str += ", \"role\": \"" + (isResponse ? "assistant" : "user") + "\"";
if (isResponse && item.thumbsUpState !== item.thumbsDownState)
str += ", \"rating\": \"" + (item.thumbsUpState ? "positive" : "negative") + "\"";
@ -572,14 +566,14 @@ Window {
Accessible.description: qsTr("This is the list of prompt/response pairs comprising the actual conversation with the model")
delegate: TextArea {
text: currentResponse ? currentChat.response : (value ? value : "")
text: value
width: listView.width
color: theme.textColor
wrapMode: Text.WordWrap
focus: false
readOnly: true
font.pixelSize: theme.fontSizeLarge
cursorVisible: currentResponse ? (currentChat.response !== "" ? currentChat.responseInProgress : false) : false
cursorVisible: currentResponse ? currentChat.responseInProgress : false
cursorPosition: text.length
background: Rectangle {
color: name === qsTr("Response: ") ? theme.backgroundLighter : theme.backgroundLight
@ -599,8 +593,8 @@ Window {
anchors.leftMargin: 90
anchors.top: parent.top
anchors.topMargin: 5
visible: (currentResponse ? true : false) && currentChat.response === "" && currentChat.responseInProgress
running: (currentResponse ? true : false) && currentChat.response === "" && currentChat.responseInProgress
visible: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress
running: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress
Accessible.role: Accessible.Animation
Accessible.name: qsTr("Busy indicator")
@ -631,7 +625,7 @@ Window {
window.height / 2 - height / 2)
x: globalPoint.x
y: globalPoint.y
property string text: currentResponse ? currentChat.response : (value ? value : "")
property string text: value
response: newResponse === undefined || newResponse === "" ? text : newResponse
onAccepted: {
var responseHasChanged = response !== text && response !== newResponse
@ -711,7 +705,7 @@ Window {
property bool isAutoScrolling: false
Connections {
target: LLM
target: currentChat
function onResponseChanged() {
if (listView.shouldAutoScroll) {
listView.isAutoScrolling = true
@ -762,7 +756,6 @@ Window {
if (listElement.name === qsTr("Response: ")) {
chatModel.updateCurrentResponse(index, true);
chatModel.updateStopped(index, false);
chatModel.updateValue(index, currentChat.response);
chatModel.updateThumbsUpState(index, false);
chatModel.updateThumbsDownState(index, false);
chatModel.updateNewResponse(index, "");
@ -840,7 +833,6 @@ Window {
var index = Math.max(0, chatModel.count - 1);
var listElement = chatModel.get(index);
chatModel.updateCurrentResponse(index, false);
chatModel.updateValue(index, currentChat.response);
}
currentChat.newPromptResponsePair(textInput.text);
currentChat.prompt(textInput.text, settingsDialog.promptTemplate,

View File

@ -458,7 +458,6 @@ void Network::handleIpifyFinished()
void Network::handleMixpanelFinished()
{
Q_ASSERT(m_usageStatsActive);
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
if (!reply)
return;

View File

@ -83,6 +83,7 @@ Drawer {
opacity: 0.9
property bool isCurrent: LLM.chatListModel.currentChat === LLM.chatListModel.get(index)
property bool trashQuestionDisplayed: false
z: isCurrent ? 199 : 1
color: index % 2 === 0 ? theme.backgroundLight : theme.backgroundLighter
border.width: isCurrent
border.color: chatName.readOnly ? theme.assistantColor : theme.userColor
@ -112,6 +113,11 @@ Drawer {
color: "transparent"
}
onEditingFinished: {
// Work around a bug in qml where we're losing focus when the whole window
// goes out of focus even though this textfield should be marked as not
// having focus
if (chatName.readOnly)
return;
changeName();
Network.sendRenameChat()
}
@ -188,6 +194,7 @@ Drawer {
visible: isCurrent && trashQuestionDisplayed
opacity: 1.0
radius: 10
z: 200
Row {
spacing: 10
Button {

View File

@ -12,7 +12,7 @@ Dialog {
id: modelDownloaderDialog
modal: true
opacity: 0.9
closePolicy: LLM.modelList.length === 0 ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside)
closePolicy: LLM.chatListModel.currentChat.modelList.length === 0 ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside)
background: Rectangle {
anchors.fill: parent
anchors.margins: -20