chat: faster KV shift, continue generating, fix stop sequences (#2781)

* Don't stop generating at end of context
* Use llama_kv_cache ops to shift context
* Fix and improve reverse prompt detection
* Replace prompt recalc callback with a flag to disallow context shift
This commit is contained in:
Jared Van Bortel 2024-08-07 11:25:24 -04:00 committed by GitHub
parent 90de2d32f8
commit be66ec8ab5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 285 additions and 230 deletions

View File

@ -33,7 +33,7 @@ set(LLMODEL_VERSION_PATCH 0)
set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL_VERSION_PATCH}")
project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD 23)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
set(BUILD_SHARED_LIBS ON)

View File

@ -531,10 +531,7 @@ size_t LLamaModel::restoreState(const uint8_t *src)
std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special)
{
bool atStart = m_tokenize_last_token == -1;
bool insertSpace = atStart || (
llama_token_get_attr(d_ptr->model, m_tokenize_last_token)
& (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)
);
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);
std::vector<LLModel::Token> fres(str.length() + 4);
int32_t fres_len = llama_tokenize_gpt4all(
d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart,
@ -546,6 +543,12 @@ std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, const std::
return fres;
}
bool LLamaModel::isSpecialToken(Token id) const
{
return llama_token_get_attr(d_ptr->model, id)
& (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN);
}
std::string LLamaModel::tokenToString(Token id) const
{
std::vector<char> result(8, 0);
@ -595,6 +598,30 @@ bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &toke
return res == 0;
}
void LLamaModel::shiftContext(PromptContext &promptCtx)
{
// infinite text generation via context shifting
// erase up to n_ctx*contextErase tokens
int n_keep = shouldAddBOS();
int n_past = promptCtx.n_past;
int n_discard = std::min(n_past - n_keep, int(promptCtx.n_ctx * promptCtx.contextErase));
assert(n_discard > 0);
if (n_discard <= 0)
return;
std::cerr << "Llama: context full, swapping: n_past = " << n_past << ", n_keep = " << n_keep
<< ", n_discard = " << n_discard << "\n";
// erase the first n_discard tokens from the context
llama_kv_cache_seq_rm (d_ptr->ctx, 0, n_keep, n_keep + n_discard);
llama_kv_cache_seq_add(d_ptr->ctx, 0, n_keep + n_discard, n_past, -n_discard);
promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
promptCtx.n_past = promptCtx.tokens.size();
}
int32_t LLamaModel::contextLength() const
{
return llama_n_ctx(d_ptr->ctx);

View File

@ -6,7 +6,6 @@
#include "llmodel.h"
#include <functional>
#include <memory>
#include <string>
#include <vector>
@ -54,9 +53,11 @@ private:
protected:
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override;
bool isSpecialToken(Token id) const override;
std::string tokenToString(Token id) const override;
Token sampleToken(PromptContext &ctx) const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
void shiftContext(PromptContext &promptCtx) override;
int32_t contextLength() const override;
const std::vector<Token> &endTokens() const override;
bool shouldAddBOS() const override;

View File

@ -134,7 +134,7 @@ public:
int32_t n_batch = 9;
float repeat_penalty = 1.10f;
int32_t repeat_last_n = 64; // last n tokens to penalize
float contextErase = 0.75f; // percent of context to erase if we exceed the context window
float contextErase = 0.5f; // percent of context to erase if we exceed the context window
};
using ProgressCallback = std::function<bool(float progress)>;
@ -159,7 +159,7 @@ public:
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &ctx,
bool special = false,
std::string *fakeReply = nullptr);
@ -213,9 +213,11 @@ protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
virtual std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special = false) = 0;
virtual bool isSpecialToken(Token id) const = 0;
virtual std::string tokenToString(Token id) const = 0;
virtual Token sampleToken(PromptContext &ctx) const = 0;
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
virtual void shiftContext(PromptContext &promptCtx) = 0;
virtual int32_t contextLength() const = 0;
virtual const std::vector<Token> &endTokens() const = 0;
virtual bool shouldAddBOS() const = 0;
@ -232,10 +234,6 @@ protected:
return -1;
}
// This is a helper function called from the default implementation of 'prompt' but it can be
// shared by all base classes so it isn't virtual
void recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate);
const Implementation *m_implementation = nullptr;
ProgressCallback m_progressCallback;
@ -249,11 +247,11 @@ protected:
bool decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp);
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &promptCtx);
Token m_tokenize_last_token = -1; // not serialized

