Move the promptCallback to own function.

This commit is contained in:
Adam Treat 2023-04-27 11:08:15 -04:00
parent 0e9f85bcda
commit ba4b28fcd5
9 changed files with 81 additions and 47 deletions

36
llm.cpp
View File

@ -38,7 +38,7 @@ static QString modelFilePath(const QString &modelName)
LLMObject::LLMObject()
: QObject{nullptr}
, m_llmodel(nullptr)
, m_responseTokens(0)
, m_promptResponseTokens(0)
, m_responseLogits(0)
, m_isRecalc(false)
{
@ -133,12 +133,12 @@ bool LLMObject::isModelLoaded() const
void LLMObject::regenerateResponse()
{
s_ctx.n_past -= m_responseTokens;
s_ctx.n_past -= m_promptResponseTokens;
s_ctx.n_past = std::max(0, s_ctx.n_past);
// FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove?
s_ctx.logits.erase(s_ctx.logits.end() -= m_responseLogits, s_ctx.logits.end());
s_ctx.tokens.erase(s_ctx.tokens.end() -= m_responseTokens, s_ctx.tokens.end());
m_responseTokens = 0;
s_ctx.tokens.erase(s_ctx.tokens.end() -= m_promptResponseTokens, s_ctx.tokens.end());
m_promptResponseTokens = 0;
m_responseLogits = 0;
m_response = std::string();
emit responseChanged();
@ -146,7 +146,7 @@ void LLMObject::regenerateResponse()
void LLMObject::resetResponse()
{
m_responseTokens = 0;
m_promptResponseTokens = 0;
m_responseLogits = 0;
m_response = std::string();
emit responseChanged();
@ -263,6 +263,18 @@ QList<QString> LLMObject::modelList() const
return list;
}
bool LLMObject::handlePrompt(int32_t token)
{
if (s_ctx.tokens.size() == s_ctx.n_ctx)
s_ctx.tokens.erase(s_ctx.tokens.begin());
s_ctx.tokens.push_back(token);
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
// the entire context window which we can reset on regenerate prompt
++m_promptResponseTokens;
return !m_stopGenerating;
}
bool LLMObject::handleResponse(int32_t token, const std::string &response)
{
#if 0
@ -282,13 +294,12 @@ bool LLMObject::handleResponse(int32_t token, const std::string &response)
s_ctx.tokens.erase(s_ctx.tokens.begin());
s_ctx.tokens.push_back(token);
// m_responseTokens and m_responseLogits are related to last prompt/response not
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not
// the entire context window which we can reset on regenerate prompt
++m_responseTokens;
if (!response.empty()) {
m_response.append(response);
emit responseChanged();
}
++m_promptResponseTokens;
Q_ASSERT(!response.empty());
m_response.append(response);
emit responseChanged();
// Stop generation if we encounter prompt or response tokens
QString r = QString::fromStdString(m_response);
@ -315,6 +326,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
QString instructPrompt = prompt_template.arg(prompt);
m_stopGenerating = false;
auto promptFunc = std::bind(&LLMObject::handlePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&LLMObject::handleRecalculate, this, std::placeholders::_1);
@ -327,7 +339,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
s_ctx.n_batch = n_batch;
s_ctx.repeat_penalty = repeat_penalty;
s_ctx.repeat_last_n = repeat_penalty_tokens;
m_llmodel->prompt(instructPrompt.toStdString(), responseFunc, recalcFunc, s_ctx);
m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, s_ctx);
m_responseLogits += s_ctx.logits.size() - logitsBefore;
std::string trimmed = trim_whitespace(m_response);
if (trimmed != m_response) {

3
llm.h
View File

@ -58,13 +58,14 @@ Q_SIGNALS:
private:
void resetContextPrivate();
bool loadModelPrivate(const QString &modelName);
bool handlePrompt(int32_t token);
bool handleResponse(int32_t token, const std::string &response);
bool handleRecalculate(bool isRecalc);
private:
LLModel *m_llmodel;
std::string m_response;
quint32 m_responseTokens;
quint32 m_promptResponseTokens;
quint32 m_responseLogits;
QString m_modelName;
QThread m_llmThread;

View File

@ -686,8 +686,9 @@ bool GPTJ::isModelLoaded() const
}
void GPTJ::prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response,
std::function<bool(bool)> recalculate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &promptCtx) {
if (!isModelLoaded()) {
@ -708,7 +709,7 @@ void GPTJ::prompt(const std::string &prompt,
promptCtx.n_ctx = d_ptr->model.hparams.n_ctx;
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
response(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed.");
responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed.");
std::cerr << "GPT-J ERROR: The prompt is" << embd_inp.size() <<
"tokens and the context window is" << promptCtx.n_ctx << "!\n";
return;
@ -741,7 +742,7 @@ void GPTJ::prompt(const std::string &prompt,
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculate);
recalculateContext(promptCtx, recalculateCallback);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
}
@ -750,10 +751,10 @@ void GPTJ::prompt(const std::string &prompt,
std::cerr << "GPT-J ERROR: Failed to process prompt\n";
return;
}
// We pass a null string for each token to see if the user has asked us to stop...
size_t tokens = batch_end - i;
for (size_t t = 0; t < tokens; ++t)
if (!response(batch.at(t), ""))
if (!promptCallback(batch.at(t)))
return;
promptCtx.n_past += batch.size();
i = batch_end;
@ -790,8 +791,8 @@ void GPTJ::prompt(const std::string &prompt,
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculate);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
recalculateContext(promptCtx, recalculateCallback);
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
}
const int64_t t_start_predict_us = ggml_time_us();
@ -805,7 +806,7 @@ void GPTJ::prompt(const std::string &prompt,
promptCtx.n_past += 1;
// display text
++totalPredictions;
if (id == 50256 /*end of text*/ || !response(id, d_ptr->vocab.id_to_token[id]))
if (id == 50256 /*end of text*/ || !responseCallback(id, d_ptr->vocab.id_to_token[id]))
goto stop_generating;
}

View File

@ -16,8 +16,9 @@ public:
bool loadModel(const std::string &modelPath, std::istream &fin) override;
bool isModelLoaded() const override;
void prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response,
std::function<bool(bool)> recalculate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &ctx) override;
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() override;

View File

@ -80,8 +80,9 @@ bool LLamaModel::isModelLoaded() const
}
void LLamaModel::prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response,
std::function<bool(bool)> recalculate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &promptCtx) {
if (!isModelLoaded()) {
@ -102,7 +103,7 @@ void LLamaModel::prompt(const std::string &prompt,
promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx);
if ((int) embd_inp.size() > promptCtx.n_ctx - 4) {
response(-1, "The prompt size exceeds the context window size and cannot be processed.");
responseCallback(-1, "The prompt size exceeds the context window size and cannot be processed.");
std::cerr << "LLAMA ERROR: The prompt is" << embd_inp.size() <<
"tokens and the context window is" << promptCtx.n_ctx << "!\n";
return;
@ -128,7 +129,7 @@ void LLamaModel::prompt(const std::string &prompt,
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculate);
recalculateContext(promptCtx, recalculateCallback);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
}
@ -137,10 +138,9 @@ void LLamaModel::prompt(const std::string &prompt,
return;
}
// We pass a null string for each token to see if the user has asked us to stop...
size_t tokens = batch_end - i;
for (size_t t = 0; t < tokens; ++t)
if (!response(batch.at(t), ""))
if (!promptCallback(batch.at(t)))
return;
promptCtx.n_past += batch.size();
i = batch_end;
@ -162,8 +162,8 @@ void LLamaModel::prompt(const std::string &prompt,
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculate);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
recalculateContext(promptCtx, recalculateCallback);
assert(promptCtx.n_past + 1 <= promptCtx.n_ctx);
}
if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) {
@ -174,7 +174,7 @@ void LLamaModel::prompt(const std::string &prompt,
promptCtx.n_past += 1;
// display text
++totalPredictions;
if (id == llama_token_eos() || !response(id, llama_token_to_str(d_ptr->ctx, id)))
if (id == llama_token_eos() || !responseCallback(id, llama_token_to_str(d_ptr->ctx, id)))
return;
}
}

