gpt4all/gpt4all-chat/chatllm.cpp

1117 lines
40 KiB
C++
Raw Normal View History

#include "chatllm.h"
#include "chat.h"
#include "chatapi.h"
#include "localdocs.h"
2023-06-22 22:44:49 +03:00
#include "modellist.h"
#include "network.h"
2023-06-27 18:54:34 +03:00
#include "mysettings.h"
#include "../gpt4all-backend/llmodel.h"
//#define DEBUG
//#define DEBUG_MODEL_LOADING
#define GPTJ_INTERNAL_STATE_VERSION 0
#define LLAMA_INTERNAL_STATE_VERSION 0
class LLModelStore {
public:
static LLModelStore *globalInstance();
LLModelInfo acquireModel(); // will block until llmodel is ready
void releaseModel(const LLModelInfo &info); // must be called when you are done
private:
LLModelStore()
{
// seed with empty model
m_availableModels.append(LLModelInfo());
}
~LLModelStore() {}
QVector<LLModelInfo> m_availableModels;
QMutex m_mutex;
QWaitCondition m_condition;
friend class MyLLModelStore;
};
class MyLLModelStore : public LLModelStore { };
Q_GLOBAL_STATIC(MyLLModelStore, storeInstance)
LLModelStore *LLModelStore::globalInstance()
{
return storeInstance();
}
LLModelInfo LLModelStore::acquireModel()
{
QMutexLocker locker(&m_mutex);
while (m_availableModels.isEmpty())
m_condition.wait(locker.mutex());
return m_availableModels.takeFirst();
}
void LLModelStore::releaseModel(const LLModelInfo &info)
{
QMutexLocker locker(&m_mutex);
m_availableModels.append(info);
Q_ASSERT(m_availableModels.count() < 2);
m_condition.wakeAll();
}
ChatLLM::ChatLLM(Chat *parent, bool isServer)
: QObject{nullptr}
, m_promptResponseTokens(0)
2023-05-11 23:46:25 +03:00
, m_promptTokens(0)
, m_isRecalc(false)
, m_shouldBeLoaded(false)
, m_forceUnloadModel(false)
, m_markedForDeletion(false)
, m_shouldTrySwitchContext(false)
2023-06-20 01:26:04 +03:00
, m_stopGenerating(false)
, m_timer(nullptr)
, m_isServer(isServer)
2023-06-27 18:54:34 +03:00
, m_forceMetal(MySettings::globalInstance()->forceMetal())
, m_reloadingToChangeVariant(false)
, m_processedSystemPrompt(false)
2023-10-10 23:43:02 +03:00
, m_restoreStateFromText(false)
{
moveToThread(&m_llmThread);
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
2023-06-01 21:13:12 +03:00
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
Qt::QueuedConnection); // explicitly queued
connect(this, &ChatLLM::shouldTrySwitchContextChanged, this, &ChatLLM::handleShouldTrySwitchContextChanged,
Qt::QueuedConnection); // explicitly queued
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
2023-06-27 18:54:34 +03:00
connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged);
2023-09-13 17:32:08 +03:00
connect(MySettings::globalInstance(), &MySettings::deviceChanged, this, &ChatLLM::handleDeviceChanged);
2023-06-01 21:13:12 +03:00
// The following are blocking operations and will block the llm thread
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
Qt::BlockingQueuedConnection);
m_llmThread.setObjectName(parent->id());
m_llmThread.start();
}
2023-05-12 21:06:03 +03:00
ChatLLM::~ChatLLM()
{
destroy();
}
void ChatLLM::destroy() {
2023-07-09 21:42:11 +03:00
m_stopGenerating = true;
2023-05-12 21:06:03 +03:00
m_llmThread.quit();
m_llmThread.wait();
// The only time we should have a model loaded here is on shutdown
// as we explicitly unload the model in all other circumstances
if (isModelLoaded()) {
2023-06-22 22:44:49 +03:00
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
}
2023-05-12 21:06:03 +03:00
}
void ChatLLM::handleThreadStarted()
{
m_timer = new TokenTimer(this);
connect(m_timer, &TokenTimer::report, this, &ChatLLM::reportSpeed);
emit threadStarted();
}
2023-06-27 18:54:34 +03:00
void ChatLLM::handleForceMetalChanged(bool forceMetal)
{
#if defined(Q_OS_MAC) && defined(__arm__)
m_forceMetal = forceMetal;
if (isModelLoaded() && m_shouldBeLoaded) {
m_reloadingToChangeVariant = true;
unloadModel();
reloadModel();
m_reloadingToChangeVariant = false;
}
#endif
}
2023-09-13 17:32:08 +03:00
void ChatLLM::handleDeviceChanged()
{
if (isModelLoaded() && m_shouldBeLoaded) {
m_reloadingToChangeVariant = true;
unloadModel();
reloadModel();
m_reloadingToChangeVariant = false;
}
}
bool ChatLLM::loadDefaultModel()
{
2023-06-22 22:44:49 +03:00
ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo();
if (defaultModel.filename().isEmpty()) {
2023-06-22 22:44:49 +03:00
emit modelLoadingError(QString("Could not find any model to load"));
return false;
}
2023-06-22 22:44:49 +03:00
return loadModel(defaultModel);
}
bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
{
// We're trying to see if the store already has the model fully loaded that we wish to use
// and if so we just acquire it from the store and switch the context and return true. If the
// store doesn't have it or we're already loaded or in any other case just return false.
// If we're already loaded or a server or we're reloading to change the variant/device or the
// modelInfo is empty, then this should fail
if (isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty()) {
m_shouldTrySwitchContext = false;
emit trySwitchContextOfLoadedModelCompleted(false);
return false;
}
QString filePath = modelInfo.dirpath + modelInfo.filename();
QFileInfo fileInfo(filePath);
m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
// The store gave us no already loaded model, the wrong type of model, then give it back to the
// store and fail
if (!m_llModelInfo.model || m_llModelInfo.fileInfo != fileInfo) {
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
m_llModelInfo = LLModelInfo();
m_shouldTrySwitchContext = false;
emit trySwitchContextOfLoadedModelCompleted(false);
return false;
}
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
// We should be loaded and now we are
m_shouldBeLoaded = true;
m_shouldTrySwitchContext = false;
// Restore, signal and process
restoreState();
emit modelLoadingPercentageChanged(1.0f);
emit trySwitchContextOfLoadedModelCompleted(true);
processSystemPrompt();
return true;
}
2023-06-22 22:44:49 +03:00
bool ChatLLM::loadModel(const ModelInfo &modelInfo)
{
// This is a complicated method because N different possible threads are interested in the outcome
// of this method. Why? Because we have a main/gui thread trying to monitor the state of N different
// possible chat threads all vying for a single resource - the currently loaded model - as the user
// switches back and forth between chats. It is important for our main/gui thread to never block
// but simultaneously always have up2date information with regards to which chat has the model loaded
// and what the type and name of that model is. I've tried to comment extensively in this method
// to provide an overview of what we're doing here.
// We're already loaded with this model
2023-06-22 22:44:49 +03:00
if (isModelLoaded() && this->modelInfo() == modelInfo)
return true;
QString filePath = modelInfo.dirpath + modelInfo.filename();
QFileInfo fileInfo(filePath);
// We have a live model, but it isn't the one we want
bool alreadyAcquired = isModelLoaded();
if (alreadyAcquired) {
resetContext();
#if defined(DEBUG_MODEL_LOADING)
2023-06-22 22:44:49 +03:00
qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
2023-06-22 22:44:49 +03:00
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
} else if (!m_isServer) {
// This is a blocking call that tries to retrieve the model we need from the model store.
// If it succeeds, then we just have to restore state. If the store has never had a model
// returned to it, then the modelInfo.model pointer should be null which will happen on startup
2023-06-22 22:44:49 +03:00
m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
#if defined(DEBUG_MODEL_LOADING)
2023-06-27 18:54:34 +03:00
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
// At this point it is possible that while we were blocked waiting to acquire the model from the
// store, that our state was changed to not be loaded. If this is the case, release the model
// back into the store and quit loading
if (!m_shouldBeLoaded) {
#if defined(DEBUG_MODEL_LOADING)
2023-06-22 22:44:49 +03:00
qDebug() << "no longer need model" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
2023-06-22 22:44:49 +03:00
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
m_llModelInfo = LLModelInfo();
emit modelLoadingPercentageChanged(0.0f);
return false;
}
// Check if the store just gave us exactly the model we were looking for
2023-06-27 18:54:34 +03:00
if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo && !m_reloadingToChangeVariant) {
#if defined(DEBUG_MODEL_LOADING)
2023-06-22 22:44:49 +03:00
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
restoreState();
emit modelLoadingPercentageChanged(1.0f);
setModelInfo(modelInfo);
Q_ASSERT(!m_modelInfo.filename().isEmpty());
if (m_modelInfo.filename().isEmpty())
emit modelLoadingError(QString("Modelinfo is left null for %1").arg(modelInfo.filename()));
else
processSystemPrompt();
return true;
} else {
// Release the memory since we have to switch to a different model.
#if defined(DEBUG_MODEL_LOADING)
2023-06-22 22:44:49 +03:00
qDebug() << "deleting model" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
2023-06-22 22:44:49 +03:00
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
}
}
// Guarantee we've released the previous models memory
2023-06-22 22:44:49 +03:00
Q_ASSERT(!m_llModelInfo.model);
// Store the file info in the modelInfo in case we have an error loading
2023-06-22 22:44:49 +03:00
m_llModelInfo.fileInfo = fileInfo;
if (fileInfo.exists()) {
if (modelInfo.isOnline) {
QString apiKey;
QString modelName;
{
QFile file(filePath);
bool success = file.open(QIODeviceBase::ReadOnly);
(void)success;
Q_ASSERT(success);
QJsonDocument doc = QJsonDocument::fromJson(file.readAll());
QJsonObject obj = doc.object();
apiKey = obj["apiKey"].toString();
modelName = obj["modelName"].toString();
}
m_llModelType = LLModelType::API_;
ChatAPI *model = new ChatAPI();
model->setModelName(modelName);
model->setRequestURL(modelInfo.url());
model->setAPIKey(apiKey);
2023-06-22 22:44:49 +03:00
m_llModelInfo.model = model;
} else {
auto n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo);
m_ctx.n_ctx = n_ctx;
auto ngl = MySettings::globalInstance()->modelGpuLayers(modelInfo);
std::string buildVariant = "auto";
2023-06-27 18:54:34 +03:00
#if defined(Q_OS_MAC) && defined(__arm__)
if (m_forceMetal)
buildVariant = "metal";
2023-06-27 18:54:34 +03:00
#endif
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
2023-06-27 18:54:34 +03:00
2023-06-22 22:44:49 +03:00
if (m_llModelInfo.model) {
if (m_llModelInfo.model->isModelBlacklisted(filePath.toStdString())) {
static QSet<QString> warned;
auto fname = modelInfo.filename();
if (!warned.contains(fname)) {
emit modelLoadingWarning(QString(
"%1 is known to be broken. Please get a replacement via the download dialog."
).arg(fname));
warned.insert(fname); // don't warn again until restart
}
}
m_llModelInfo.model->setProgressCallback([this](float progress) -> bool {
emit modelLoadingPercentageChanged(progress);
return m_shouldBeLoaded;
});
2023-09-13 17:32:08 +03:00
// Pick the best match for the device
2023-09-14 15:25:37 +03:00
QString actualDevice = m_llModelInfo.model->implementation().buildVariant() == "metal" ? "Metal" : "CPU";
2023-09-13 17:32:08 +03:00
const QString requestedDevice = MySettings::globalInstance()->device();
if (requestedDevice == "CPU") {
emit reportFallbackReason(""); // fallback not applicable
} else {
const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString(), n_ctx, ngl);
2023-09-13 17:32:08 +03:00
std::vector<LLModel::GPUDevice> availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory);
LLModel::GPUDevice *device = nullptr;
if (!availableDevices.empty() && requestedDevice == "Auto" && availableDevices.front().type == 2 /*a discrete gpu*/) {
device = &availableDevices.front();
2023-09-13 17:32:08 +03:00
} else {
for (LLModel::GPUDevice &d : availableDevices) {
if (QString::fromStdString(d.name) == requestedDevice) {
device = &d;
break;
}
2023-09-13 17:32:08 +03:00
}
}
emit reportFallbackReason(""); // no fallback yet
std::string unavail_reason;
if (!device) {
// GPU not available
} else if (!m_llModelInfo.model->initializeGPUDevice(device->index, &unavail_reason)) {
2023-10-06 18:30:55 +03:00
emit reportFallbackReason(QString::fromStdString("<br>" + unavail_reason));
} else {
actualDevice = QString::fromStdString(device->name);
}
2023-09-13 17:32:08 +03:00
}
2023-09-14 15:25:37 +03:00
// Report which device we're actually using
emit reportDevice(actualDevice);
bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, ngl);
if (actualDevice == "CPU") {
// we asked llama.cpp to use the CPU
} else if (!success) {
// llama_init_from_file returned nullptr
2023-09-14 23:52:31 +03:00
emit reportDevice("CPU");
2023-10-06 18:30:55 +03:00
emit reportFallbackReason("<br>GPU loading failed (out of VRAM?)");
success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, 0);
} else if (!m_llModelInfo.model->usingGPUDevice()) {
// ggml_vk_init was not called in llama.cpp
// We might have had to fallback to CPU after load if the model is not possible to accelerate
// for instance if the quantization method is not supported on Vulkan yet
emit reportDevice("CPU");
2023-10-06 18:30:55 +03:00
emit reportFallbackReason("<br>model or quant has no GPU support");
2023-09-14 23:52:31 +03:00
}
if (!success) {
2023-09-14 23:52:31 +03:00
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
if (!m_isServer)
2023-06-22 22:44:49 +03:00
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelInfo.filename()));
} else {
switch (m_llModelInfo.model->implementation().modelType()[0]) {
2023-06-22 22:44:49 +03:00
case 'L': m_llModelType = LLModelType::LLAMA_; break;
case 'G': m_llModelType = LLModelType::GPTJ_; break;
default:
{
2023-09-14 23:52:31 +03:00
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
if (!m_isServer)
2023-06-22 22:44:49 +03:00
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not determine model type for %1").arg(modelInfo.filename()));
}
}
}
} else {
if (!m_isServer)
2023-06-22 22:44:49 +03:00
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not load model due to invalid format for %1").arg(modelInfo.filename()));
}
}
#if defined(DEBUG_MODEL_LOADING)
2023-06-22 22:44:49 +03:00
qDebug() << "new model" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
restoreState();
#if defined(DEBUG)
qDebug() << "modelLoadedChanged" << m_llmThread.objectName();
fflush(stdout);
#endif
emit modelLoadingPercentageChanged(isModelLoaded() ? 1.0f : 0.0f);
2023-05-10 06:43:16 +03:00
static bool isFirstLoad = true;
2023-05-09 18:46:33 +03:00
if (isFirstLoad) {
emit sendStartup();
2023-05-09 18:46:33 +03:00
isFirstLoad = false;
} else
emit sendModelLoaded();
} else {
if (!m_isServer)
2023-06-22 22:44:49 +03:00
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not find file for model %1").arg(modelInfo.filename()));
}
if (m_llModelInfo.model) {
2023-06-22 22:44:49 +03:00
setModelInfo(modelInfo);
processSystemPrompt();
}
2023-06-22 22:44:49 +03:00
return m_llModelInfo.model;
}
bool ChatLLM::isModelLoaded() const
{
2023-06-22 22:44:49 +03:00
return m_llModelInfo.model && m_llModelInfo.model->isModelLoaded();
}
std::string remove_leading_whitespace(const std::string& input) {
auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) {
return !std::isspace(c);
});
if (first_non_whitespace == input.end())
return std::string();
return std::string(first_non_whitespace, input.end());
}
std::string trim_whitespace(const std::string& input) {
auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) {
return !std::isspace(c);
});
if (first_non_whitespace == input.end())
return std::string();
auto last_non_whitespace = std::find_if(input.rbegin(), input.rend(), [](unsigned char c) {
return !std::isspace(c);
}).base();
return std::string(first_non_whitespace, last_non_whitespace);
}
void ChatLLM::regenerateResponse()
{
// ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning
// of n_past is of the number of prompt/response pairs, rather than for total tokens.
if (m_llModelType == LLModelType::API_)
m_ctx.n_past -= 1;
else
m_ctx.n_past -= m_promptResponseTokens;
m_ctx.n_past = std::max(0, m_ctx.n_past);
m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end());
m_promptResponseTokens = 0;
2023-05-11 23:46:25 +03:00
m_promptTokens = 0;
m_response = std::string();
emit responseChanged(QString::fromStdString(m_response));
}
void ChatLLM::resetResponse()
{
2023-05-11 23:46:25 +03:00
m_promptTokens = 0;
m_promptResponseTokens = 0;
m_response = std::string();
emit responseChanged(QString::fromStdString(m_response));
}
void ChatLLM::resetContext()
{
resetResponse();
m_processedSystemPrompt = false;
m_ctx = LLModel::PromptContext();
}
QString ChatLLM::response() const
{
return QString::fromStdString(remove_leading_whitespace(m_response));
}
2023-06-22 22:44:49 +03:00
ModelInfo ChatLLM::modelInfo() const
{
2023-06-22 22:44:49 +03:00
return m_modelInfo;
}
2023-06-22 22:44:49 +03:00
void ChatLLM::setModelInfo(const ModelInfo &modelInfo)
{
2023-06-22 22:44:49 +03:00
m_modelInfo = modelInfo;
emit modelInfoChanged(modelInfo);
}
2023-06-22 22:44:49 +03:00
void ChatLLM::modelChangeRequested(const ModelInfo &modelInfo)
{
m_shouldBeLoaded = true;
2023-06-22 22:44:49 +03:00
loadModel(modelInfo);
}
bool ChatLLM::handlePrompt(int32_t token)
{
2023-06-20 01:23:05 +03:00
// m_promptResponseTokens is related to last prompt/response not
// the entire context window which we can reset on regenerate prompt
#if defined(DEBUG)
qDebug() << "prompt process" << m_llmThread.objectName() << token;
#endif
2023-05-11 23:46:25 +03:00
++m_promptTokens;
++m_promptResponseTokens;
m_timer->start();
return !m_stopGenerating;
}
bool ChatLLM::handleResponse(int32_t token, const std::string &response)
{
#if defined(DEBUG)
printf("%s", response.c_str());
fflush(stdout);
#endif
// check for error
if (token < 0) {
m_response.append(response);
emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response)));
return false;
}
2023-06-20 01:23:05 +03:00
// m_promptResponseTokens is related to last prompt/response not
// the entire context window which we can reset on regenerate prompt
++m_promptResponseTokens;
m_timer->inc();
Q_ASSERT(!response.empty());
m_response.append(response);
emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response)));
return !m_stopGenerating;
}
bool ChatLLM::handleRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "recalculate" << m_llmThread.objectName() << isRecalc;
#endif
if (m_isRecalc != isRecalc) {
m_isRecalc = isRecalc;
emit recalcChanged();
}
return !m_stopGenerating;
}
bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt)
{
if (m_restoreStateFromText) {
Q_ASSERT(m_state.isEmpty());
processRestoreStateFromText();
}
if (!m_processedSystemPrompt)
processSystemPrompt();
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo);
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo);
return promptInternal(collectionList, prompt, promptTemplate, n_predict, top_k, top_p, min_p, temp, n_batch,
repeat_penalty, repeat_penalty_tokens);
}
bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens)
{
if (!isModelLoaded())
return false;
2023-06-20 01:23:54 +03:00
QList<ResultInfo> databaseResults;
const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize();
if (!collectionList.isEmpty()) {
emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks
emit databaseResultsChanged(databaseResults);
}
2023-06-01 21:13:12 +03:00
// Augment the prompt template with the results if any
QList<QString> docsContext;
2023-06-20 01:23:54 +03:00
if (!databaseResults.isEmpty())
docsContext.append("### Context:");
2023-06-20 01:23:54 +03:00
for (const ResultInfo &info : databaseResults)
docsContext.append(info.text);
int n_threads = MySettings::globalInstance()->threadCount();
m_stopGenerating = false;
auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1);
emit promptProcessing();
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
m_ctx.top_p = top_p;
m_ctx.min_p = min_p;
m_ctx.temp = temp;
m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty;
m_ctx.repeat_last_n = repeat_penalty_tokens;
2023-06-22 22:44:49 +03:00
m_llModelInfo.model->setThreadCount(n_threads);
#if defined(DEBUG)
printf("%s", qPrintable(prompt));
fflush(stdout);
#endif
m_timer->start();
if (!docsContext.isEmpty()) {
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response
m_llModelInfo.model->prompt(docsContext.join("\n").toStdString(), "%1", promptFunc, responseFunc, recalcFunc, m_ctx);
m_ctx.n_predict = old_n_predict; // now we are ready for a response
}
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
#if defined(DEBUG)
printf("\n");
fflush(stdout);
#endif
m_timer->stop();
std::string trimmed = trim_whitespace(m_response);
if (trimmed != m_response) {
m_response = trimmed;
emit responseChanged(QString::fromStdString(m_response));
}
emit responseStopped();
return true;
}
void ChatLLM::setShouldBeLoaded(bool b)
{
#if defined(DEBUG_MODEL_LOADING)
2023-06-22 22:44:49 +03:00
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model;
#endif
m_shouldBeLoaded = b; // atomic
emit shouldBeLoadedChanged();
}
void ChatLLM::setShouldTrySwitchContext(bool b)
{
m_shouldTrySwitchContext = b; // atomic
emit shouldTrySwitchContextChanged();
}
void ChatLLM::handleShouldBeLoadedChanged()
{
if (m_shouldBeLoaded)
reloadModel();
else
unloadModel();
}
void ChatLLM::handleShouldTrySwitchContextChanged()
{
if (m_shouldTrySwitchContext)
trySwitchContextOfLoadedModel(modelInfo());
}
void ChatLLM::unloadModel()
{
if (!isModelLoaded() || m_isServer)
return;
if (!m_forceUnloadModel || !m_shouldBeLoaded)
emit modelLoadingPercentageChanged(0.0f);
else
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
if (!m_markedForDeletion)
saveState();
#if defined(DEBUG_MODEL_LOADING)
2023-06-22 22:44:49 +03:00
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
if (m_forceUnloadModel) {
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
m_forceUnloadModel = false;
}
2023-06-22 22:44:49 +03:00
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
m_llModelInfo = LLModelInfo();
}
void ChatLLM::reloadModel()
{
if (isModelLoaded() && m_forceUnloadModel)
unloadModel(); // we unload first if we are forcing an unload
if (isModelLoaded() || m_isServer)
return;
#if defined(DEBUG_MODEL_LOADING)
2023-06-22 22:44:49 +03:00
qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
2023-06-22 22:44:49 +03:00
const ModelInfo m = modelInfo();
if (m.name().isEmpty())
loadDefaultModel();
else
loadModel(m);
}
2023-05-02 18:19:17 +03:00
void ChatLLM::generateName()
{
Q_ASSERT(isModelLoaded());
if (!isModelLoaded())
return;
std::string instructPrompt("### Instruction:\n%1\n### Response:\n"); // standard Alpaca
2023-05-02 18:19:17 +03:00
auto promptFunc = std::bind(&ChatLLM::handleNamePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2);
2023-05-02 18:19:17 +03:00
auto recalcFunc = std::bind(&ChatLLM::handleNameRecalculate, this, std::placeholders::_1);
LLModel::PromptContext ctx = m_ctx;
m_llModelInfo.model->prompt("Describe response above in three words.", instructPrompt, promptFunc, responseFunc,
recalcFunc, ctx);
2023-05-02 18:19:17 +03:00
std::string trimmed = trim_whitespace(m_nameResponse);
if (trimmed != m_nameResponse) {
m_nameResponse = trimmed;
emit generatedNameChanged(QString::fromStdString(m_nameResponse));
2023-05-02 18:19:17 +03:00
}
}
void ChatLLM::handleChatIdChanged(const QString &id)
{
m_llmThread.setObjectName(id);
}
2023-05-02 18:19:17 +03:00
bool ChatLLM::handleNamePrompt(int32_t token)
{
#if defined(DEBUG)
qDebug() << "name prompt" << m_llmThread.objectName() << token;
#endif
2023-05-02 18:19:17 +03:00
Q_UNUSED(token);
qt_noop();
2023-07-09 21:42:11 +03:00
return !m_stopGenerating;
2023-05-02 18:19:17 +03:00
}
bool ChatLLM::handleNameResponse(int32_t token, const std::string &response)
{
#if defined(DEBUG)
qDebug() << "name response" << m_llmThread.objectName() << token << response;
#endif
2023-05-02 18:19:17 +03:00
Q_UNUSED(token);
2023-05-08 19:02:31 +03:00
2023-05-02 18:19:17 +03:00
m_nameResponse.append(response);
emit generatedNameChanged(QString::fromStdString(m_nameResponse));
2023-05-08 19:02:31 +03:00
QString gen = QString::fromStdString(m_nameResponse).simplified();
QStringList words = gen.split(' ', Qt::SkipEmptyParts);
return words.size() <= 3;
2023-05-02 18:19:17 +03:00
}
bool ChatLLM::handleNameRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "name recalc" << m_llmThread.objectName() << isRecalc;
#endif
2023-05-02 18:19:17 +03:00
Q_UNUSED(isRecalc);
2023-07-09 18:32:51 +03:00
qt_noop();
return true;
2023-05-02 18:19:17 +03:00
}
bool ChatLLM::handleSystemPrompt(int32_t token)
{
#if defined(DEBUG)
qDebug() << "system prompt" << m_llmThread.objectName() << token << m_stopGenerating;
#endif
Q_UNUSED(token);
2023-07-09 21:42:11 +03:00
return !m_stopGenerating;
}
bool ChatLLM::handleSystemRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "system recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return false;
}
2023-10-10 23:43:02 +03:00
bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
{
#if defined(DEBUG)
qDebug() << "restore state from text prompt" << m_llmThread.objectName() << token << m_stopGenerating;
#endif
Q_UNUSED(token);
return !m_stopGenerating;
}
bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "restore state from text recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return false;
}
// this function serialized the cached model state to disk.
// we want to also serialize n_ctx, and read it at load time.
2023-10-10 23:43:02 +03:00
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
{
if (version > 1) {
2023-06-22 22:44:49 +03:00
stream << m_llModelType;
switch (m_llModelType) {
case GPTJ_: stream << GPTJ_INTERNAL_STATE_VERSION; break;
case LLAMA_: stream << LLAMA_INTERNAL_STATE_VERSION; break;
default: Q_UNREACHABLE();
}
}
stream << response();
stream << generatedName();
stream << m_promptResponseTokens;
2023-10-10 23:43:02 +03:00
if (!serializeKV) {
#if defined(DEBUG)
qDebug() << "serialize" << m_llmThread.objectName() << m_state.size();
#endif
return stream.status() == QDataStream::Ok;
}
2023-06-20 01:23:05 +03:00
if (version <= 3) {
int responseLogits = 0;
2023-06-20 01:23:05 +03:00
stream << responseLogits;
}
stream << m_ctx.n_past;
if (version >= 7) {
stream << m_ctx.n_ctx;
}
stream << quint64(m_ctx.logits.size());
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.logits.data()), m_ctx.logits.size() * sizeof(float));
stream << quint64(m_ctx.tokens.size());
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.tokens.data()), m_ctx.tokens.size() * sizeof(int));
saveState();
QByteArray compressed = qCompress(m_state);
stream << compressed;
#if defined(DEBUG)
qDebug() << "serialize" << m_llmThread.objectName() << m_state.size();
#endif
return stream.status() == QDataStream::Ok;
}
2023-10-10 23:43:02 +03:00
bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV)
{
if (version > 1) {
int internalStateVersion;
2023-06-22 22:44:49 +03:00
stream >> m_llModelType;
stream >> internalStateVersion; // for future use
}
QString response;
stream >> response;
m_response = response.toStdString();
QString nameResponse;
stream >> nameResponse;
m_nameResponse = nameResponse.toStdString();
stream >> m_promptResponseTokens;
2023-10-10 23:43:02 +03:00
// If we do not deserialize the KV or it is discarded, then we need to restore the state from the
// text only. This will be a costly operation, but the chat has to be restored from the text archive
// alone.
m_restoreStateFromText = !deserializeKV || discardKV;
if (!deserializeKV) {
#if defined(DEBUG)
qDebug() << "deserialize" << m_llmThread.objectName();
#endif
return stream.status() == QDataStream::Ok;
}
2023-06-20 01:23:05 +03:00
if (version <= 3) {
int responseLogits;
stream >> responseLogits;
}
2023-10-10 23:43:02 +03:00
int32_t n_past;
stream >> n_past;
if (!discardKV) m_ctx.n_past = n_past;
if (version >= 7) {
uint32_t n_ctx;
stream >> n_ctx;
if (!discardKV) m_ctx.n_ctx = n_ctx;
}
quint64 logitsSize;
stream >> logitsSize;
2023-10-10 23:43:02 +03:00
if (!discardKV) {
m_ctx.logits.resize(logitsSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.logits.data()), logitsSize * sizeof(float));
} else {
stream.skipRawData(logitsSize * sizeof(float));
}
quint64 tokensSize;
stream >> tokensSize;
2023-10-10 23:43:02 +03:00
if (!discardKV) {
m_ctx.tokens.resize(tokensSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
} else {
stream.skipRawData(tokensSize * sizeof(int));
}
2023-05-08 12:52:57 +03:00
if (version > 0) {
QByteArray compressed;
stream >> compressed;
2023-10-10 23:43:02 +03:00
if (!discardKV)
m_state = qUncompress(compressed);
2023-05-08 12:52:57 +03:00
} else {
if (!discardKV) {
2023-10-10 23:43:02 +03:00
stream >> m_state;
} else {
2023-10-10 23:43:02 +03:00
QByteArray state;
stream >> state;
2023-10-10 23:43:02 +03:00
}
2023-05-08 12:52:57 +03:00
}
2023-10-10 23:43:02 +03:00
#if defined(DEBUG)
qDebug() << "deserialize" << m_llmThread.objectName();
#endif
return stream.status() == QDataStream::Ok;
}
void ChatLLM::saveState()
{
if (!isModelLoaded())
return;
if (m_llModelType == LLModelType::API_) {
m_state.clear();
QDataStream stream(&m_state, QIODeviceBase::WriteOnly);
stream.setVersion(QDataStream::Qt_6_4);
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model);
stream << chatAPI->context();
return;
}
2023-06-22 22:44:49 +03:00
const size_t stateSize = m_llModelInfo.model->stateSize();
m_state.resize(stateSize);
#if defined(DEBUG)
qDebug() << "saveState" << m_llmThread.objectName() << "size:" << m_state.size();
#endif
2023-06-22 22:44:49 +03:00
m_llModelInfo.model->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
}
void ChatLLM::restoreState()
{
2023-10-10 23:43:02 +03:00
if (!isModelLoaded())
return;
if (m_llModelType == LLModelType::API_) {
QDataStream stream(&m_state, QIODeviceBase::ReadOnly);
stream.setVersion(QDataStream::Qt_6_4);
ChatAPI *chatAPI = static_cast<ChatAPI*>(m_llModelInfo.model);
QList<QString> context;
stream >> context;
chatAPI->setContext(context);
m_state.clear();
m_state.squeeze();
return;
}
#if defined(DEBUG)
qDebug() << "restoreState" << m_llmThread.objectName() << "size:" << m_state.size();
#endif
2023-10-10 23:43:02 +03:00
if (m_state.isEmpty())
return;
if (m_llModelInfo.model->stateSize() == m_state.size()) {
m_llModelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
m_processedSystemPrompt = true;
} else {
qWarning() << "restoring state from text because" << m_llModelInfo.model->stateSize() << "!=" << m_state.size();
m_restoreStateFromText = true;
}
m_state.clear();
m_state.squeeze();
}
void ChatLLM::processSystemPrompt()
{
Q_ASSERT(isModelLoaded());
2023-11-21 18:42:12 +03:00
if (!isModelLoaded() || m_processedSystemPrompt || m_restoreStateFromText || m_isServer)
return;
const std::string systemPrompt = MySettings::globalInstance()->modelSystemPrompt(m_modelInfo).toStdString();
2023-07-12 21:30:11 +03:00
if (QString::fromStdString(systemPrompt).trimmed().isEmpty()) {
m_processedSystemPrompt = true;
return;
}
2023-10-10 23:43:02 +03:00
// Start with a whole new context
2023-07-09 21:42:11 +03:00
m_stopGenerating = false;
2023-10-10 23:43:02 +03:00
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
auto recalcFunc = std::bind(&ChatLLM::handleSystemRecalculate, this, std::placeholders::_1);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo);
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo);
int n_threads = MySettings::globalInstance()->threadCount();
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
m_ctx.top_p = top_p;
m_ctx.min_p = min_p;
m_ctx.temp = temp;
m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty;
m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llModelInfo.model->setThreadCount(n_threads);
#if defined(DEBUG)
printf("%s", qPrintable(QString::fromStdString(systemPrompt)));
fflush(stdout);
#endif
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response
// use "%1%2" and not "%1" to avoid implicit whitespace
m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, recalcFunc, m_ctx, true);
m_ctx.n_predict = old_n_predict;
#if defined(DEBUG)
printf("\n");
fflush(stdout);
#endif
2023-10-10 23:43:02 +03:00
2023-11-21 18:42:12 +03:00
m_processedSystemPrompt = m_stopGenerating == false;
2023-10-10 23:43:02 +03:00
}
void ChatLLM::processRestoreStateFromText()
{
Q_ASSERT(isModelLoaded());
if (!isModelLoaded() || !m_restoreStateFromText || m_isServer)
return;
m_isRecalc = true;
emit recalcChanged();
m_stopGenerating = false;
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1);
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo);
2023-10-10 23:43:02 +03:00
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo);
int n_threads = MySettings::globalInstance()->threadCount();
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
m_ctx.top_p = top_p;
m_ctx.min_p = min_p;
2023-10-10 23:43:02 +03:00
m_ctx.temp = temp;
m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty;
m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llModelInfo.model->setThreadCount(n_threads);
auto it = m_stateFromText.begin();
while (it < m_stateFromText.end()) {
auto &prompt = *it++;
Q_ASSERT(prompt.first == "Prompt: ");
Q_ASSERT(it < m_stateFromText.end());
auto &response = *it++;
Q_ASSERT(response.first != "Prompt: ");
auto responseText = response.second.toStdString();
m_llModelInfo.model->prompt(prompt.second.toStdString(), promptTemplate.toStdString(), promptFunc, nullptr,
recalcFunc, m_ctx, false, &responseText);
2023-10-10 23:43:02 +03:00
}
if (!m_stopGenerating) {
m_restoreStateFromText = false;
m_stateFromText.clear();
}
m_isRecalc = false;
emit recalcChanged();
2023-09-29 20:53:43 +03:00
}