Use the token cache to infer greater n_past and reuse results (#3073)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-10-31 11:19:12 -04:00 committed by GitHub
parent 62cab695eb
commit f07e2e63df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 320 additions and 169 deletions

View File

@ -12,6 +12,7 @@ function(gpt4all_add_warning_options target)
-Wformat=2
-Wmissing-include-dirs
-Wstrict-overflow=2
-Wsuggest-override
-Wvla
# errors
-Werror=format-security

View File

@ -124,9 +124,7 @@ public:
};
struct PromptContext {
std::vector<int32_t> tokens; // current tokens in the context window
int32_t n_past = 0; // number of tokens in past conversation
int32_t n_ctx = 0; // number of tokens possible in context window
int32_t n_predict = 200;
int32_t top_k = 40;
float top_p = 0.9f;
@ -151,8 +149,8 @@ public:
virtual bool isModelLoaded() const = 0;
virtual size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) = 0;
virtual size_t stateSize() const = 0;
virtual size_t saveState(std::span<uint8_t> dest) const = 0;
virtual size_t restoreState(std::span<const uint8_t> src) = 0;
virtual size_t saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const = 0;
virtual size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) = 0;
// This method requires the model to return true from supportsCompletion otherwise it will throw
// an error
@ -210,6 +208,8 @@ public:
void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; }
virtual int32_t contextLength() const = 0;
protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
@ -218,9 +218,15 @@ protected:
virtual std::string tokenToString(Token id) const = 0;
virtual void initSampler(PromptContext &ctx) = 0;
virtual Token sampleToken() const = 0;
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
virtual bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const = 0;
virtual void shiftContext(PromptContext &promptCtx) = 0;
virtual int32_t contextLength() const = 0;
virtual int32_t inputLength() const = 0;
virtual void setTokenizeInputPosition(int32_t pos) = 0;
virtual auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator = 0;
virtual void setModelInputPosition(PromptContext &ctx, int32_t pos) = 0;
virtual void appendInputToken(PromptContext &ctx, Token tok) = 0;
virtual std::span<const Token> inputTokens() const = 0;
virtual const std::vector<Token> &endTokens() const = 0;
virtual bool shouldAddBOS() const = 0;
@ -252,11 +258,13 @@ protected:
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp,
bool isResponse = false);
bool isResponse = false,
bool alwaysDecode = false);
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx);
protected:
Token m_tokenize_last_token = -1; // not serialized
friend class LLMImplementation;

View File

