mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-11-23 11:26:10 +03:00
Use the token cache to infer greater n_past and reuse results (#3073)
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
62cab695eb
commit
f07e2e63df
@ -12,6 +12,7 @@ function(gpt4all_add_warning_options target)
|
||||
-Wformat=2
|
||||
-Wmissing-include-dirs
|
||||
-Wstrict-overflow=2
|
||||
-Wsuggest-override
|
||||
-Wvla
|
||||
# errors
|
||||
-Werror=format-security
|
||||
|
@ -124,9 +124,7 @@ public:
|
||||
};
|
||||
|
||||
struct PromptContext {
|
||||
std::vector<int32_t> tokens; // current tokens in the context window
|
||||
int32_t n_past = 0; // number of tokens in past conversation
|
||||
int32_t n_ctx = 0; // number of tokens possible in context window
|
||||
int32_t n_predict = 200;
|
||||
int32_t top_k = 40;
|
||||
float top_p = 0.9f;
|
||||
@ -151,8 +149,8 @@ public:
|
||||
virtual bool isModelLoaded() const = 0;
|
||||
virtual size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) = 0;
|
||||
virtual size_t stateSize() const = 0;
|
||||
virtual size_t saveState(std::span<uint8_t> dest) const = 0;
|
||||
virtual size_t restoreState(std::span<const uint8_t> src) = 0;
|
||||
virtual size_t saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const = 0;
|
||||
virtual size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) = 0;
|
||||
|
||||
// This method requires the model to return true from supportsCompletion otherwise it will throw
|
||||
// an error
|
||||
@ -210,6 +208,8 @@ public:
|
||||
|
||||
void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; }
|
||||
|
||||
virtual int32_t contextLength() const = 0;
|
||||
|
||||
protected:
|
||||
// These are pure virtual because subclasses need to implement as the default implementation of
|
||||
// 'prompt' above calls these functions
|
||||
@ -218,9 +218,15 @@ protected:
|
||||
virtual std::string tokenToString(Token id) const = 0;
|
||||
virtual void initSampler(PromptContext &ctx) = 0;
|
||||
virtual Token sampleToken() const = 0;
|
||||
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
|
||||
virtual bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const = 0;
|
||||
virtual void shiftContext(PromptContext &promptCtx) = 0;
|
||||
virtual int32_t contextLength() const = 0;
|
||||
virtual int32_t inputLength() const = 0;
|
||||
virtual void setTokenizeInputPosition(int32_t pos) = 0;
|
||||
virtual auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
|
||||
-> std::vector<Token>::const_iterator = 0;
|
||||
virtual void setModelInputPosition(PromptContext &ctx, int32_t pos) = 0;
|
||||
virtual void appendInputToken(PromptContext &ctx, Token tok) = 0;
|
||||
virtual std::span<const Token> inputTokens() const = 0;
|
||||
virtual const std::vector<Token> &endTokens() const = 0;
|
||||
virtual bool shouldAddBOS() const = 0;
|
||||
|
||||
@ -252,11 +258,13 @@ protected:
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx,
|
||||
std::vector<Token> embd_inp,
|
||||
bool isResponse = false);
|
||||
bool isResponse = false,
|
||||
bool alwaysDecode = false);
|
||||
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx);
|
||||
|
||||
protected:
|
||||
Token m_tokenize_last_token = -1; // not serialized
|
||||
|
||||
friend class LLMImplementation;
|
||||
|
@ -23,6 +23,11 @@ extern "C" {
|
||||
*/
|
||||
typedef void *llmodel_model;
|
||||
|
||||
/**
|
||||
* A token.
|
||||
*/
|
||||
typedef int32_t token_t;
|
||||
|
||||
/**
|
||||
* llmodel_prompt_context structure for holding the prompt context.
|
||||
* NOTE: The implementation takes care of all the memory handling of the raw logits pointer and the
|
||||
@ -30,10 +35,7 @@ typedef void *llmodel_model;
|
||||
* behavior.
|
||||
*/
|
||||
struct llmodel_prompt_context {
|
||||
int32_t *tokens; // current tokens in the context window
|
||||
size_t tokens_size; // the size of the raw tokens vector
|
||||
int32_t n_past; // number of tokens in past conversation
|
||||
int32_t n_ctx; // number of tokens possible in context window
|
||||
int32_t n_predict; // number of tokens to predict
|
||||
int32_t top_k; // top k logits to sample from
|
||||
float top_p; // nucleus sampling probability threshold
|
||||
@ -141,27 +143,41 @@ bool llmodel_isModelLoaded(llmodel_model model);
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @return the size in bytes of the internal state of the model
|
||||
*/
|
||||
uint64_t llmodel_get_state_size(llmodel_model model);
|
||||
uint64_t llmodel_state_get_size(llmodel_model model);
|
||||
|
||||
/**
|
||||
* Saves the internal state of the model to the specified destination address.
|
||||
* Saves the internal state of the model.
|
||||
* NOTE: This state data is specific to the type of model you have created.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param dest A pointer to the destination.
|
||||
* @param size The size of the destination buffer.
|
||||
* @return the number of bytes copied, or zero on error.
|
||||
* @param state Where to store the state. This must be a buffer of at least llmodel_state_get_size() bytes.
|
||||
* @param state_size The size of the destination for the state.
|
||||
* @param input_tokens_out Where to store the address of the token cache state. This is dynamically allocated and must
|
||||
* be freed with llmodel_state_free_input_tokens.
|
||||
* @param n_input_tokens Where to store the size of the token cache state.
|
||||
* @return The number of bytes copied. On error, zero is returned, the token cache is set to NULL, and the token cache
|
||||
* size is set to zero.
|
||||
*/
|
||||
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest, uint64_t size);
|
||||
uint64_t llmodel_state_get_data(llmodel_model model, uint8_t *state_out, uint64_t state_size,
|
||||
token_t **input_tokens_out, uint64_t *n_input_tokens);
|
||||
|
||||
/**
|
||||
* Frees the temporary token cache buffer created by a call to llmodel_state_get_data().
|
||||
* @param input_tokens The token cache buffer.
|
||||
*/
|
||||
void llmodel_state_free_input_tokens(token_t *input_tokens);
|
||||
|
||||
/**
|
||||
* Restores the internal state of the model using data from the specified address.
|
||||
* NOTE: This state data is specific to the type of model you have created.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param src A pointer to the state data.
|
||||
* @param size The size of the source data.
|
||||
* @param state A pointer to the state data.
|
||||
* @param state_size The size of the state data.
|
||||
* @param input_tokens The token cache associated with the saved state.
|
||||
* @param n_input_tokens The number of tokens in input_tokens.
|
||||
* @return The number of bytes read, or zero on error.
|
||||
*/
|
||||
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src, size_t size);
|
||||
uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint64_t state_size,
|
||||
const token_t *input_tokens, uint64_t n_input_tokens);
|
||||
|
||||
/**
|
||||
* Generate a response using the model.
|
||||
|
@ -218,6 +218,7 @@ struct LLamaPrivate {
|
||||
int64_t n_threads = 0;
|
||||
std::vector<LLModel::Token> end_tokens;
|
||||
const char *backend_name = nullptr;
|
||||
std::vector<LLModel::Token> inputTokens;
|
||||
|
||||
llama_model *model = nullptr;
|
||||
llama_context *ctx = nullptr;
|
||||
@ -501,14 +502,20 @@ size_t LLamaModel::stateSize() const
|
||||
return llama_state_get_size(d_ptr->ctx);
|
||||
}
|
||||
|
||||
size_t LLamaModel::saveState(std::span<uint8_t> dest) const
|
||||
size_t LLamaModel::saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const
|
||||
{
|
||||
return llama_state_get_data(d_ptr->ctx, dest.data(), dest.size());
|
||||
size_t bytesWritten = llama_state_get_data(d_ptr->ctx, stateOut.data(), stateOut.size());
|
||||
if (bytesWritten)
|
||||
inputTokensOut.assign(d_ptr->inputTokens.begin(), d_ptr->inputTokens.end());
|
||||
return bytesWritten;
|
||||
}
|
||||
|
||||
size_t LLamaModel::restoreState(std::span<const uint8_t> src)
|
||||
size_t LLamaModel::restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens)
|
||||
{
|
||||
return llama_state_set_data(d_ptr->ctx, src.data(), src.size());
|
||||
size_t bytesRead = llama_state_set_data(d_ptr->ctx, state.data(), state.size());
|
||||
if (bytesRead)
|
||||
d_ptr->inputTokens.assign(inputTokens.begin(), inputTokens.end());
|
||||
return bytesRead;
|
||||
}
|
||||
|
||||
std::vector<LLModel::Token> LLamaModel::tokenize(std::string_view str, bool special)
|
||||
@ -594,7 +601,7 @@ LLModel::Token LLamaModel::sampleToken() const
|
||||
return llama_sampler_sample(d_ptr->sampler_chain, d_ptr->ctx, -1);
|
||||
}
|
||||
|
||||
bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
|
||||
bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) const
|
||||
{
|
||||
llama_kv_cache_seq_rm(d_ptr->ctx, 0, ctx.n_past, -1);
|
||||
|
||||
@ -625,7 +632,7 @@ void LLamaModel::shiftContext(PromptContext &promptCtx)
|
||||
// erase up to n_ctx*contextErase tokens
|
||||
int n_keep = shouldAddBOS();
|
||||
int n_past = promptCtx.n_past;
|
||||
int n_discard = std::min(n_past - n_keep, int(promptCtx.n_ctx * promptCtx.contextErase));
|
||||
int n_discard = std::min(n_past - n_keep, int(contextLength() * promptCtx.contextErase));
|
||||
|
||||
assert(n_discard > 0);
|
||||
if (n_discard <= 0)
|
||||
@ -638,8 +645,9 @@ void LLamaModel::shiftContext(PromptContext &promptCtx)
|
||||
llama_kv_cache_seq_rm (d_ptr->ctx, 0, n_keep, n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(d_ptr->ctx, 0, n_keep + n_discard, n_past, -n_discard);
|
||||
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
auto &inp = d_ptr->inputTokens;
|
||||
inp.erase(inp.begin() + n_keep, inp.begin() + n_keep + n_discard);
|
||||
promptCtx.n_past = inp.size();
|
||||
}
|
||||
|
||||
int32_t LLamaModel::contextLength() const
|
||||
@ -647,6 +655,60 @@ int32_t LLamaModel::contextLength() const
|
||||
return llama_n_ctx(d_ptr->ctx);
|
||||
}
|
||||
|
||||
int32_t LLamaModel::inputLength() const
|
||||
{
|
||||
return d_ptr->inputTokens.size();
|
||||
}
|
||||
|
||||
void LLamaModel::setTokenizeInputPosition(int32_t pos)
|
||||
{
|
||||
assert(pos >= 0);
|
||||
m_tokenize_last_token = pos ? d_ptr->inputTokens.at(size_t(pos) - 1) : -1; // not serialized
|
||||
}
|
||||
|
||||
auto LLamaModel::computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
|
||||
-> std::vector<Token>::const_iterator
|
||||
{
|
||||
assert(ctx.n_past >= 0);
|
||||
auto pos = size_t(ctx.n_past);
|
||||
if (pos > d_ptr->inputTokens.size()) {
|
||||
std::ostringstream ss;
|
||||
ss << "n_past=" << pos << " is past end of token cache length=" << d_ptr->inputTokens.size();
|
||||
throw std::out_of_range(ss.str());
|
||||
}
|
||||
|
||||
// find common prefix
|
||||
auto cacheIt = d_ptr->inputTokens.begin();
|
||||
auto inputIt = input.begin();
|
||||
while (cacheIt < d_ptr->inputTokens.end() && inputIt < input.end() && *cacheIt == *inputIt) {
|
||||
++cacheIt; ++inputIt; ++pos;
|
||||
}
|
||||
// tell the caller to ignore the tokens between [begin, inputIt)
|
||||
return inputIt;
|
||||
}
|
||||
|
||||
void LLamaModel::setModelInputPosition(PromptContext &ctx, int32_t pos)
|
||||
{
|
||||
auto &inp = d_ptr->inputTokens;
|
||||
assert(pos >= 0);
|
||||
assert(pos <= inp.size());
|
||||
// truncate token cache to end at the new n_past
|
||||
if (pos < inp.size())
|
||||
inp.resize(pos);
|
||||
ctx.n_past = pos;
|
||||
}
|
||||
|
||||
void LLamaModel::appendInputToken(PromptContext &ctx, Token tok)
|
||||
{
|
||||
d_ptr->inputTokens.push_back(tok);
|
||||
ctx.n_past += 1;
|
||||
}
|
||||
|
||||
auto LLamaModel::inputTokens() const -> std::span<const Token>
|
||||
{
|
||||
return d_ptr->inputTokens;
|
||||
}
|
||||
|
||||
const std::vector<LLModel::Token> &LLamaModel::endTokens() const
|
||||
{
|
||||
return d_ptr->end_tokens;
|
||||
|
@ -28,8 +28,8 @@ public:
|
||||
bool isModelLoaded() const override;
|
||||
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(std::span<uint8_t> dest) const override;
|
||||
size_t restoreState(std::span<const uint8_t> src) override;
|
||||
size_t saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const override;
|
||||
size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) override;
|
||||
void setThreadCount(int32_t n_threads) override;
|
||||
int32_t threadCount() const override;
|
||||
std::vector<GPUDevice> availableGPUDevices(size_t memoryRequired = 0) const override;
|
||||
@ -48,10 +48,7 @@ public:
|
||||
void embed(const std::vector<std::string> &texts, float *embeddings, bool isRetrieval, int dimensionality = -1,
|
||||
size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<LLamaPrivate> d_ptr;
|
||||
bool m_supportsEmbedding = false;
|
||||
bool m_supportsCompletion = false;
|
||||
int32_t contextLength() const override;
|
||||
|
||||
protected:
|
||||
std::vector<Token> tokenize(std::string_view str, bool special) override;
|
||||
@ -59,9 +56,15 @@ protected:
|
||||
std::string tokenToString(Token id) const override;
|
||||
void initSampler(PromptContext &ctx) override;
|
||||
Token sampleToken() const override;
|
||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
|
||||
bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const override;
|
||||
void shiftContext(PromptContext &promptCtx) override;
|
||||
int32_t contextLength() const override;
|
||||
int32_t inputLength() const override;
|
||||
void setTokenizeInputPosition(int32_t pos) override;
|
||||
auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
|
||||
-> std::vector<Token>::const_iterator override;
|
||||
void setModelInputPosition(PromptContext &ctx, int32_t pos) override;
|
||||
void appendInputToken(PromptContext &ctx, Token tok) override;
|
||||
std::span<const Token> inputTokens() const override;
|
||||
const std::vector<Token> &endTokens() const override;
|
||||
bool shouldAddBOS() const override;
|
||||
int32_t maxContextLength(std::string const &modelPath) const override;
|
||||
@ -70,6 +73,11 @@ protected:
|
||||
void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
|
||||
size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb,
|
||||
const EmbModelSpec *spec);
|
||||
|
||||
private:
|
||||
std::unique_ptr<LLamaPrivate> d_ptr;
|
||||
bool m_supportsEmbedding = false;
|
||||
bool m_supportsCompletion = false;
|
||||
};
|
||||
|
||||
#endif // LLAMAMODEL_H
|
||||
|
@ -14,6 +14,11 @@
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
#include <span>
|
||||
|
||||
namespace ranges = std::ranges;
|
||||
|
||||
static_assert(sizeof(token_t) == sizeof(LLModel::Token));
|
||||
|
||||
struct LLModelWrapper {
|
||||
LLModel *llModel = nullptr;
|
||||
@ -85,22 +90,40 @@ bool llmodel_isModelLoaded(llmodel_model model)
|
||||
return wrapper->llModel->isModelLoaded();
|
||||
}
|
||||
|
||||
uint64_t llmodel_get_state_size(llmodel_model model)
|
||||
uint64_t llmodel_state_get_size(llmodel_model model)
|
||||
{
|
||||
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
||||
return wrapper->llModel->stateSize();
|
||||
}
|
||||
|
||||
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest, uint64_t size)
|
||||
uint64_t llmodel_state_get_data(llmodel_model model, uint8_t *state_out, uint64_t state_size,
|
||||
token_t **input_tokens_out, uint64_t *n_input_tokens)
|
||||
{
|
||||
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
||||
return wrapper->llModel->saveState({dest, size_t(size)});
|
||||
std::vector<LLModel::Token> inputTokens;
|
||||
auto bytesWritten = wrapper->llModel->saveState({state_out, size_t(state_size)}, inputTokens);
|
||||
if (bytesWritten) {
|
||||
auto *buf = new LLModel::Token[inputTokens.size()];
|
||||
ranges::copy(inputTokens, buf);
|
||||
*input_tokens_out = buf;
|
||||
*n_input_tokens = uint64_t(inputTokens.size());
|
||||
} else {
|
||||
*input_tokens_out = nullptr;
|
||||
*n_input_tokens = 0;
|
||||
}
|
||||
return bytesWritten;
|
||||
}
|
||||
|
||||
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src, uint64_t size)
|
||||
void llmodel_state_free_input_tokens(LLModel::Token *input_tokens)
|
||||
{
|
||||
delete[] input_tokens;
|
||||
}
|
||||
|
||||
uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint64_t state_size,
|
||||
const token_t *input_tokens, uint64_t n_input_tokens)
|
||||
{
|
||||
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
||||
return wrapper->llModel->restoreState({src, size_t(size)});
|
||||
return wrapper->llModel->restoreState({state, size_t(state_size)}, {input_tokens, size_t(n_input_tokens)});
|
||||
}
|
||||
|
||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
@ -120,7 +143,6 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
|
||||
// Copy the C prompt context
|
||||
wrapper->promptContext.n_past = ctx->n_past;
|
||||
wrapper->promptContext.n_ctx = ctx->n_ctx;
|
||||
wrapper->promptContext.n_predict = ctx->n_predict;
|
||||
wrapper->promptContext.top_k = ctx->top_k;
|
||||
wrapper->promptContext.top_p = ctx->top_p;
|
||||
@ -136,14 +158,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
wrapper->promptContext, special,
|
||||
fake_reply ? std::make_optional<std::string_view>(fake_reply) : std::nullopt);
|
||||
|
||||
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
||||
// which involves no copies
|
||||
ctx->tokens = wrapper->promptContext.tokens.data();
|
||||
ctx->tokens_size = wrapper->promptContext.tokens.size();
|
||||
|
||||
// Update the rest of the C prompt context
|
||||
ctx->n_past = wrapper->promptContext.n_past;
|
||||
ctx->n_ctx = wrapper->promptContext.n_ctx;
|
||||
ctx->n_predict = wrapper->promptContext.n_predict;
|
||||
ctx->top_k = wrapper->promptContext.top_k;
|
||||
ctx->top_p = wrapper->promptContext.top_p;
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <optional>
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
@ -66,19 +67,14 @@ void LLModel::prompt(const std::string &prompt,
|
||||
ss << "n_past=" << promptCtx.n_past << " is past end of context length=" << contextLength();
|
||||
throw std::out_of_range(ss.str());
|
||||
}
|
||||
if (promptCtx.n_past > promptCtx.tokens.size()) {
|
||||
if (promptCtx.n_past > inputLength()) {
|
||||
std::ostringstream ss;
|
||||
ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << promptCtx.tokens.size();
|
||||
ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << inputLength();
|
||||
throw std::out_of_range(ss.str());
|
||||
}
|
||||
|
||||
promptCtx.n_ctx = contextLength();
|
||||
promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH);
|
||||
|
||||
if (promptCtx.n_past < promptCtx.tokens.size())
|
||||
promptCtx.tokens.resize(promptCtx.n_past);
|
||||
m_tokenize_last_token = promptCtx.tokens.empty() ? -1 : promptCtx.tokens.back(); // not serialized
|
||||
|
||||
// parse the prompt template
|
||||
std::vector<std::smatch> placeholders;
|
||||
{
|
||||
@ -90,6 +86,8 @@ void LLModel::prompt(const std::string &prompt,
|
||||
}
|
||||
}
|
||||
|
||||
setTokenizeInputPosition(promptCtx.n_past);
|
||||
|
||||
// tokenize the user prompt
|
||||
std::vector<Token> embd_inp;
|
||||
if (placeholders.empty()) {
|
||||
@ -118,7 +116,8 @@ void LLModel::prompt(const std::string &prompt,
|
||||
}
|
||||
|
||||
// decode the user prompt
|
||||
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
|
||||
if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, /*isResponse*/ false,
|
||||
/*alwaysDecode*/ true))
|
||||
return; // error
|
||||
|
||||
// decode the assistant's reply, either generated or spoofed
|
||||
@ -151,36 +150,67 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
|
||||
bool allowContextShift,
|
||||
PromptContext &promptCtx,
|
||||
std::vector<Token> embd_inp,
|
||||
bool isResponse) {
|
||||
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
|
||||
bool isResponse,
|
||||
bool alwaysDecode) {
|
||||
if ((int) embd_inp.size() > contextLength() - 4) {
|
||||
// FIXME: (Adam) We should find a way to bubble these strings to the UI level to allow for
|
||||
// translation
|
||||
responseCallback(-1, "Your message was too long and could not be processed. Please try again with something shorter.");
|
||||
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
|
||||
" tokens and the context window is " << promptCtx.n_ctx << "!\n";
|
||||
" tokens and the context window is " << contextLength() << "!\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
// FIXME(jared): There are mitigations for this situation, such as making room before
|
||||
// copying the prompt context, or restoring the KV cache when we restore the prompt
|
||||
// context.
|
||||
if (!allowContextShift && promptCtx.n_past + embd_inp.size() > promptCtx.n_ctx) {
|
||||
if (!allowContextShift && promptCtx.n_past + embd_inp.size() > contextLength()) {
|
||||
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_eval=" << embd_inp.size()
|
||||
<< ", n_ctx=" << promptCtx.n_ctx << "\n";
|
||||
<< ", n_ctx=" << contextLength() << "\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
// process the prompt in batches
|
||||
// always decode something before generating, even if cached
|
||||
if (alwaysDecode && embd_inp.empty()) {
|
||||
auto cache = inputTokens();
|
||||
if (!promptCtx.n_past)
|
||||
throw std::runtime_error("zero token prompt is not supported");
|
||||
assert(!cache.empty());
|
||||
embd_inp.push_back(cache.back());
|
||||
promptCtx.n_past--;
|
||||
}
|
||||
|
||||
// Find the greatest n_past where the beginning of embd_inp matches the end of the token cache, starting at the
|
||||
// requested n_past.
|
||||
// This is used to skip unnecessary work when the prompt shares a common prefix with the previous result.
|
||||
auto embd_inp_start = computeModelInputPosition(promptCtx, embd_inp);
|
||||
size_t start_offset = embd_inp_start - embd_inp.begin();
|
||||
|
||||
// always decode up to a full batch before generating, even if cached
|
||||
if (alwaysDecode)
|
||||
start_offset -= std::min(promptCtx.n_batch, int32_t(start_offset));
|
||||
|
||||
setModelInputPosition(promptCtx, promptCtx.n_past + start_offset);
|
||||
|
||||
// execute the callback even for skipped tokens
|
||||
size_t i = 0;
|
||||
for (; i < start_offset; i++) {
|
||||
Token tok = embd_inp[i];
|
||||
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
|
||||
if (!res)
|
||||
return false;
|
||||
}
|
||||
|
||||
// process the prompt in batches
|
||||
while (i < embd_inp.size()) {
|
||||
size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size());
|
||||
std::vector<Token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
||||
std::span<const Token> batch(embd_inp.begin() + i, embd_inp.begin() + batch_end);
|
||||
|
||||
// Check if the context has run out...
|
||||
if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) {
|
||||
if (promptCtx.n_past + int32_t(batch.size()) > contextLength()) {
|
||||
assert(allowContextShift);
|
||||
shiftContext(promptCtx);
|
||||
assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx);
|
||||
assert(promptCtx.n_past + int32_t(batch.size()) <= contextLength());
|
||||
}
|
||||
|
||||
if (!evalTokens(promptCtx, batch)) {
|
||||
@ -190,9 +220,8 @@ bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
|
||||
|
||||
size_t tokens = batch_end - i;
|
||||
for (size_t t = 0; t < tokens; ++t) {
|
||||
promptCtx.tokens.push_back(batch.at(t));
|
||||
promptCtx.n_past += 1;
|
||||
Token tok = batch.at(t);
|
||||
Token tok = batch[t];
|
||||
appendInputToken(promptCtx, tok);
|
||||
bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok);
|
||||
if (!res)
|
||||
return false;
|
||||
@ -232,8 +261,8 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
|
||||
// Don't even start if there is no room
|
||||
if (!promptCtx.n_predict)
|
||||
return;
|
||||
if (!allowContextShift && promptCtx.n_past >= promptCtx.n_ctx) {
|
||||
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << promptCtx.n_ctx
|
||||
if (!allowContextShift && promptCtx.n_past >= contextLength()) {
|
||||
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << contextLength()
|
||||
<< "\n";
|
||||
return;
|
||||
}
|
||||
@ -254,23 +283,22 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
|
||||
|
||||
auto accept = [this, &promptCtx, &new_tok, allowContextShift]() -> bool {
|
||||
// Shift context if out of space
|
||||
if (promptCtx.n_past >= promptCtx.n_ctx) {
|
||||
if (promptCtx.n_past >= contextLength()) {
|
||||
(void)allowContextShift;
|
||||
assert(allowContextShift);
|
||||
shiftContext(promptCtx);
|
||||
assert(promptCtx.n_past < promptCtx.n_ctx);
|
||||
assert(promptCtx.n_past < contextLength());
|
||||
}
|
||||
|
||||
// Accept the token
|
||||
Token tok = std::exchange(new_tok, std::nullopt).value();
|
||||
if (!evalTokens(promptCtx, { tok })) {
|
||||
if (!evalTokens(promptCtx, { &tok, 1 })) {
|
||||
// TODO(jared): raise an exception
|
||||
std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
promptCtx.tokens.push_back(tok);
|
||||
promptCtx.n_past += 1;
|
||||
appendInputToken(promptCtx, tok);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -309,9 +337,9 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
|
||||
}
|
||||
|
||||
// Optionally stop if the context will run out
|
||||
if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= promptCtx.n_ctx) {
|
||||
if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= contextLength()) {
|
||||
std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx="
|
||||
<< promptCtx.n_ctx << "\n";
|
||||
<< contextLength() << "\n";
|
||||
stop = true;
|
||||
}
|
||||
|
||||
@ -357,16 +385,17 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
|
||||
}
|
||||
}
|
||||
|
||||
auto &tokens = promptCtx.tokens;
|
||||
if (tokens.size() < cachedTokens.size()) {
|
||||
if (inputLength() < cachedTokens.size()) {
|
||||
/* This is theoretically possible if the longest stop sequence is greater than
|
||||
* n_ctx * contextErase tokens. */
|
||||
throw std::runtime_error("shifted too much context, can't go back");
|
||||
}
|
||||
|
||||
auto discard_start = tokens.end() - cachedTokens.size();
|
||||
assert(std::equal(discard_start, tokens.end(), cachedTokens.begin()));
|
||||
tokens.erase(discard_start, tokens.end());
|
||||
#ifndef NDEBUG
|
||||
auto inp = inputTokens();
|
||||
auto discard_start = inp.end() - cachedTokens.size();
|
||||
assert(std::equal(discard_start, inp.end(), cachedTokens.begin()));
|
||||
#endif
|
||||
|
||||
promptCtx.n_past -= cachedTokens.size();
|
||||
}
|
||||
|
@ -113,10 +113,7 @@ def _old_loop(gpt4all_instance):
|
||||
full_response = gpt4all_instance.chat_completion(
|
||||
MESSAGES,
|
||||
# preferential kwargs for chat ux
|
||||
logits_size=0,
|
||||
tokens_size=0,
|
||||
n_past=0,
|
||||
n_ctx=0,
|
||||
n_predict=200,
|
||||
top_k=40,
|
||||
top_p=0.9,
|
||||
|
@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
|
||||
|
||||
### Added
|
||||
- Warn on Windows if the Microsoft Visual C++ runtime libraries are not found ([#2920](https://github.com/nomic-ai/gpt4all/pull/2920))
|
||||
- Basic cache for faster prefill when the input shares a prefix with previous context ([#3073](https://github.com/nomic-ai/gpt4all/pull/3073))
|
||||
|
||||
### Changed
|
||||
- Rebase llama.cpp on latest upstream as of September 26th ([#2998](https://github.com/nomic-ai/gpt4all/pull/2998))
|
||||
|
@ -116,10 +116,7 @@ llmodel = load_llmodel_library()
|
||||
|
||||
class LLModelPromptContext(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("tokens", ctypes.POINTER(ctypes.c_int32)),
|
||||
("tokens_size", ctypes.c_size_t),
|
||||
("n_past", ctypes.c_int32),
|
||||
("n_ctx", ctypes.c_int32),
|
||||
("n_predict", ctypes.c_int32),
|
||||
("top_k", ctypes.c_int32),
|
||||
("top_p", ctypes.c_float),
|
||||
@ -393,9 +390,7 @@ class LLModel:
|
||||
):
|
||||
if self.context is None:
|
||||
context = LLModelPromptContext(
|
||||
tokens_size=0,
|
||||
n_past=0,
|
||||
n_ctx=0,
|
||||
n_predict=n_predict,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
|
@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
|
||||
### Added
|
||||
- Add ability to attach text, markdown, and rst files to chat ([#3135](https://github.com/nomic-ai/gpt4all/pull/3135))
|
||||
- Add feature to minimize to system tray (by [@bgallois](https://github.com/bgallois) in ([#3109](https://github.com/nomic-ai/gpt4all/pull/3109))
|
||||
- Basic cache for faster prefill when the input shares a prefix with previous context ([#3073](https://github.com/nomic-ai/gpt4all/pull/3073))
|
||||
|
||||
### Changed
|
||||
- Implement Qt 6.8 compatibility ([#3121](https://github.com/nomic-ai/gpt4all/pull/3121))
|
||||
|
@ -51,7 +51,6 @@ bool ChatAPI::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
||||
void ChatAPI::setThreadCount(int32_t n_threads)
|
||||
{
|
||||
Q_UNUSED(n_threads);
|
||||
qt_noop();
|
||||
}
|
||||
|
||||
int32_t ChatAPI::threadCount() const
|
||||
@ -68,24 +67,6 @@ bool ChatAPI::isModelLoaded() const
|
||||
return true;
|
||||
}
|
||||
|
||||
// All three of the state virtual functions are handled custom inside of chatllm save/restore
|
||||
size_t ChatAPI::stateSize() const
|
||||
{
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
size_t ChatAPI::saveState(std::span<uint8_t> dest) const
|
||||
{
|
||||
Q_UNUSED(dest);
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
size_t ChatAPI::restoreState(std::span<const uint8_t> src)
|
||||
{
|
||||
Q_UNUSED(src);
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
|
||||
void ChatAPI::prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
#include <gpt4all-backend/llmodel.h>
|
||||
|
||||
#include <QByteArray>
|
||||
#include <QByteArray> // IWYU pragma: keep
|
||||
#include <QNetworkReply>
|
||||
#include <QObject>
|
||||
#include <QString>
|
||||
@ -13,6 +13,8 @@
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <span>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
@ -63,9 +65,15 @@ public:
|
||||
bool loadModel(const std::string &modelPath, int n_ctx, int ngl) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(std::span<uint8_t> dest) const override;
|
||||
size_t restoreState(std::span<const uint8_t> src) override;
|
||||
|
||||
// All three of the state virtual functions are handled custom inside of chatllm save/restore
|
||||
size_t stateSize() const override
|
||||
{ throwNotImplemented(); }
|
||||
size_t saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const override
|
||||
{ Q_UNUSED(stateOut); Q_UNUSED(inputTokensOut); throwNotImplemented(); }
|
||||
size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) override
|
||||
{ Q_UNUSED(state); Q_UNUSED(inputTokens); throwNotImplemented(); }
|
||||
|
||||
void prompt(const std::string &prompt,
|
||||
const std::string &promptTemplate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
@ -88,6 +96,10 @@ public:
|
||||
|
||||
bool callResponse(int32_t token, const std::string &string);
|
||||
|
||||
[[noreturn]]
|
||||
int32_t contextLength() const override
|
||||
{ throwNotImplemented(); }
|
||||
|
||||
Q_SIGNALS:
|
||||
void request(const QString &apiKey,
|
||||
LLModel::PromptContext *ctx,
|
||||
@ -98,60 +110,69 @@ protected:
|
||||
// them as they are only called from the default implementation of 'prompt' which we override and
|
||||
// completely replace
|
||||
|
||||
[[noreturn]]
|
||||
static void throwNotImplemented() { throw std::logic_error("not implemented"); }
|
||||
|
||||
[[noreturn]]
|
||||
std::vector<Token> tokenize(std::string_view str, bool special) override
|
||||
{
|
||||
(void)str;
|
||||
(void)special;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
{ Q_UNUSED(str); Q_UNUSED(special); throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
bool isSpecialToken(Token id) const override
|
||||
{
|
||||
(void)id;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
{ Q_UNUSED(id); throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
std::string tokenToString(Token id) const override
|
||||
{
|
||||
(void)id;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
{ Q_UNUSED(id); throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
void initSampler(PromptContext &ctx) override
|
||||
{
|
||||
(void)ctx;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
{ Q_UNUSED(ctx); throwNotImplemented(); }
|
||||
|
||||
Token sampleToken() const override { throw std::logic_error("not implemented"); }
|
||||
[[noreturn]]
|
||||
Token sampleToken() const override
|
||||
{ throwNotImplemented(); }
|
||||
|
||||
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override
|
||||
{
|
||||
(void)ctx;
|
||||
(void)tokens;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
[[noreturn]]
|
||||
bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const override
|
||||
{ Q_UNUSED(ctx); Q_UNUSED(tokens); throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
void shiftContext(PromptContext &promptCtx) override
|
||||
{
|
||||
(void)promptCtx;
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
{ Q_UNUSED(promptCtx); throwNotImplemented(); }
|
||||
|
||||
int32_t contextLength() const override
|
||||
{
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
[[noreturn]]
|
||||
int32_t inputLength() const override
|
||||
{ throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
void setTokenizeInputPosition(int32_t pos) override
|
||||
{ Q_UNUSED(pos); throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
|
||||
-> std::vector<Token>::const_iterator override
|
||||
{ Q_UNUSED(ctx); Q_UNUSED(input); throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
void setModelInputPosition(PromptContext &ctx, int32_t pos) override
|
||||
{ Q_UNUSED(ctx); Q_UNUSED(pos); throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
void appendInputToken(PromptContext &ctx, Token tok) override
|
||||
{ Q_UNUSED(ctx); Q_UNUSED(tok); throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
const std::vector<Token> &endTokens() const override
|
||||
{
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
{ throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
bool shouldAddBOS() const override
|
||||
{
|
||||
throw std::logic_error("not implemented");
|
||||
}
|
||||
{ throwNotImplemented(); }
|
||||
|
||||
[[noreturn]]
|
||||
std::span<const Token> inputTokens() const override
|
||||
{ throwNotImplemented(); }
|
||||
|
||||
private:
|
||||
std::function<bool(int32_t, const std::string&)> m_responseCallback;
|
||||
|
@ -33,6 +33,7 @@
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <span>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@ -404,7 +405,6 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
|
||||
|
||||
QString requestedDevice = MySettings::globalInstance()->device();
|
||||
int n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo);
|
||||
m_ctx.n_ctx = n_ctx;
|
||||
int ngl = MySettings::globalInstance()->modelGpuLayers(modelInfo);
|
||||
|
||||
std::string backend = "auto";
|
||||
@ -632,7 +632,6 @@ void ChatLLM::regenerateResponse()
|
||||
else
|
||||
m_ctx.n_past -= m_promptResponseTokens;
|
||||
m_ctx.n_past = std::max(0, m_ctx.n_past);
|
||||
m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end());
|
||||
m_promptResponseTokens = 0;
|
||||
m_promptTokens = 0;
|
||||
m_response = m_trimmedResponse = std::string();
|
||||
@ -1078,12 +1077,13 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
|
||||
stream << responseLogits;
|
||||
}
|
||||
stream << m_ctx.n_past;
|
||||
if (version >= 7) {
|
||||
stream << m_ctx.n_ctx;
|
||||
}
|
||||
stream << quint64(m_ctx.tokens.size());
|
||||
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.tokens.data()), m_ctx.tokens.size() * sizeof(int));
|
||||
saveState();
|
||||
if (version >= 7) {
|
||||
stream << m_stateContextLength;
|
||||
}
|
||||
stream << quint64(m_stateInputTokens.size());
|
||||
stream.writeRawData(reinterpret_cast<const char *>(m_stateInputTokens.data()),
|
||||
m_stateInputTokens.size() * sizeof(m_stateInputTokens[0]));
|
||||
QByteArray compressed = qCompress(m_state);
|
||||
stream << compressed;
|
||||
#if defined(DEBUG)
|
||||
@ -1145,7 +1145,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
|
||||
if (version >= 7) {
|
||||
uint32_t n_ctx;
|
||||
stream >> n_ctx;
|
||||
if (!discardKV) m_ctx.n_ctx = n_ctx;
|
||||
if (!discardKV) m_stateContextLength = n_ctx;
|
||||
}
|
||||
|
||||
if (version < 9) {
|
||||
@ -1157,10 +1157,10 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
|
||||
quint64 tokensSize;
|
||||
stream >> tokensSize;
|
||||
if (!discardKV) {
|
||||
m_ctx.tokens.resize(tokensSize);
|
||||
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
|
||||
m_stateInputTokens.resize(tokensSize);
|
||||
stream.readRawData(reinterpret_cast<char *>(m_stateInputTokens.data()), tokensSize * sizeof(m_stateInputTokens[0]));
|
||||
} else {
|
||||
stream.skipRawData(tokensSize * sizeof(int));
|
||||
stream.skipRawData(tokensSize * sizeof(m_stateInputTokens[0]));
|
||||
}
|
||||
|
||||
if (version >= 1) {
|
||||
@ -1202,13 +1202,16 @@ void ChatLLM::saveState()
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "saveState" << m_llmThread.objectName() << "size:" << m_state.size();
|
||||
#endif
|
||||
bool ok = m_llModelInfo.model->saveState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())});
|
||||
bool ok = m_llModelInfo.model->saveState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())},
|
||||
m_stateInputTokens);
|
||||
if (!ok) {
|
||||
// FIXME(jared): how badly does this situation break GPT4All?
|
||||
qWarning() << "ChatLLM failed to save LLModel state";
|
||||
m_state.clear();
|
||||
m_state.squeeze();
|
||||
m_stateContextLength = -1;
|
||||
}
|
||||
m_stateContextLength = m_llModelInfo.model->contextLength();
|
||||
}
|
||||
|
||||
void ChatLLM::restoreState()
|
||||
@ -1235,13 +1238,22 @@ void ChatLLM::restoreState()
|
||||
if (m_state.isEmpty())
|
||||
return;
|
||||
|
||||
size_t bytesRead = m_llModelInfo.model->restoreState({reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())});
|
||||
if (bytesRead) {
|
||||
m_processedSystemPrompt = true;
|
||||
m_pristineLoadedState = true;
|
||||
} else {
|
||||
qWarning() << "restoring state from text because of error reading state (mismatch or corrupt data)";
|
||||
if (m_llModelInfo.model->contextLength() != m_stateContextLength) {
|
||||
qWarning() << "restoring state from text because of n_ctx mismatch (state"
|
||||
<< m_stateContextLength << "model" << m_llModelInfo.model->contextLength() << ")";
|
||||
m_restoreStateFromText = true;
|
||||
} else {
|
||||
size_t bytesRead = m_llModelInfo.model->restoreState(
|
||||
{reinterpret_cast<uint8_t *>(m_state.data()), size_t(m_state.size())},
|
||||
m_stateInputTokens
|
||||
);
|
||||
if (!bytesRead) {
|
||||
qWarning() << "restoring state from text because of error reading state (mismatch or corrupt data)";
|
||||
m_restoreStateFromText = true;
|
||||
} else {
|
||||
m_processedSystemPrompt = true;
|
||||
m_pristineLoadedState = true;
|
||||
}
|
||||
}
|
||||
|
||||
// free local state copy unless unload is pending
|
||||
|
@ -9,7 +9,7 @@
|
||||
#include <QByteArray>
|
||||
#include <QElapsedTimer>
|
||||
#include <QFileInfo>
|
||||
#include <QList>
|
||||
#include <QList> // IWYU pragma: keep
|
||||
#include <QObject>
|
||||
#include <QPointer>
|
||||
#include <QString>
|
||||
@ -22,6 +22,7 @@
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using namespace Qt::Literals::StringLiterals;
|
||||
|
||||
@ -277,6 +278,8 @@ private:
|
||||
ModelInfo m_modelInfo;
|
||||
TokenTimer *m_timer;
|
||||
QByteArray m_state;
|
||||
std::vector<LLModel::Token> m_stateInputTokens;
|
||||
int32_t m_stateContextLength = -1;
|
||||
QThread m_llmThread;
|
||||
std::atomic<bool> m_stopGenerating;
|
||||
std::atomic<bool> m_shouldBeLoaded;
|
||||
|
Loading…
Reference in New Issue
Block a user