View File

@ -106,7 +106,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
const char *prompt_template,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
llmodel_recalculate_callback recalculate_callback,
bool allow_context_shift,
llmodel_prompt_context *ctx,
bool special,
const char *fake_reply)
@ -135,7 +135,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr;
// Call the C++ prompt method
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, recalculate_callback,
wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift,
wrapper->promptContext, special, fake_reply_p);
// Update the C context by giving access to the wrappers raw pointers to std::vector data

View File

@ -74,13 +74,6 @@ typedef bool (*llmodel_prompt_callback)(int32_t token_id);
*/
typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response);
/**
* Callback type for recalculation of context.
* @param whether the model is recalculating the context.
* @return a bool indicating whether the model should keep generating.
*/
typedef bool (*llmodel_recalculate_callback)(bool is_recalculating);
/**
* Embedding cancellation callback for use with llmodel_embed.
* @param batch_sizes The number of tokens in each batch that will be embedded.
@ -175,7 +168,7 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
* @param prompt_template A string representing the input prompt template.
* @param prompt_callback A callback function for handling the processing of prompt.
* @param response_callback A callback function for handling the generated response.
* @param recalculate_callback A callback function for handling recalculation requests.
* @param allow_context_shift Whether to allow shifting of context to make room for more input.
* @param special True if special tokens in the prompt should be processed, false otherwise.
* @param fake_reply A string to insert into context as the model's reply, or NULL to generate one.
* @param ctx A pointer to the llmodel_prompt_context structure.
@ -184,7 +177,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
const char *prompt_template,
llmodel_prompt_callback prompt_callback,
llmodel_response_callback response_callback,
llmodel_recalculate_callback recalculate_callback,
bool allow_context_shift,
llmodel_prompt_context *ctx,
bool special,
const char *fake_reply);

View File

@ -11,42 +11,9 @@
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_set>
#include <vector>
// TODO(cebtenzzre): replace this with llama_kv_cache_seq_shift for llamamodel (GPT-J needs this as-is)
// FIXME(jared): if recalculate returns false, we leave n_past<tokens.size() and do not tell the caller to stop
// FIXME(jared): if we get here during chat name or follow-up generation, bad things will happen when we try to restore
// the old prompt context afterwards
void LLModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
{
int n_keep = shouldAddBOS();
const int32_t n_discard = (promptCtx.n_ctx - n_keep) * promptCtx.contextErase;
// Erase the first percentage of context from the tokens
std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
size_t i = n_keep;
promptCtx.n_past = n_keep;
while (i < promptCtx.tokens.size()) {
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
std::vector<int32_t> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
if (!evalTokens(promptCtx, batch)) {
std::cerr << "LLModel ERROR: Failed to process prompt\n";
goto stop_generating;
}
promptCtx.n_past += batch.size();
if (!recalculate(true))
goto stop_generating;
i = batch_end;
}
assert(promptCtx.n_past == int32_t(promptCtx.tokens.size()));
stop_generating:
recalculate(false);
}
namespace ranges = std::ranges;
static bool parsePromptTemplate(const std::string &tmpl, std::vector<std::smatch> &placeholders, std::string &err)
{
@ -75,7 +42,7 @@ void LLModel::prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &promptCtx,
bool special,
std::string *fakeReply)
@ -92,12 +59,21 @@ void LLModel::prompt(const std::string &prompt,
return;
}
// make sure token cache matches decode offset
if (promptCtx.tokens.size() < promptCtx.n_past) {
// sanity checks
if (promptCtx.n_past > contextLength()) {
std::ostringstream ss;
ss << "expected n_past to be at most " << promptCtx.tokens.size() << ", got " << promptCtx.n_past;
ss << "n_past=" << promptCtx.n_past << " is past end of context length=" << contextLength();
throw std::out_of_range(ss.str());
}
if (promptCtx.n_past > promptCtx.tokens.size()) {
std::ostringstream ss;
ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << promptCtx.tokens.size();
throw std::out_of_range(ss.str());
}
promptCtx.n_ctx = contextLength();
promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
if (promptCtx.n_past < promptCtx.tokens.size())
promptCtx.tokens.resize(promptCtx.n_past);
m_tokenize_last_token = promptCtx.tokens.empty() ? -1 : promptCtx.tokens.back(); // not serialized
@ -149,15 +125,15 @@ void LLModel::prompt(const std::string &prompt,
promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it
// decode the user prompt
if (!decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp))
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
return; // error
// decode the assistant's reply, either generated or spoofed
if (fakeReply == nullptr) {
generateResponse(responseCallback, recalculateCallback, promptCtx);
generateResponse(responseCallback, allowContextShift, promptCtx);
} else {
embd_inp = tokenize(promptCtx, *fakeReply, false);
if (!decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp))
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
return; // error
}
@ -172,19 +148,16 @@ void LLModel::prompt(const std::string &prompt,
}
if (!asstSuffix.empty()) {
embd_inp = tokenize(promptCtx, asstSuffix, true);
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp);
}
}
// returns false on error
bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp) {
// save the context size
promptCtx.n_ctx = contextLength();
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed.");
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
@ -192,9 +165,14 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
return false;
}
promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size());
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx);
promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
// FIXME(jared): There are mitigations for this situation, such as making room before
// copying the prompt context, or restoring the KV cache when we restore the prompt
// context.
if (!allowContextShift && promptCtx.n_past + embd_inp.size() > promptCtx.n_ctx) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_eval=" << embd_inp.size()
<< ", n_ctx=" << promptCtx.n_ctx << "\n";
return false;
}
// process the prompt in batches
size_t i = 0;
@ -204,7 +182,8 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
// Check if the context has run out...
if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) {
recalculateContext(promptCtx, recalculateCallback);
assert(allowContextShift);
shiftContext(promptCtx);
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
}
@ -226,70 +205,170 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
return true;
}
/*
* If string s overlaps with the string key such that some prefix of the key is at the end
* of the string, return the position in s where the first match starts. Otherwise, return
* std::string::npos. Examples:
* s = "bfo", key = "foo" -> 1
* s = "fooa", key = "foo" -> npos
*/
static std::string::size_type stringsOverlap(const std::string &s, const std::string &key)
{
if (s.empty() || key.empty())
throw std::invalid_argument("arguments to stringsOverlap must not be empty");
for (int start = std::max(0, int(s.size()) - int(key.size())); start < s.size(); start++) {
if (s.compare(start, s.size(), key, 0, s.size() - start) == 0)
return start;
}
return std::string::npos;
}
void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &promptCtx) {
static const char *stopSequences[] {
"### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context",
};
// Don't even start if there is no room
if (!promptCtx.n_predict)
return;
if (!allowContextShift && promptCtx.n_past >= promptCtx.n_ctx) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << promptCtx.n_ctx
<< "\n";
return;
}
std::string cachedResponse;
std::vector<Token> cachedTokens;
std::unordered_set<std::string> reversePrompts
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" };
int n_predicted = 0;
// predict next tokens
for (int i = 0; i < promptCtx.n_predict; i++) {
// Predict next tokens
for (bool stop = false; !stop;) {
// Sample next token
std::optional<Token> new_tok = sampleToken(promptCtx);
std::string new_piece = tokenToString(new_tok.value());
cachedTokens.push_back(new_tok.value());
cachedResponse += new_piece;
// sample next token
auto id = sampleToken(promptCtx);
auto accept = [this, &promptCtx, &cachedTokens, &new_tok, allowContextShift]() -> bool {
// Shift context if out of space
if (promptCtx.n_past >= promptCtx.n_ctx) {
(void)allowContextShift;
assert(allowContextShift);
shiftContext(promptCtx);
assert(promptCtx.n_past < promptCtx.n_ctx);
}
// Check if the context has run out...
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
recalculateContext(promptCtx, recalculateCallback);
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
}
// Accept the token
Token tok = std::exchange(new_tok, std::nullopt).value();
if (!evalTokens(promptCtx, { tok })) {
// TODO(jared): raise an exception
std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n";
return false;
}
if (!evalTokens(promptCtx, { id })) {
std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n";
return;
}
promptCtx.tokens.push_back(tok);
promptCtx.n_past += 1;
return true;
};
// display text
// Check for EOS
auto lengthLimit = std::string::npos;
for (const auto token : endTokens()) {
if (id == token) return;
}
const std::string str = tokenToString(id);
// Check if the provided str is part of our reverse prompts
bool foundPartialReversePrompt = false;
const std::string completed = cachedResponse + std::string(str);
if (reversePrompts.find(completed) != reversePrompts.end())
return;
// Check if it partially matches our reverse prompts and if so, cache
for (const auto& s : reversePrompts) {
if (s.compare(0, completed.size(), completed) == 0) {
foundPartialReversePrompt = true;
cachedResponse = completed;
break;
if (new_tok == token) {
stop = true;
lengthLimit = cachedResponse.size() - new_piece.size();
}
}
// Regardless the token gets added to our cache
cachedTokens.push_back(id);
if (lengthLimit != std::string::npos) {
// EOS matched
} else if (!isSpecialToken(new_tok.value())) {
// Check if the response contains a stop sequence
for (const auto &p : stopSequences) {
auto match = cachedResponse.find(p);
if (match != std::string::npos) stop = true;
lengthLimit = std::min(lengthLimit, match);
if (match == 0) break;
}
// Continue if we have found a partial match
if (foundPartialReversePrompt)
continue;
// Empty the cache
for (auto t : cachedTokens) {
promptCtx.tokens.push_back(t);
promptCtx.n_past += 1;
//TODO: Conversion to std::string can be avoided here...
if (!responseCallback(t, std::string(tokenToString(t))))
return;
// Check if the response matches the start of a stop sequence
if (lengthLimit == std::string::npos) {
for (const auto &p : stopSequences) {
auto match = stringsOverlap(cachedResponse, p);
lengthLimit = std::min(lengthLimit, match);
if (match == 0) break;
}
}
} else if (ranges::contains(stopSequences, new_piece)) {
// Special tokens must exactly match a stop sequence
stop = true;
lengthLimit = cachedResponse.size() - new_piece.size();
}
// Optionally stop if the context will run out
if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= promptCtx.n_ctx) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx="
<< promptCtx.n_ctx << "\n";
stop = true;
}
// Empty the cache, up to the length limit
std::string::size_type responseLength = 0;
while (!cachedTokens.empty()) {
Token tok = cachedTokens.front();
std::string piece = tokenToString(tok);
// Stop if the piece (or part of it) does not fit within the length limit
if (responseLength + (stop ? 1 : piece.size()) > lengthLimit)
break;
// Remove token from cache
assert(cachedResponse.starts_with(piece));
cachedTokens.erase(cachedTokens.begin(), cachedTokens.begin() + 1);
cachedResponse.erase(cachedResponse.begin(), cachedResponse.begin() + piece.size());
// Accept the token, if needed (not cached)
if (cachedTokens.empty() && new_tok && !accept())
return;
// Send the token
if (!responseCallback(tok, piece) || ++n_predicted >= promptCtx.n_predict) {
stop = true;
break;
}
// FIXME(jared): we could avoid printing partial stop sequences if we didn't have to
// output token IDs and could cache a partial token for the next prompt call
responseLength += piece.size();
}
assert(cachedTokens.empty() == cachedResponse.empty());
// Accept the token, if needed (in cache)
if (new_tok) {
assert(!cachedTokens.empty() && cachedTokens.back() == new_tok);
if (stop) {
cachedTokens.pop_back();
} else if (!accept()) {
return;
}
}
cachedTokens.clear();
}
auto &tokens = promptCtx.tokens;
if (tokens.size() < cachedTokens.size()) {
/* This is theoretically possible if the longest stop sequence is greater than
* n_ctx * contextErase tokens. */
throw std::runtime_error("shifted too much context, can't go back");
}
auto discard_start = tokens.end() - cachedTokens.size();
assert(std::equal(discard_start, tokens.end(), cachedTokens.begin()));
tokens.erase(discard_start, tokens.end());
promptCtx.n_past -= cachedTokens.size();
}
void LLModel::embed(

View File

@ -128,7 +128,6 @@ llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p)
llmodel.llmodel_prompt.argtypes = [
@ -137,7 +136,7 @@ llmodel.llmodel_prompt.argtypes = [
ctypes.c_char_p,
PromptCallback,
ResponseCallback,
RecalculateCallback,
ctypes.c_bool,
ctypes.POINTER(LLModelPromptContext),
ctypes.c_bool,
ctypes.c_char_p,
@ -513,7 +512,7 @@ class LLModel:
ctypes.c_char_p(prompt_template.encode()),
PromptCallback(self._prompt_callback),
ResponseCallback(self._callback_decoder(callback)),
RecalculateCallback(self._recalculate_callback),
True,
self.context,
special,
ctypes.c_char_p(),
@ -606,8 +605,3 @@ class LLModel:
@staticmethod
def _prompt_callback(token_id: int) -> bool:
return True
# Empty recalculate callback
@staticmethod
def _recalculate_callback(is_recalculating: bool) -> bool:
return is_recalculating

View File

@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.16)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD 23)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
if(APPLE)
@ -31,7 +31,6 @@ project(gpt4all VERSION ${APP_VERSION_BASE} LANGUAGES CXX C)
set(CMAKE_AUTOMOC ON)
set(CMAKE_AUTORCC ON)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
option(GPT4ALL_TRANSLATIONS OFF "Build with translations")
option(GPT4ALL_LOCALHOST OFF "Build installer for localhost repo")

