diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index 5a26f4a4..800c0bf6 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -248,7 +248,7 @@ protected: return true; } - void decodePrompt(std::function promptCallback, + bool decodePrompt(std::function promptCallback, std::function responseCallback, std::function recalculateCallback, PromptContext &promptCtx, diff --git a/gpt4all-backend/llmodel_shared.cpp b/gpt4all-backend/llmodel_shared.cpp index 1f797e8f..be02b65b 100644 --- a/gpt4all-backend/llmodel_shared.cpp +++ b/gpt4all-backend/llmodel_shared.cpp @@ -135,14 +135,16 @@ void LLModel::prompt(const std::string &prompt, promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it // decode the user prompt - decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp); + if (!decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp)) + return; // error // decode the assistant's reply, either generated or spoofed if (fakeReply == nullptr) { generateResponse(responseCallback, recalculateCallback, promptCtx); } else { embd_inp = tokenize(promptCtx, *fakeReply, false); - decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp); + if (!decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp)) + return; // error } // decode the rest of the prompt template @@ -160,7 +162,8 @@ void LLModel::prompt(const std::string &prompt, } } -void LLModel::decodePrompt(std::function promptCallback, +// returns false on error +bool LLModel::decodePrompt(std::function promptCallback, std::function responseCallback, std::function recalculateCallback, PromptContext &promptCtx, @@ -172,7 +175,7 @@ void LLModel::decodePrompt(std::function promptCallback, responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed."); std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() << " tokens and the context window is " << promptCtx.n_ctx << "!\n"; - return; + return false; } promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size()); @@ -193,7 +196,7 @@ void LLModel::decodePrompt(std::function promptCallback, if (!evalTokens(promptCtx, batch)) { std::cerr << implementation().modelType() << " ERROR: Failed to process prompt\n"; - return; + return false; } size_t tokens = batch_end - i; @@ -203,10 +206,12 @@ void LLModel::decodePrompt(std::function promptCallback, promptCtx.tokens.push_back(batch.at(t)); promptCtx.n_past += 1; if (!promptCallback(batch.at(t))) - return; + return false; } i = batch_end; } + + return true; } void LLModel::generateResponse(std::function responseCallback,