View File

@ -16,8 +16,9 @@ public:
bool loadModel(const std::string &modelPath, std::istream &fin) override;
bool isModelLoaded() const override;
void prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response,
std::function<bool(bool)> recalculate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &ctx) override;
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() override;

View File

@ -29,8 +29,9 @@ public:
// window
};
virtual void prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response,
std::function<bool(bool)> recalculate,
std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &ctx) = 0;
virtual void setThreadCount(int32_t n_threads) {}
virtual int32_t threadCount() { return 1; }

View File

@ -49,6 +49,11 @@ bool llmodel_isModelLoaded(llmodel_model model)
}
// Wrapper functions for the C callbacks
bool prompt_wrapper(int32_t token_id, void *user_data) {
llmodel_prompt_callback callback = reinterpret_cast<llmodel_prompt_callback>(user_data);
return callback(token_id);
}
bool response_wrapper(int32_t token_id, const std::string &response, void *user_data) {
llmodel_response_callback callback = reinterpret_cast<llmodel_response_callback>(user_data);
return callback(token_id, response.c_str());
@ -60,17 +65,20 @@ bool recalculate_wrapper(bool is_recalculating, void *user_data) {
}
void llmodel_prompt(llmodel_model model, const char *prompt,
llmodel_response_callback response,
llmodel_recalculate_callback recalculate,
llmodel_response_callback prompt_callback,
llmodel_response_callback response_callback,
llmodel_recalculate_callback recalculate_callback,
llmodel_prompt_context *ctx)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
// Create std::function wrappers that call the C function pointers
std::function<bool(int32_t)> prompt_func =
std::bind(&prompt_wrapper, std::placeholders::_1, reinterpret_cast<void*>(prompt_callback));
std::function<bool(int32_t, const std::string&)> response_func =
std::bind(&response_wrapper, std::placeholders::_1, std::placeholders::_2, reinterpret_cast<void*>(response));
std::bind(&response_wrapper, std::placeholders::_1, std::placeholders::_2, reinterpret_cast<void*>(response_callback));
std::function<bool(bool)> recalc_func =
std::bind(&recalculate_wrapper, std::placeholders::_1, reinterpret_cast<void*>(recalculate));
std::bind(&recalculate_wrapper, std::placeholders::_1, reinterpret_cast<void*>(recalculate_callback));
// Copy the C prompt context
wrapper->promptContext.n_past = ctx->n_past;
@ -85,7 +93,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
wrapper->promptContext.contextErase = ctx->context_erase;
// Call the C++ prompt method
wrapper->llModel->prompt(prompt, response_func, recalc_func, wrapper->promptContext);
wrapper->llModel->prompt(prompt, prompt_func, response_func, recalc_func, wrapper->promptContext);
// Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies

View File

@ -37,10 +37,17 @@ typedef struct {
float context_erase; // percent of context to erase if we exceed the context window
} llmodel_prompt_context;
/**
* Callback type for prompt processing.
* @param token_id The token id of the prompt.
* @return a bool indicating whether the model should keep processing.
*/
typedef bool (*llmodel_prompt_callback)(int32_t token_id);
/**
* Callback type for response.
* @param token_id The token id of the response.
* @param response The response string.
* @param response The response string. NOTE: a token_id of -1 indicates the string is an error string.
* @return a bool indicating whether the model should keep generating.
*/
typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response);
@ -95,13 +102,15 @@ bool llmodel_isModelLoaded(llmodel_model model);
* Generate a response using the model.
* @param model A pointer to the llmodel_model instance.
* @param prompt A string representing the input prompt.
* @param response A callback function for handling the generated response.
* @param recalculate A callback function for handling recalculation requests.
* @param prompt_callback A callback function for handling the processing of prompt.
* @param response_callback A callback function for handling the generated response.
* @param recalculate_callback A callback function for handling recalculation requests.
* @param ctx A pointer to the llmodel_prompt_context structure.
*/
void llmodel_prompt(llmodel_model model, const char *prompt,
llmodel_response_callback response,
llmodel_recalculate_callback recalculate,
llmodel_response_callback prompt_callback,
llmodel_response_callback response_callback,
llmodel_recalculate_callback recalculate_callback,
llmodel_prompt_context *ctx);
/**