View File

@ -62,7 +62,7 @@ void Chat::connectLLM()
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
@ -252,9 +252,9 @@ void Chat::serverNewPromptResponsePair(const QString &prompt)
m_chatModel->appendResponse("Response: ", prompt);
}
bool Chat::isRecalc() const
bool Chat::restoringFromText() const
{
return m_llmodel->isRecalc();
return m_llmodel->restoringFromText();
}
void Chat::unloadAndDeleteLater()
@ -320,10 +320,10 @@ void Chat::generatedQuestionFinished(const QString &question)
emit generatedQuestionsChanged();
}
void Chat::handleRecalculating()
void Chat::handleRestoringFromText()
{
Network::globalInstance()->trackChatEvent("recalc_context", { {"length", m_chatModel->count()} });
emit recalcChanged();
emit restoringFromTextChanged();
}
void Chat::handleModelLoadingError(const QString &error)

View File

@ -27,7 +27,7 @@ class Chat : public QObject
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged)
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged)
Q_PROPERTY(bool isServer READ isServer NOTIFY isServerChanged)
Q_PROPERTY(ResponseState responseState READ responseState NOTIFY responseStateChanged)
Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged)
@ -88,7 +88,7 @@ public:
ResponseState responseState() const;
ModelInfo modelInfo() const;
void setModelInfo(const ModelInfo &modelInfo);
bool isRecalc() const;
bool restoringFromText() const;
Q_INVOKABLE void unloadModel();
Q_INVOKABLE void reloadModel();
@ -144,7 +144,7 @@ Q_SIGNALS:
void processSystemPromptRequested();
void modelChangeRequested(const ModelInfo &modelInfo);
void modelInfoChanged();
void recalcChanged();
void restoringFromTextChanged();
void loadDefaultModelRequested();
void loadModelRequested(const ModelInfo &modelInfo);
void generateNameRequested();
@ -167,7 +167,7 @@ private Q_SLOTS:
void responseStopped(qint64 promptResponseMs);
void generatedNameChanged(const QString &name);
void generatedQuestionFinished(const QString &question);
void handleRecalculating();
void handleRestoringFromText();
void handleModelLoadingError(const QString &error);
void handleTokenSpeedChanged(const QString &tokenSpeed);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results);

