backend: fix a crash on inputs greater than n_ctx (#2498)

This fixes a regression in commit 4fc4d94b ("fix chat-style prompt
templates (#1970)"), which moved some return statements into a new
function (LLModel::decodePrompt) without making them return from the
parent as well.

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-07-01 11:33:46 -04:00 committed by GitHub
parent 146428fa0a
commit bd307abfe6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 7 deletions

View File

@ -248,7 +248,7 @@ protected:
return true;
}
void decodePrompt(std::function<bool(int32_t)> promptCallback,
bool decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &promptCtx,

View File

@ -135,14 +135,16 @@ void LLModel::prompt(const std::string &prompt,
promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it
// decode the user prompt
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
if (!decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp))
return; // error
// decode the assistant's reply, either generated or spoofed
if (fakeReply == nullptr) {
generateResponse(responseCallback, recalculateCallback, promptCtx);
} else {
embd_inp = tokenize(promptCtx, *fakeReply, false);
decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp);
if (!decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp))
return; // error
}
// decode the rest of the prompt template
@ -160,7 +162,8 @@ void LLModel::prompt(const std::string &prompt,
}
}
void LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
// returns false on error
bool LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
std::function<bool(int32_t, const std::string&)> responseCallback,
std::function<bool(bool)> recalculateCallback,
PromptContext &promptCtx,
@ -172,7 +175,7 @@ void LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed.");
std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() <<
" tokens and the context window is " << promptCtx.n_ctx << "!\n";
return;
return false;
}
promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size());
@ -193,7 +196,7 @@ void LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
if (!evalTokens(promptCtx, batch)) {
std::cerr << implementation().modelType() << " ERROR: Failed to process prompt\n";
return;
return false;
}
size_t tokens = batch_end - i;
@ -203,10 +206,12 @@ void LLModel::decodePrompt(std::function<bool(int32_t)> promptCallback,
promptCtx.tokens.push_back(batch.at(t));
promptCtx.n_past += 1;
if (!promptCallback(batch.at(t)))
return;
return false;
}
i = batch_end;
}
return true;
}
void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,