@ -23,6 +23,11 @@ extern "C" {
*/
typedef void *llmodel_model;
/**
* A token.
*/
typedef int32_t token_t;
/**
* llmodel_prompt_context structure for holding the prompt context.
* NOTE: The implementation takes care of all the memory handling of the raw logits pointer and the
@ -30,10 +35,7 @@ typedef void *llmodel_model;
* behavior.
*/
struct llmodel_prompt_context {
int32_t *tokens; // current tokens in the context window
size_t tokens_size; // the size of the raw tokens vector
int32_t n_past; // number of tokens in past conversation
int32_t n_ctx; // number of tokens possible in context window
int32_t n_predict; // number of tokens to predict
int32_t top_k; // top k logits to sample from
float top_p; // nucleus sampling probability threshold
@ -141,27 +143,41 @@ bool llmodel_isModelLoaded(llmodel_model model);
* @param model A pointer to the llmodel_model instance.
* @return the size in bytes of the internal state of the model
*/
uint64_t llmodel_get_state_size(llmodel_model model);
uint64_t llmodel_state_get_size(llmodel_model model);
/**
* Saves the internal state of the model to the specified destination address.
* Saves the internal state of the model.
* NOTE: This state data is specific to the type of model you have created.
* @param model A pointer to the llmodel_model instance.
* @param dest A pointer to the destination.
* @param size The size of the destination buffer.
* @return the number of bytes copied, or zero on error.
* @param state Where to store the state. This must be a buffer of at least llmodel_state_get_size() bytes.
* @param state_size The size of the destination for the state.
* @param input_tokens_out Where to store the address of the token cache state. This is dynamically allocated and must
* be freed with llmodel_state_free_input_tokens.
* @param n_input_tokens Where to store the size of the token cache state.
* @return The number of bytes copied. On error, zero is returned, the token cache is set to NULL, and the token cache
* size is set to zero.
*/
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest, uint64_t size);
uint64_t llmodel_state_get_data(llmodel_model model, uint8_t *state_out, uint64_t state_size,
token_t **input_tokens_out, uint64_t *n_input_tokens);
/**
* Frees the temporary token cache buffer created by a call to llmodel_state_get_data().
* @param input_tokens The token cache buffer.
*/
void llmodel_state_free_input_tokens(token_t *input_tokens);
/**
* Restores the internal state of the model using data from the specified address.
* NOTE: This state data is specific to the type of model you have created.
* @param model A pointer to the llmodel_model instance.
* @param src A pointer to the state data.
* @param size The size of the source data.
* @param state A pointer to the state data.
* @param state_size The size of the state data.
* @param input_tokens The token cache associated with the saved state.
* @param n_input_tokens The number of tokens in input_tokens.
* @return The number of bytes read, or zero on error.
*/
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src, size_t size);
uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint64_t state_size,
const token_t *input_tokens, uint64_t n_input_tokens);
/**
* Generate a response using the model.

View File

@ -218,6 +218,7 @@ struct LLamaPrivate {
int64_t n_threads = 0;
std::vector<LLModel::Token> end_tokens;
const char *backend_name = nullptr;
std::vector<LLModel::Token> inputTokens;
llama_model *model = nullptr;
llama_context *ctx = nullptr;
@ -501,14 +502,20 @@ size_t LLamaModel::stateSize() const
return llama_state_get_size(d_ptr->ctx);
}
size_t LLamaModel::saveState(std::span<uint8_t> dest) const
size_t LLamaModel::saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const
{
return llama_state_get_data(d_ptr->ctx, dest.data(), dest.size());
size_t bytesWritten = llama_state_get_data(d_ptr->ctx, stateOut.data(), stateOut.size());
if (bytesWritten)
inputTokensOut.assign(d_ptr->inputTokens.begin(), d_ptr->inputTokens.end());
return bytesWritten;
}
size_t LLamaModel::restoreState(std::span<const uint8_t> src)
size_t LLamaModel::restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens)
{
return llama_state_set_data(d_ptr->ctx, src.data(), src.size());
size_t bytesRead = llama_state_set_data(d_ptr->ctx, state.data(), state.size());
if (bytesRead)
d_ptr->inputTokens.assign(inputTokens.begin(), inputTokens.end());
return bytesRead;
}
std::vector<LLModel::Token> LLamaModel::tokenize(std::string_view str, bool special)
@ -594,7 +601,7 @@ LLModel::Token LLamaModel::sampleToken() const
return llama_sampler_sample(d_ptr->sampler_chain, d_ptr->ctx, -1);
}
bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) const
{
llama_kv_cache_seq_rm(d_ptr->ctx, 0, ctx.n_past, -1);
@ -625,7 +632,7 @@ void LLamaModel::shiftContext(PromptContext &promptCtx)
// erase up to n_ctx*contextErase tokens
int n_keep = shouldAddBOS();
int n_past = promptCtx.n_past;
int n_discard = std::min(n_past - n_keep, int(promptCtx.n_ctx * promptCtx.contextErase));
int n_discard = std::min(n_past - n_keep, int(contextLength() * promptCtx.contextErase));
assert(n_discard > 0);
if (n_discard <= 0)
@ -638,8 +645,9 @@ void LLamaModel::shiftContext(PromptContext &promptCtx)
llama_kv_cache_seq_rm (d_ptr->ctx, 0, n_keep, n_keep + n_discard);
llama_kv_cache_seq_add(d_ptr->ctx, 0, n_keep + n_discard, n_past, -n_discard);
promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
promptCtx.n_past = promptCtx.tokens.size();
auto &inp = d_ptr->inputTokens;
inp.erase(inp.begin() + n_keep, inp.begin() + n_keep + n_discard);
promptCtx.n_past = inp.size();
}
int32_t LLamaModel::contextLength() const
@ -647,6 +655,60 @@ int32_t LLamaModel::contextLength() const
return llama_n_ctx(d_ptr->ctx);
}
int32_t LLamaModel::inputLength() const
{
return d_ptr->inputTokens.size();
}
void LLamaModel::setTokenizeInputPosition(int32_t pos)
{
assert(pos >= 0);
m_tokenize_last_token = pos ? d_ptr->inputTokens.at(size_t(pos) - 1) : -1; // not serialized
}
auto LLamaModel::computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator
{
assert(ctx.n_past >= 0);
auto pos = size_t(ctx.n_past);
if (pos > d_ptr->inputTokens.size()) {
std::ostringstream ss;
ss << "n_past=" << pos << " is past end of token cache length=" << d_ptr->inputTokens.size();
throw std::out_of_range(ss.str());
}
// find common prefix
auto cacheIt = d_ptr->inputTokens.begin();
auto inputIt = input.begin();
while (cacheIt < d_ptr->inputTokens.end() && inputIt < input.end() && *cacheIt == *inputIt) {
++cacheIt; ++inputIt; ++pos;
}
// tell the caller to ignore the tokens between [begin, inputIt)
return inputIt;
}
void LLamaModel::setModelInputPosition(PromptContext &ctx, int32_t pos)
{
auto &inp = d_ptr->inputTokens;
assert(pos >= 0);
assert(pos <= inp.size());
// truncate token cache to end at the new n_past
if (pos < inp.size())
inp.resize(pos);
ctx.n_past = pos;
}
void LLamaModel::appendInputToken(PromptContext &ctx, Token tok)
{
d_ptr->inputTokens.push_back(tok);
ctx.n_past += 1;
}
auto LLamaModel::inputTokens() const -> std::span<const Token>
{
return d_ptr->inputTokens;
}
const std::vector<LLModel::Token> &LLamaModel::endTokens() const
{
return d_ptr->end_tokens;

View File

@ -28,8 +28,8 @@ public:
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
size_t stateSize() const override;
size_t saveState(std::span<uint8_t> dest) const override;
size_t restoreState(std::span<const uint8_t> src) override;
size_t saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const override;
size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) override;
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() const override;
std::vector<GPUDevice> availableGPUDevices(size_t memoryRequired = 0) const override;
@ -48,10 +48,7 @@ public:
void embed(const std::vector<std::string> &texts, float *embeddings, bool isRetrieval, int dimensionality = -1,
size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override;
private:
std::unique_ptr<LLamaPrivate> d_ptr;
bool m_supportsEmbedding = false;
bool m_supportsCompletion = false;
int32_t contextLength() const override;
protected:
std::vector<Token> tokenize(std::string_view str, bool special) override;
@ -59,9 +56,15 @@ protected:
std::string tokenToString(Token id) const override;
void initSampler(PromptContext &ctx) override;
Token sampleToken() const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const override;
void shiftContext(PromptContext &promptCtx) override;
int32_t contextLength() const override;
int32_t inputLength() const override;
void setTokenizeInputPosition(int32_t pos) override;
auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator override;
void setModelInputPosition(PromptContext &ctx, int32_t pos) override;
void appendInputToken(PromptContext &ctx, Token tok) override;
std::span<const Token> inputTokens() const override;
const std::vector<Token> &endTokens() const override;
bool shouldAddBOS() const override;
int32_t maxContextLength(std::string const &modelPath) const override;
@ -70,6 +73,11 @@ protected:
void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb,
const EmbModelSpec *spec);
private:
std::unique_ptr<LLamaPrivate> d_ptr;
bool m_supportsEmbedding = false;
bool m_supportsCompletion = false;
};
#endif // LLAMAMODEL_H

View File

@ -14,6 +14,11 @@
#include <string>
#include <string_view>
#include <vector>
#include <span>
namespace ranges = std::ranges;
static_assert(sizeof(token_t) == sizeof(LLModel::Token));
struct LLModelWrapper {
LLModel *llModel = nullptr;
@ -85,22 +90,40 @@ bool llmodel_isModelLoaded(llmodel_model model)
return wrapper->llModel->isModelLoaded();
}
uint64_t llmodel_get_state_size(llmodel_model model)
uint64_t llmodel_state_get_size(llmodel_model model)
{
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->stateSize();
}
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest, uint64_t size)
uint64_t llmodel_state_get_data(llmodel_model model, uint8_t *state_out, uint64_t state_size,
token_t **input_tokens_out, uint64_t *n_input_tokens)
{
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->saveState({dest, size_t(size)});
std::vector<LLModel::Token> inputTokens;
auto bytesWritten = wrapper->llModel->saveState({state_out, size_t(state_size)}, inputTokens);
if (bytesWritten) {
auto *buf = new LLModel::Token[inputTokens.size()];
ranges::copy(inputTokens, buf);
*input_tokens_out = buf;
*n_input_tokens = uint64_t(inputTokens.size());
} else {
*input_tokens_out = nullptr;
*n_input_tokens = 0;
}
return bytesWritten;
}
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src, uint64_t size)
void llmodel_state_free_input_tokens(LLModel::Token *input_tokens)
{
delete[] input_tokens;
}
uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint64_t state_size,
const token_t *input_tokens, uint64_t n_input_tokens)
{
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->restoreState({src, size_t(size)});
return wrapper->llModel->restoreState({state, size_t(state_size)}, {input_tokens, size_t(n_input_tokens)});
}
void llmodel_prompt(llmodel_model model, const char *prompt,
@ -120,7 +143,6 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
// Copy the C prompt context
wrapper->promptContext.n_past = ctx->n_past;
wrapper->promptContext.n_ctx = ctx->n_ctx;
wrapper->promptContext.n_predict = ctx->n_predict;
wrapper->promptContext.top_k = ctx->top_k;
wrapper->promptContext.top_p = ctx->top_p;
@ -136,14 +158,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
wrapper->promptContext, special,
fake_reply ? std::make_optional<std::string_view>(fake_reply) : std::nullopt);
// Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies
ctx->tokens = wrapper->promptContext.tokens.data();
ctx->tokens_size = wrapper->promptContext.tokens.size();
// Update the rest of the C prompt context
ctx->n_past = wrapper->promptContext.n_past;
ctx->n_ctx = wrapper->promptContext.n_ctx;
ctx->n_predict = wrapper->promptContext.n_predict;
ctx->top_k = wrapper->promptContext.top_k;
ctx->top_p = wrapper->promptContext.top_p;

View File

@ -6,6 +6,7 @@
#include <cstdint>
#include <functional>
#include <iostream>
#include <iterator>
#include <optional>
#include <regex>
#include <sstream>
@ -66,19 +67,14 @@ void LLModel::prompt(const std::string &prompt,
ss << "n_past=" << promptCtx.n_past << " is past end of context length=" << contextLength();
throw std::out_of_range(ss.str());
}
if (promptCtx.n_past > promptCtx.tokens.size()) {
if (promptCtx.n_past > inputLength()) {
std::ostringstream ss;
ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << promptCtx.tokens.size();
ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << inputLength();
throw std::out_of_range(ss.str());
}
promptCtx.n_ctx = contextLength();
promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
if (promptCtx.n_past < promptCtx.tokens.size())
promptCtx.tokens.resize(promptCtx.n_past);
m_tokenize_last_token = promptCtx.tokens.empty() ? -1 : promptCtx.tokens.back(); // not serialized
// parse the prompt template
std::vector<std::smatch> placeholders;
{
@ -90,6 +86,8 @@ void LLModel::prompt(const std::string &prompt,
}
}
setTokenizeInputPosition(promptCtx.n_past);
// tokenize the user prompt
std::vector<Token> embd_inp;
if (placeholders.empty()) {
@ -118,7 +116,8 @@ void LLModel::prompt(const std::string &prompt,
}
// decode the user prompt
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, /*isResponse*/ false,
/*alwaysDecode*/ true))
return; // error
// decode the assistant's reply, either generated or spoofed
@ -151,36 +150,67 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp,
bool isResponse) {
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
bool isResponse,
bool alwaysDecode) {
if ((int) embd_inp.size() > contextLength() - 4) {
// FIXME: (Adam) We should find a way to bubble these strings to the UI level to allow for
// translation
responseCallback(-1, "Your message was too long and could not be processed. Please try again with something shorter.");
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
" tokens and the context window is " << promptCtx.n_ctx << "!\n";
" tokens and the context window is " << contextLength() << "!\n";
return false;
}
// FIXME(jared): There are mitigations for this situation, such as making room before
// copying the prompt context, or restoring the KV cache when we restore the prompt
// context.
if (!allowContextShift && promptCtx.n_past + embd_inp.size() > promptCtx.n_ctx) {
if (!allowContextShift && promptCtx.n_past + embd_inp.size() > contextLength()) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_eval=" << embd_inp.size()
<< ", n_ctx=" << promptCtx.n_ctx << "\n";
<< ", n_ctx=" << contextLength() << "\n";
return false;
}
// process the prompt in batches
// always decode something before generating, even if cached
if (alwaysDecode && embd_inp.empty()) {
auto cache = inputTokens();
if (!promptCtx.n_past)
throw std::runtime_error("zero token prompt is not supported");
assert(!cache.empty());
embd_inp.push_back(cache.back());
promptCtx.n_past--;
}
// Find the greatest n_past where the beginning of embd_inp matches the end of the token cache, starting at the
// requested n_past.
// This is used to skip unnecessary work when the prompt shares a common prefix with the previous result.
auto embd_inp_start = computeModelInputPosition(promptCtx, embd_inp);
size_t start_offset = embd_inp_start - embd_inp.begin();
// always decode up to a full batch before generating, even if cached
if (alwaysDecode)
start_offset -= std::min(promptCtx.n_batch, int32_t(start_offset));
setModelInputPosition(promptCtx, promptCtx.n_past + start_offset);
// execute the callback even for skipped tokens
size_t i = 0;
for (; i < start_offset; i++) {
Token tok = embd_inp[i];
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
if (!res)
return false;
}
// process the prompt in batches
while (i < embd_inp.size()) {
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
std::vector<Token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
std::span<const Token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
// Check if the context has run out...
if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) {
if (promptCtx.n_past + int32_t(batch.size()) > contextLength()) {
assert(allowContextShift);
shiftContext(promptCtx);
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
assert(promptCtx.n_past + int32_t(batch.size()) <= contextLength());
}
if (!evalTokens(promptCtx, batch)) {
@ -190,9 +220,8 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
size_t tokens = batch_end - i;
for (size_t t = 0; t < tokens; ++t) {
promptCtx.tokens.push_back(batch.at(t));
promptCtx.n_past += 1;
Token tok = batch.at(t);
Token tok = batch[t];
appendInputToken(promptCtx, tok);
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
if (!res)
return false;
@ -232,8 +261,8 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
// Don't even start if there is no room
if (!promptCtx.n_predict)
return;
if (!allowContextShift && promptCtx.n_past >= promptCtx.n_ctx) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << promptCtx.n_ctx
if (!allowContextShift && promptCtx.n_past >= contextLength()) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << contextLength()
<< "\n";
return;
}
@ -254,23 +283,22 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
auto accept = [this, &promptCtx, &new_tok, allowContextShift]() -> bool {
// Shift context if out of space
if (promptCtx.n_past >= promptCtx.n_ctx) {
if (promptCtx.n_past >= contextLength()) {
(void)allowContextShift;
assert(allowContextShift);
shiftContext(promptCtx);
assert(promptCtx.n_past < promptCtx.n_ctx);
assert(promptCtx.n_past < contextLength());
}
// Accept the token
Token tok = std::exchange(new_tok, std::nullopt).value();
if (!evalTokens(promptCtx, { tok })) {
if (!evalTokens(promptCtx, { &tok, 1 })) {
// TODO(jared): raise an exception
std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n";
return false;
}
promptCtx.tokens.push_back(tok);
promptCtx.n_past += 1;
appendInputToken(promptCtx, tok);
return true;
};
@ -309,9 +337,9 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
}
// Optionally stop if the context will run out
if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= promptCtx.n_ctx) {
if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= contextLength()) {
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx="
<< promptCtx.n_ctx << "\n";
<< contextLength() << "\n";
stop = true;
}
@ -357,16 +385,17 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
}
}
auto &tokens = promptCtx.tokens;
if (tokens.size() < cachedTokens.size()) {
if (inputLength() < cachedTokens.size()) {
/* This is theoretically possible if the longest stop sequence is greater than
* n_ctx * contextErase tokens. */
throw std::runtime_error("shifted too much context, can't go back");
}
auto discard_start = tokens.end() - cachedTokens.size();
assert(std::equal(discard_start, tokens.end(), cachedTokens.begin()));
tokens.erase(discard_start, tokens.end());
#ifndef NDEBUG
auto inp = inputTokens();
auto discard_start = inp.end() - cachedTokens.size();
assert(std::equal(discard_start, inp.end(), cachedTokens.begin()));
#endif
promptCtx.n_past -= cachedTokens.size();
}

View File

@ -113,10 +113,7 @@ def _old_loop(gpt4all_instance):
full_response = gpt4all_instance.chat_completion(
MESSAGES,
# preferential kwargs for chat ux
logits_size=0,
tokens_size=0,
n_past=0,
n_ctx=0,
n_predict=200,
top_k=40,
top_p=0.9,

View File

@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
### Added
- Warn on Windows if the Microsoft Visual C++ runtime libraries are not found ([#2920](https://github.com/nomic-ai/gpt4all/pull/2920))
- Basic cache for faster prefill when the input shares a prefix with previous context ([#3073](https://github.com/nomic-ai/gpt4all/pull/3073))
### Changed
- Rebase llama.cpp on latest upstream as of September 26th ([#2998](https://github.com/nomic-ai/gpt4all/pull/2998))

View File

@ -116,10 +116,7 @@ llmodel = load_llmodel_library()
class LLModelPromptContext(ctypes.Structure):
_fields_ = [
("tokens", ctypes.POINTER(ctypes.c_int32)),
("tokens_size", ctypes.c_size_t),
("n_past", ctypes.c_int32),
("n_ctx", ctypes.c_int32),
("n_predict", ctypes.c_int32),
("top_k", ctypes.c_int32),
("top_p", ctypes.c_float),
@ -393,9 +390,7 @@ class LLModel:
):
if self.context is None:
context = LLModelPromptContext(
tokens_size=0,
n_past=0,
n_ctx=0,
n_predict=n_predict,
top_k=top_k,
top_p=top_p,

View File

@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
### Added
- Add ability to attach text, markdown, and rst files to chat ([#3135](https://github.com/nomic-ai/gpt4all/pull/3135))
- Add feature to minimize to system tray (by [@bgallois](https://github.com/bgallois) in ([#3109](https://github.com/nomic-ai/gpt4all/pull/3109))
- Basic cache for faster prefill when the input shares a prefix with previous context ([#3073](https://github.com/nomic-ai/gpt4all/pull/3073))
### Changed
- Implement Qt 6.8 compatibility ([#3121](https://github.com/nomic-ai/gpt4all/pull/3121))

View File

@ -51,7 +51,6 @@ bool ChatAPI::loadModel(const std::string &modelPath, int n_ctx, int ngl)
void ChatAPI::setThreadCount(int32_t n_threads)
{
Q_UNUSED(n_threads);
qt_noop();
}
int32_t ChatAPI::threadCount() const
@ -68,24 +67,6 @@ bool ChatAPI::isModelLoaded() const
return true;
}
// All three of the state virtual functions are handled custom inside of chatllm save/restore
size_t ChatAPI::stateSize() const
{
throw std::logic_error("not implemented");
}
size_t ChatAPI::saveState(std::span<uint8_t> dest) const
{
Q_UNUSED(dest);
throw std::logic_error("not implemented");
}
size_t ChatAPI::restoreState(std::span<const uint8_t> src)
{
Q_UNUSED(src);
throw std::logic_error("not implemented");
}
void ChatAPI::prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,

View File

@ -3,7 +3,7 @@
#include <gpt4all-backend/llmodel.h>
#include <QByteArray>
#include <QByteArray> // IWYU pragma: keep
#include <QNetworkReply>
#include <QObject>
#include <QString>
@ -13,6 +13,8 @@
#include <cstddef>
#include <cstdint>
#include <functional>
#include <optional>
#include <span>
#include <stdexcept>
#include <string>
#include <string_view>
@ -63,9 +65,15 @@ public:
bool loadModel(const std::string &modelPath, int n_ctx, int ngl) override;
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
size_t stateSize() const override;
size_t saveState(std::span<uint8_t> dest) const override;
size_t restoreState(std::span<const uint8_t> src) override;
// All three of the state virtual functions are handled custom inside of chatllm save/restore
size_t stateSize() const override
{ throwNotImplemented(); }
size_t saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const override
{ Q_UNUSED(stateOut); Q_UNUSED(inputTokensOut); throwNotImplemented(); }
size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) override
{ Q_UNUSED(state); Q_UNUSED(inputTokens); throwNotImplemented(); }
void prompt(const std::string &prompt,
const std::string &promptTemplate,
std::function<bool(int32_t)> promptCallback,
@ -88,6 +96,10 @@ public:
bool callResponse(int32_t token, const std::string &string);
[[noreturn]]
int32_t contextLength() const override
{ throwNotImplemented(); }
Q_SIGNALS:
void request(const QString &apiKey,
LLModel::PromptContext *ctx,
@ -98,60 +110,69 @@ protected:
// them as they are only called from the default implementation of 'prompt' which we override and
// completely replace
[[noreturn]]
static void throwNotImplemented() { throw std::logic_error("not implemented"); }
[[noreturn]]
std::vector<Token> tokenize(std::string_view str, bool special) override
{
(void)str;
(void)special;
throw std::logic_error("not implemented");
}
{ Q_UNUSED(str); Q_UNUSED(special); throwNotImplemented(); }
[[noreturn]]
bool isSpecialToken(Token id) const override
{
(void)id;
throw std::logic_error("not implemented");
}
{ Q_UNUSED(id); throwNotImplemented(); }
[[noreturn]]
std::string tokenToString(Token id) const override
{
(void)id;
throw std::logic_error("not implemented");
}
{ Q_UNUSED(id); throwNotImplemented(); }
[[noreturn]]
void initSampler(PromptContext &ctx) override
{
(void)ctx;
throw std::logic_error("not implemented");
}
{ Q_UNUSED(ctx); throwNotImplemented(); }
Token sampleToken() const override { throw std::logic_error("not implemented"); }
[[noreturn]]
Token sampleToken() const override
{ throwNotImplemented(); }
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override
{
(void)ctx;
(void)tokens;
throw std::logic_error("not implemented");
}
[[noreturn]]
bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const override
{ Q_UNUSED(ctx); Q_UNUSED(tokens); throwNotImplemented(); }
[[noreturn]]
void shiftContext(PromptContext &promptCtx) override
{
(void)promptCtx;
throw std::logic_error("not implemented");
}
{ Q_UNUSED(promptCtx); throwNotImplemented(); }
int32_t contextLength() const override
{
throw std::logic_error("not implemented");
}
[[noreturn]]
int32_t inputLength() const override
{ throwNotImplemented(); }
[[noreturn]]
void setTokenizeInputPosition(int32_t pos) override
{ Q_UNUSED(pos); throwNotImplemented(); }
[[noreturn]]
auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator override
{ Q_UNUSED(ctx); Q_UNUSED(input); throwNotImplemented(); }
[[noreturn]]
void setModelInputPosition(PromptContext &ctx, int32_t pos) override
{ Q_UNUSED(ctx); Q_UNUSED(pos); throwNotImplemented(); }
[[noreturn]]
void appendInputToken(PromptContext &ctx, Token tok) override
{ Q_UNUSED(ctx); Q_UNUSED(tok); throwNotImplemented(); }
[[noreturn]]
const std::vector<Token> &endTokens() const override
{
throw std::logic_error("not implemented");
}
{ throwNotImplemented(); }
[[noreturn]]
bool shouldAddBOS() const override
{
throw std::logic_error("not implemented");
}
{ throwNotImplemented(); }
[[noreturn]]
std::span<const Token> inputTokens() const override
{ throwNotImplemented(); }
private:
std::function<bool(int32_t, const std::string&)> m_responseCallback;

View File

@ -33,6 +33,7 @@
#include <functional>
#include <limits>
#include <optional>
#include <span>
#include <string_view>
#include <utility>
#include <vector>
@ -404,7 +405,6 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
QString requestedDevice = MySettings::globalInstance()->device();
int n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo);
m_ctx.n_ctx = n_ctx;
int ngl = MySettings::globalInstance()->modelGpuLayers(modelInfo);
std::string backend = "auto";
@ -632,7 +632,6 @@ void ChatLLM::regenerateResponse()
else
m_ctx.n_past -= m_promptResponseTokens;
m_ctx.n_past = std::max(0, m_ctx.n_past);
m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end());
m_promptResponseTokens = 0;
m_promptTokens = 0;
m_response = m_trimmedResponse = std::string();
@ -1078,12 +1077,13 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
stream << responseLogits;
}
stream << m_ctx.n_past;
if (version >= 7) {
stream << m_ctx.n_ctx;
}
stream << quint64(m_ctx.tokens.size());
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.tokens.data()), m_ctx.tokens.size() * sizeof(int));
saveState();
if (version >= 7) {
stream << m_stateContextLength;
}
stream << quint64(m_stateInputTokens.size());
stream.writeRawData(reinterpret_cast<const char *>(m_stateInputTokens.data()),
m_stateInputTokens.size() * sizeof(m_stateInputTokens[0]));
QByteArray compressed = qCompress(m_state);
stream << compressed;
#if defined(DEBUG)
@ -1145,7 +1145,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
if (version >= 7) {
uint32_t n_ctx;
stream >> n_ctx;
if (!discardKV) m_ctx.n_ctx = n_ctx;
if (!discardKV) m_stateContextLength = n_ctx;
}
if (version < 9) {
@ -1157,10 +1157,10 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
quint64 tokensSize;
stream >> tokensSize;
if (!discardKV) {
m_ctx.tokens.resize(tokensSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
m_stateInputTokens.resize(tokensSize);
stream.readRawData(reinterpret_cast<char *>(m_stateInputTokens.data()), tokensSize * sizeof(m_stateInputTokens[0]));
} else {
stream.skipRawData(tokensSize * sizeof(int));
stream.skipRawData(tokensSize * sizeof(m_stateInputTokens[0]));
}
if (version >= 1) {
@ -1202,13 +1202,16 @@ void ChatLLM::saveState()
#if defined(DEBUG)
qDebug() << "saveState" << m_llmThread.objectName() << "size:" << m_state.size();
#endif
bool ok = m_llModelInfo.model->saveState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())});
bool ok = m_llModelInfo.model->saveState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())},
m_stateInputTokens);
if (!ok) {
// FIXME(jared): how badly does this situation break GPT4All?
qWarning() << "ChatLLM failed to save LLModel state";
m_state.clear();
m_state.squeeze();
m_stateContextLength = -1;
}
m_stateContextLength = m_llModelInfo.model->contextLength();
}
void ChatLLM::restoreState()
@ -1235,13 +1238,22 @@ void ChatLLM::restoreState()
if (m_state.isEmpty())
return;
size_t bytesRead = m_llModelInfo.model->restoreState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())});
if (bytesRead) {
m_processedSystemPrompt = true;
m_pristineLoadedState = true;
} else {
qWarning() << "restoring state from text because of error reading state (mismatch or corrupt data)";
if (m_llModelInfo.model->contextLength() != m_stateContextLength) {
qWarning() << "restoring state from text because of n_ctx mismatch (state"
<< m_stateContextLength << "model" << m_llModelInfo.model->contextLength() << ")";
m_restoreStateFromText = true;
} else {
size_t bytesRead = m_llModelInfo.model->restoreState(
{reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())},
m_stateInputTokens
);
if (!bytesRead) {
qWarning() << "restoring state from text because of error reading state (mismatch or corrupt data)";
m_restoreStateFromText = true;
} else {
m_processedSystemPrompt = true;
m_pristineLoadedState = true;
}
}
// free local state copy unless unload is pending

View File

@ -9,7 +9,7 @@
#include <QByteArray>
#include <QElapsedTimer>
#include <QFileInfo>
#include <QList>
#include <QList> // IWYU pragma: keep
#include <QObject>
#include <QPointer>
#include <QString>
@ -22,6 +22,7 @@
#include <memory>
#include <optional>
#include <string>
#include <vector>
using namespace Qt::Literals::StringLiterals;
@ -277,6 +278,8 @@ private:
ModelInfo m_modelInfo;
TokenTimer *m_timer;
QByteArray m_state;
std::vector<LLModel::Token> m_stateInputTokens;
int32_t m_stateContextLength = -1;
QThread m_llmThread;
std::atomic<bool> m_stopGenerating;
std::atomic<bool> m_shouldBeLoaded;