View File

@ -90,13 +90,13 @@ void ChatAPI::prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &promptCtx,
bool special,
std::string *fakeReply) {
Q_UNUSED(promptCallback);
Q_UNUSED(recalculateCallback);
Q_UNUSED(allowContextShift);
Q_UNUSED(special);
if (!isModelLoaded()) {

View File

@ -69,7 +69,7 @@ public:
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
bool allowContextShift,
PromptContext &ctx,
bool special,
std::string *fakeReply) override;
@ -97,38 +97,57 @@ protected:
// them as they are only called from the default implementation of 'prompt' which we override and
// completely replace
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override {
std::vector<Token> tokenize(PromptContext &ctx, const std::string &str, bool special) override
{
(void)ctx;
(void)str;
(void)special;
throw std::logic_error("not implemented");
}
std::string tokenToString(Token id) const override {
bool isSpecialToken(Token id) const override
{
(void)id;
throw std::logic_error("not implemented");
}
Token sampleToken(PromptContext &ctx) const override {
std::string tokenToString(Token id) const override
{
(void)id;
throw std::logic_error("not implemented");
}
Token sampleToken(PromptContext &ctx) const override
{
(void)ctx;
throw std::logic_error("not implemented");
}
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override {
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override
{
(void)ctx;
(void)tokens;
throw std::logic_error("not implemented");
}
int32_t contextLength() const override {
void shiftContext(PromptContext &promptCtx) override
{
(void)promptCtx;
throw std::logic_error("not implemented");
}
const std::vector<Token> &endTokens() const override {
int32_t contextLength() const override
{
throw std::logic_error("not implemented");
}
bool shouldAddBOS() const override {
const std::vector<Token> &endTokens() const override
{
throw std::logic_error("not implemented");
}
bool shouldAddBOS() const override
{
throw std::logic_error("not implemented");
}

View File

@ -102,7 +102,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
: QObject{nullptr}
, m_promptResponseTokens(0)
, m_promptTokens(0)
, m_isRecalc(false)
, m_restoringFromText(false)
, m_shouldBeLoaded(false)
, m_forceUnloadModel(false)
, m_markedForDeletion(false)
@ -712,17 +712,6 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
return !m_stopGenerating;
}
bool ChatLLM::handleRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "recalculate" << m_llmThread.objectName() << isRecalc;
#endif
if (m_isRecalc != isRecalc) {
m_isRecalc = isRecalc;
emit recalcChanged();
}
return !m_stopGenerating;
}
bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt)
{
if (m_restoreStateFromText) {
@ -776,7 +765,6 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1);
emit promptProcessing();
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
@ -796,10 +784,12 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
m_timer->start();
if (!docsContext.isEmpty()) {
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response
m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc, recalcFunc, m_ctx);
m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc,
/*allowContextShift*/ true, m_ctx);
m_ctx.n_predict = old_n_predict; // now we are ready for a response
}
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
/*allowContextShift*/ true, m_ctx);
#if defined(DEBUG)
printf("\n");
fflush(stdout);
@ -904,10 +894,9 @@ void ChatLLM::generateName()
auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
auto promptFunc = std::bind(&ChatLLM::handleNamePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleNameRecalculate, this, std::placeholders::_1);
LLModel::PromptContext ctx = m_ctx;
m_llModelInfo.model->prompt(chatNamePrompt.toStdString(), promptTemplate.toStdString(),
promptFunc, responseFunc, recalcFunc, ctx);
promptFunc, responseFunc, /*allowContextShift*/ false, ctx);
std::string trimmed = trim_whitespace(m_nameResponse);
if (trimmed != m_nameResponse) {
m_nameResponse = trimmed;
@ -944,15 +933,6 @@ bool ChatLLM::handleNameResponse(int32_t token, const std::string &response)
return words.size() <= 3;
}
bool ChatLLM::handleNameRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "name recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return true;
}
bool ChatLLM::handleQuestionPrompt(int32_t token)
{
#if defined(DEBUG)
@ -991,15 +971,6 @@ bool ChatLLM::handleQuestionResponse(int32_t token, const std::string &response)
return true;
}
bool ChatLLM::handleQuestionRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "name recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return true;
}
void ChatLLM::generateQuestions(qint64 elapsed)
{
Q_ASSERT(isModelLoaded());
@ -1019,12 +990,11 @@ void ChatLLM::generateQuestions(qint64 elapsed)
auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
auto promptFunc = std::bind(&ChatLLM::handleQuestionPrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleQuestionResponse, this, std::placeholders::_1, std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleQuestionRecalculate, this, std::placeholders::_1);
LLModel::PromptContext ctx = m_ctx;
QElapsedTimer totalTime;
totalTime.start();
m_llModelInfo.model->prompt(suggestedFollowUpPrompt,
promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
m_llModelInfo.model->prompt(suggestedFollowUpPrompt, promptTemplate.toStdString(), promptFunc, responseFunc,
/*allowContextShift*/ false, ctx);
elapsed += totalTime.elapsed();
emit responseStopped(elapsed);
}
@ -1039,15 +1009,6 @@ bool ChatLLM::handleSystemPrompt(int32_t token)
return !m_stopGenerating;
}
bool ChatLLM::handleSystemRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "system recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return false;
}
bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
{
#if defined(DEBUG)
@ -1057,15 +1018,6 @@ bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
return !m_stopGenerating;
}
bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "restore state from text recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return false;
}
// this function serialized the cached model state to disk.
// we want to also serialize n_ctx, and read it at load time.
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
@ -1268,7 +1220,6 @@ void ChatLLM::processSystemPrompt()
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
auto recalcFunc = std::bind(&ChatLLM::handleSystemRecalculate, this, std::placeholders::_1);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
@ -1294,7 +1245,7 @@ void ChatLLM::processSystemPrompt()
#endif
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response
// use "%1%2" and not "%1" to avoid implicit whitespace
m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, recalcFunc, m_ctx, true);
m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, /*allowContextShift*/ true, m_ctx, true);
m_ctx.n_predict = old_n_predict;
#if defined(DEBUG)
printf("\n");
@ -1311,14 +1262,13 @@ void ChatLLM::processRestoreStateFromText()
if (!isModelLoaded() || !m_restoreStateFromText || m_isServer)
return;
m_isRecalc = true;
emit recalcChanged();
m_restoringFromText = true;
emit restoringFromTextChanged();
m_stopGenerating = false;
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1);
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
@ -1351,7 +1301,7 @@ void ChatLLM::processRestoreStateFromText()
auto responseText = response.second.toStdString();
m_llModelInfo.model->prompt(prompt.second.toStdString(), promptTemplate.toStdString(), promptFunc, nullptr,
recalcFunc, m_ctx, false, &responseText);
/*allowContextShift*/ true, m_ctx, false, &responseText);
}
if (!m_stopGenerating) {
@ -1359,8 +1309,8 @@ void ChatLLM::processRestoreStateFromText()
m_stateFromText.clear();
}
m_isRecalc = false;
emit recalcChanged();
m_restoringFromText = false;
emit restoringFromTextChanged();
m_pristineLoadedState = false;
}

