mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-11-10 11:49:05 +03:00
Move the promptCallback to own function.
This commit is contained in:
parent
0e9f85bcda
commit
ba4b28fcd5
36
llm.cpp
36
llm.cpp
@ -38,7 +38,7 @@ static QString modelFilePath(const QString &modelName)
|
||||
LLMObject::LLMObject()
|
||||
: QObject{nullptr}
|
||||
, m_llmodel(nullptr)
|
||||
, m_responseTokens(0)
|
||||
, m_promptResponseTokens(0)
|
||||
, m_responseLogits(0)
|
||||
, m_isRecalc(false)
|
||||
{
|
||||
@ -133,12 +133,12 @@ bool LLMObject::isModelLoaded() const
|
||||
|
||||
void LLMObject::regenerateResponse()
|
||||
{
|
||||
s_ctx.n_past -= m_responseTokens;
|
||||
s_ctx.n_past -= m_promptResponseTokens;
|
||||
s_ctx.n_past = std::max(0, s_ctx.n_past);
|
||||
// FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove?
|
||||
s_ctx.logits.erase(s_ctx.logits.end() -= m_responseLogits, s_ctx.logits.end());
|
||||
s_ctx.tokens.erase(s_ctx.tokens.end() -= m_responseTokens, s_ctx.tokens.end());
|
||||
m_responseTokens = 0;
|
||||
s_ctx.tokens.erase(s_ctx.tokens.end() -= m_promptResponseTokens, s_ctx.tokens.end());
|
||||
m_promptResponseTokens = 0;
|
||||
m_responseLogits = 0;
|
||||
m_response = std::string();
|
||||
emit responseChanged();
|
||||
@ -146,7 +146,7 @@ void LLMObject::regenerateResponse()
|
||||
|
||||
void LLMObject::resetResponse()
|
||||
{
|
||||
m_responseTokens = 0;
|
||||
m_promptResponseTokens = 0;
|
||||
m_responseLogits = 0;
|
||||
m_response = std::string();
|
||||
emit responseChanged();
|
||||
@ -263,6 +263,18 @@ QList<QString> LLMObject::modelList() const
|
||||
return list;
|
||||
}
|
||||
|
||||
bool LLMObject::handlePrompt(int32_t token)
|
||||
{
|
||||
if (s_ctx.tokens.size() == s_ctx.n_ctx)
|
||||
s_ctx.tokens.erase(s_ctx.tokens.begin());
|
||||
s_ctx.tokens.push_back(token);
|
||||
|
||||
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
|
||||
// the entire context window which we can reset on regenerate prompt
|
||||
++m_promptResponseTokens;
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool LLMObject::handleResponse(int32_t token, const std::string &response)
|
||||
{
|
||||
#if 0
|
||||
@ -282,13 +294,12 @@ bool LLMObject::handleResponse(int32_t token, const std::string &response)
|
||||
s_ctx.tokens.erase(s_ctx.tokens.begin());
|
||||
s_ctx.tokens.push_back(token);
|
||||
|
||||
// m_responseTokens and m_responseLogits are related to last prompt/response not
|
||||
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
|
||||
// the entire context window which we can reset on regenerate prompt
|
||||
++m_responseTokens;
|
||||
if (!response.empty()) {
|
||||
m_response.append(response);
|
||||
emit responseChanged();
|
||||
}
|
||||
++m_promptResponseTokens;
|
||||
Q_ASSERT(!response.empty());
|
||||
m_response.append(response);
|
||||
emit responseChanged();
|
||||
|
||||
// Stop generation if we encounter prompt or response tokens
|
||||
QString r = QString::fromStdString(m_response);
|
||||
@ -315,6 +326,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
|
||||
QString instructPrompt = prompt_template.arg(prompt);
|
||||
|
||||
m_stopGenerating = false;
|
||||
auto promptFunc = std::bind(&LLMObject::handlePrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1,
|
||||
std::placeholders::_2);
|
||||
auto recalcFunc = std::bind(&LLMObject::handleRecalculate, this, std::placeholders::_1);
|
||||
@ -327,7 +339,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
|
||||
s_ctx.n_batch = n_batch;
|
||||
s_ctx.repeat_penalty = repeat_penalty;
|
||||
s_ctx.repeat_last_n = repeat_penalty_tokens;
|
||||
m_llmodel->prompt(instructPrompt.toStdString(), responseFunc, recalcFunc, s_ctx);
|
||||
m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, s_ctx);
|
||||
m_responseLogits += s_ctx.logits.size() - logitsBefore;
|
||||
std::string trimmed = trim_whitespace(m_response);
|
||||
if (trimmed != m_response) {
|
||||
|
3
llm.h
3
llm.h
@ -58,13 +58,14 @@ Q_SIGNALS:
|
||||
private:
|
||||
void resetContextPrivate();
|
||||
bool loadModelPrivate(const QString &modelName);
|
||||
bool handlePrompt(int32_t token);
|
||||
bool handleResponse(int32_t token, const std::string &response);
|
||||
bool handleRecalculate(bool isRecalc);
|
||||
|
||||
private:
|
||||
LLModel *m_llmodel;
|
||||
std::string m_response;
|
||||
quint32 m_responseTokens;
|
||||
quint32 m_promptResponseTokens;
|
||||
quint32 m_responseLogits;
|
||||
QString m_modelName;
|
||||
QThread m_llmThread;
|
||||
|
@ -686,8 +686,9 @@ bool GPTJ::isModelLoaded() const
|
||||
}
|
||||
|
||||
void GPTJ::prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t, const std::string&)> response,
|
||||
std::function<bool(bool)> recalculate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &promptCtx) {
|
||||
|
||||
if (!isModelLoaded()) {
|
||||
@ -708,7 +709,7 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
promptCtx.n_ctx = d_ptr->model.hparams.n_ctx;
|
||||
|
||||
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
|
||||
response(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed.");
|
||||
responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed.");
|
||||
std::cerr << "GPT-J ERROR: The prompt is" << embd_inp.size() <<
|
||||
"tokens and the context window is" << promptCtx.n_ctx << "!\n";
|
||||
return;
|
||||
@ -741,7 +742,7 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculate);
|
||||
recalculateContext(promptCtx, recalculateCallback);
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
@ -750,10 +751,10 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
std::cerr << "GPT-J ERROR: Failed to process prompt\n";
|
||||
return;
|
||||
}
|
||||
// We pass a null string for each token to see if the user has asked us to stop...
|
||||
|
||||
size_t tokens = batch_end - i;
|
||||
for (size_t t = 0; t < tokens; ++t)
|
||||
if (!response(batch.at(t), ""))
|
||||
if (!promptCallback(batch.at(t)))
|
||||
return;
|
||||
promptCtx.n_past += batch.size();
|
||||
i = batch_end;
|
||||
@ -790,8 +791,8 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculate);
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
recalculateContext(promptCtx, recalculateCallback);
|
||||
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
const int64_t t_start_predict_us = ggml_time_us();
|
||||
@ -805,7 +806,7 @@ void GPTJ::prompt(const std::string &prompt,
|
||||
promptCtx.n_past += 1;
|
||||
// display text
|
||||
++totalPredictions;
|
||||
if (id == 50256 /*end of text*/ || !response(id, d_ptr->vocab.id_to_token[id]))
|
||||
if (id == 50256 /*end of text*/ || !responseCallback(id, d_ptr->vocab.id_to_token[id]))
|
||||
goto stop_generating;
|
||||
}
|
||||
|
||||
|
@ -16,8 +16,9 @@ public:
|
||||
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
||||
bool isModelLoaded() const override;
|
||||
void prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t, const std::string&)> response,
|
||||
std::function<bool(bool)> recalculate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &ctx) override;
|
||||
void setThreadCount(int32_t n_threads) override;
|
||||
int32_t threadCount() override;
|
||||
|
@ -80,8 +80,9 @@ bool LLamaModel::isModelLoaded() const
|
||||
}
|
||||
|
||||
void LLamaModel::prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t, const std::string&)> response,
|
||||
std::function<bool(bool)> recalculate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &promptCtx) {
|
||||
|
||||
if (!isModelLoaded()) {
|
||||
@ -102,7 +103,7 @@ void LLamaModel::prompt(const std::string &prompt,
|
||||
promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx);
|
||||
|
||||
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
|
||||
response(-1, "The prompt size exceeds the context window size and cannot be processed.");
|
||||
responseCallback(-1, "The prompt size exceeds the context window size and cannot be processed.");
|
||||
std::cerr << "LLAMA ERROR: The prompt is" << embd_inp.size() <<
|
||||
"tokens and the context window is" << promptCtx.n_ctx << "!\n";
|
||||
return;
|
||||
@ -128,7 +129,7 @@ void LLamaModel::prompt(const std::string &prompt,
|
||||
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculate);
|
||||
recalculateContext(promptCtx, recalculateCallback);
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
@ -137,10 +138,9 @@ void LLamaModel::prompt(const std::string &prompt,
|
||||
return;
|
||||
}
|
||||
|
||||
// We pass a null string for each token to see if the user has asked us to stop...
|
||||
size_t tokens = batch_end - i;
|
||||
for (size_t t = 0; t < tokens; ++t)
|
||||
if (!response(batch.at(t), ""))
|
||||
if (!promptCallback(batch.at(t)))
|
||||
return;
|
||||
promptCtx.n_past += batch.size();
|
||||
i = batch_end;
|
||||
@ -162,8 +162,8 @@ void LLamaModel::prompt(const std::string &prompt,
|
||||
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
|
||||
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
|
||||
promptCtx.n_past = promptCtx.tokens.size();
|
||||
recalculateContext(promptCtx, recalculate);
|
||||
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
|
||||
recalculateContext(promptCtx, recalculateCallback);
|
||||
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
|
||||
}
|
||||
|
||||
if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) {
|
||||
@ -174,7 +174,7 @@ void LLamaModel::prompt(const std::string &prompt,
|
||||
promptCtx.n_past += 1;
|
||||
// display text
|
||||
++totalPredictions;
|
||||
if (id == llama_token_eos() || !response(id, llama_token_to_str(d_ptr->ctx, id)))
|
||||
if (id == llama_token_eos() || !responseCallback(id, llama_token_to_str(d_ptr->ctx, id)))
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -16,8 +16,9 @@ public:
|
||||
bool loadModel(const std::string &modelPath, std::istream &fin) override;
|
||||
bool isModelLoaded() const override;
|
||||
void prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t, const std::string&)> response,
|
||||
std::function<bool(bool)> recalculate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &ctx) override;
|
||||
void setThreadCount(int32_t n_threads) override;
|
||||
int32_t threadCount() override;
|
||||
|
@ -29,8 +29,9 @@ public:
|
||||
// window
|
||||
};
|
||||
virtual void prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t, const std::string&)> response,
|
||||
std::function<bool(bool)> recalculate,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
std::function<bool(bool)> recalculateCallback,
|
||||
PromptContext &ctx) = 0;
|
||||
virtual void setThreadCount(int32_t n_threads) {}
|
||||
virtual int32_t threadCount() { return 1; }
|
||||
|
@ -49,6 +49,11 @@ bool llmodel_isModelLoaded(llmodel_model model)
|
||||
}
|
||||
|
||||
// Wrapper functions for the C callbacks
|
||||
bool prompt_wrapper(int32_t token_id, void *user_data) {
|
||||
llmodel_prompt_callback callback = reinterpret_cast<llmodel_prompt_callback>(user_data);
|
||||
return callback(token_id);
|
||||
}
|
||||
|
||||
bool response_wrapper(int32_t token_id, const std::string &response, void *user_data) {
|
||||
llmodel_response_callback callback = reinterpret_cast<llmodel_response_callback>(user_data);
|
||||
return callback(token_id, response.c_str());
|
||||
@ -60,17 +65,20 @@ bool recalculate_wrapper(bool is_recalculating, void *user_data) {
|
||||
}
|
||||
|
||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
llmodel_response_callback response,
|
||||
llmodel_recalculate_callback recalculate,
|
||||
llmodel_response_callback prompt_callback,
|
||||
llmodel_response_callback response_callback,
|
||||
llmodel_recalculate_callback recalculate_callback,
|
||||
llmodel_prompt_context *ctx)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
|
||||
// Create std::function wrappers that call the C function pointers
|
||||
std::function<bool(int32_t)> prompt_func =
|
||||
std::bind(&prompt_wrapper, std::placeholders::_1, reinterpret_cast<void*>(prompt_callback));
|
||||
std::function<bool(int32_t, const std::string&)> response_func =
|
||||
std::bind(&response_wrapper, std::placeholders::_1, std::placeholders::_2, reinterpret_cast<void*>(response));
|
||||
std::bind(&response_wrapper, std::placeholders::_1, std::placeholders::_2, reinterpret_cast<void*>(response_callback));
|
||||
std::function<bool(bool)> recalc_func =
|
||||
std::bind(&recalculate_wrapper, std::placeholders::_1, reinterpret_cast<void*>(recalculate));
|
||||
std::bind(&recalculate_wrapper, std::placeholders::_1, reinterpret_cast<void*>(recalculate_callback));
|
||||
|
||||
// Copy the C prompt context
|
||||
wrapper->promptContext.n_past = ctx->n_past;
|
||||
@ -85,7 +93,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
wrapper->promptContext.contextErase = ctx->context_erase;
|
||||
|
||||
// Call the C++ prompt method
|
||||
wrapper->llModel->prompt(prompt, response_func, recalc_func, wrapper->promptContext);
|
||||
wrapper->llModel->prompt(prompt, prompt_func, response_func, recalc_func, wrapper->promptContext);
|
||||
|
||||
// Update the C context by giving access to the wrappers raw pointers to std::vector data
|
||||
// which involves no copies
|
||||
|
@ -37,10 +37,17 @@ typedef struct {
|
||||
float context_erase; // percent of context to erase if we exceed the context window
|
||||
} llmodel_prompt_context;
|
||||
|
||||
/**
|
||||
* Callback type for prompt processing.
|
||||
* @param token_id The token id of the prompt.
|
||||
* @return a bool indicating whether the model should keep processing.
|
||||
*/
|
||||
typedef bool (*llmodel_prompt_callback)(int32_t token_id);
|
||||
|
||||
/**
|
||||
* Callback type for response.
|
||||
* @param token_id The token id of the response.
|
||||
* @param response The response string.
|
||||
* @param response The response string. NOTE: a token_id of -1 indicates the string is an error string.
|
||||
* @return a bool indicating whether the model should keep generating.
|
||||
*/
|
||||
typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response);
|
||||
@ -95,13 +102,15 @@ bool llmodel_isModelLoaded(llmodel_model model);
|
||||
* Generate a response using the model.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param prompt A string representing the input prompt.
|
||||
* @param response A callback function for handling the generated response.
|
||||
* @param recalculate A callback function for handling recalculation requests.
|
||||
* @param prompt_callback A callback function for handling the processing of prompt.
|
||||
* @param response_callback A callback function for handling the generated response.
|
||||
* @param recalculate_callback A callback function for handling recalculation requests.
|
||||
* @param ctx A pointer to the llmodel_prompt_context structure.
|
||||
*/
|
||||
void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
llmodel_response_callback response,
|
||||
llmodel_recalculate_callback recalculate,
|
||||
llmodel_response_callback prompt_callback,
|
||||
llmodel_response_callback response_callback,
|
||||
llmodel_recalculate_callback recalculate_callback,
|
||||
llmodel_prompt_context *ctx);
|
||||
|
||||
/**
|
||||
|
Loading…
Reference in New Issue
Block a user