View File

@ -93,7 +93,7 @@ class Chat;
class ChatLLM : public QObject
{
Q_OBJECT
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged)
Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged)
@ -121,7 +121,7 @@ public:
ModelInfo modelInfo() const;
void setModelInfo(const ModelInfo &info);
bool isRecalc() const { return m_isRecalc; }
bool restoringFromText() const { return m_restoringFromText; }
void acquireModel();
void resetModel();
@ -172,7 +172,7 @@ public Q_SLOTS:
void processRestoreStateFromText();
Q_SIGNALS:
void recalcChanged();
void restoringFromTextChanged();
void loadedModelInfoChanged();
void modelLoadingPercentageChanged(float);
void modelLoadingError(const QString &error);
@ -201,19 +201,14 @@ protected:
int32_t repeat_penalty_tokens);
bool handlePrompt(int32_t token);
bool handleResponse(int32_t token, const std::string &response);
bool handleRecalculate(bool isRecalc);
bool handleNamePrompt(int32_t token);
bool handleNameResponse(int32_t token, const std::string &response);
bool handleNameRecalculate(bool isRecalc);
bool handleSystemPrompt(int32_t token);
bool handleSystemResponse(int32_t token, const std::string &response);
bool handleSystemRecalculate(bool isRecalc);
bool handleRestoreStateFromTextPrompt(int32_t token);
bool handleRestoreStateFromTextResponse(int32_t token, const std::string &response);
bool handleRestoreStateFromTextRecalculate(bool isRecalc);
bool handleQuestionPrompt(int32_t token);
bool handleQuestionResponse(int32_t token, const std::string &response);
bool handleQuestionRecalculate(bool isRecalc);
void saveState();
void restoreState();
@ -236,7 +231,7 @@ private:
QThread m_llmThread;
std::atomic<bool> m_stopGenerating;
std::atomic<bool> m_shouldBeLoaded;
std::atomic<bool> m_isRecalc;
std::atomic<bool> m_restoringFromText; // status indication
std::atomic<bool> m_forceUnloadModel;
std::atomic<bool> m_markedForDeletion;
bool m_isServer;

View File

@ -834,7 +834,7 @@ Rectangle {
to: 360
duration: 1000
loops: Animation.Infinite
running: currentResponse && (currentChat.responseInProgress || currentChat.isRecalc)
running: currentResponse && (currentChat.responseInProgress || currentChat.restoringFromText)
}
}
}
@ -867,13 +867,13 @@ Rectangle {
color: theme.mutedTextColor
}
RowLayout {
visible: currentResponse && ((value === "" && currentChat.responseInProgress) || currentChat.isRecalc)
visible: currentResponse && ((value === "" && currentChat.responseInProgress) || currentChat.restoringFromText)
Text {
color: theme.mutedTextColor
font.pixelSize: theme.fontSizeLarger
text: {
if (currentChat.isRecalc)
return qsTr("recalculating context ...");
if (currentChat.restoringFromText)
return qsTr("restoring from text ...");
switch (currentChat.responseState) {
case Chat.ResponseStopped: return qsTr("response stopped ...");
case Chat.LocalDocsRetrieval: return qsTr("retrieving localdocs: %1 ...").arg(currentChat.collectionList.join(", "));
@ -1861,7 +1861,7 @@ Rectangle {
}
}
function sendMessage() {
if (textInput.text === "" || currentChat.responseInProgress || currentChat.isRecalc)
if (textInput.text === "" || currentChat.responseInProgress || currentChat.restoringFromText)
return
currentChat.stopGenerating()