Add save/restore to chatgpt chats and allow serialize/deseralize from disk.

This commit is contained in:
Adam Treat 2023-05-15 18:36:41 -04:00 committed by AT
parent 0cd509d530
commit f931de21c5
7 changed files with 120 additions and 12 deletions

View File

@ -258,6 +258,7 @@ bool Chat::deserialize(QDataStream &stream, int version)
// unfortunately, we cannot deserialize these
if (version < 2 && m_savedModelName.contains("gpt4all-j"))
return false;
m_llmodel->setModelName(m_savedModelName);
if (!m_llmodel->deserialize(stream, version))
return false;
if (!m_chatModel->deserialize(stream, version))

View File

@ -46,6 +46,7 @@ bool ChatGPT::isModelLoaded() const
return true;
}
// All three of the state virtual functions are handled custom inside of chatllm save/restore
size_t ChatGPT::stateSize() const
{
return 0;
@ -53,11 +54,13 @@ size_t ChatGPT::stateSize() const
size_t ChatGPT::saveState(uint8_t *dest) const
{
Q_UNUSED(dest);
return 0;
}
size_t ChatGPT::restoreState(const uint8_t *src)
{
Q_UNUSED(src);
return 0;
}
@ -141,8 +144,8 @@ void ChatGPT::handleFinished()
bool ok;
int code = response.toInt(&ok);
if (!ok || code != 200) {
qWarning() << QString("\nERROR: ChatGPT responded with error code \"%1-%2%3\"\n")
.arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString();
qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"")
.arg(code).arg(reply->errorString()).toStdString();
}
reply->deleteLater();
}
@ -190,8 +193,11 @@ void ChatGPT::handleReadyRead()
const QString content = delta.value("content").toString();
Q_ASSERT(m_ctx);
Q_ASSERT(m_responseCallback);
m_responseCallback(0, content.toStdString());
m_currentResponse += content;
if (!m_responseCallback(0, content.toStdString())) {
reply->abort();
return;
}
}
}
@ -201,6 +207,6 @@ void ChatGPT::handleErrorOccurred(QNetworkReply::NetworkError code)
if (!reply)
return;
qWarning() << QString("\nERROR: ChatGPT responded with error code \"%1-%2%3\"\n")
.arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString();
qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"")
.arg(code).arg(reply->errorString()).toStdString();
}

View File

@ -30,6 +30,9 @@ public:
void setModelName(const QString &modelName) { m_modelName = modelName; }
void setAPIKey(const QString &apiKey) { m_apiKey = apiKey; }
QList<QString> context() const { return m_context; }
void setContext(const QList<QString> &context) { m_context = context; }
protected:
void recalculateContext(PromptContext &promptCtx,
std::function<bool(bool)> recalculate) override {}

View File

@ -38,6 +38,19 @@ void ChatListModel::setShouldSaveChats(bool b)
emit shouldSaveChatsChanged();
}
bool ChatListModel::shouldSaveChatGPTChats() const
{
return m_shouldSaveChatGPTChats;
}
void ChatListModel::setShouldSaveChatGPTChats(bool b)
{
if (m_shouldSaveChatGPTChats == b)
return;
m_shouldSaveChatGPTChats = b;
emit shouldSaveChatGPTChatsChanged();
}
void ChatListModel::removeChatFile(Chat *chat) const
{
Q_ASSERT(chat != m_serverChat);
@ -52,15 +65,17 @@ void ChatListModel::removeChatFile(Chat *chat) const
void ChatListModel::saveChats() const
{
if (!m_shouldSaveChats)
return;
QElapsedTimer timer;
timer.start();
const QString savePath = Download::globalInstance()->downloadLocalModelsPath();
for (Chat *chat : m_chats) {
if (chat == m_serverChat)
continue;
const bool isChatGPT = chat->modelName().startsWith("chatgpt-");
if (!isChatGPT && !m_shouldSaveChats)
continue;
if (isChatGPT && !m_shouldSaveChatGPTChats)
continue;
QString fileName = "gpt4all-" + chat->id() + ".chat";
QFile file(savePath + "/" + fileName);
bool success = file.open(QIODevice::WriteOnly);

View File

@ -20,6 +20,7 @@ class ChatListModel : public QAbstractListModel
Q_PROPERTY(int count READ count NOTIFY countChanged)
Q_PROPERTY(Chat *currentChat READ currentChat WRITE setCurrentChat NOTIFY currentChatChanged)
Q_PROPERTY(bool shouldSaveChats READ shouldSaveChats WRITE setShouldSaveChats NOTIFY shouldSaveChatsChanged)
Q_PROPERTY(bool shouldSaveChatGPTChats READ shouldSaveChatGPTChats WRITE setShouldSaveChatGPTChats NOTIFY shouldSaveChatGPTChatsChanged)
public:
explicit ChatListModel(QObject *parent = nullptr);
@ -62,6 +63,9 @@ public:
bool shouldSaveChats() const;
void setShouldSaveChats(bool b);
bool shouldSaveChatGPTChats() const;
void setShouldSaveChatGPTChats(bool b);
Q_INVOKABLE void addChat()
{
// Don't add a new chat if we already have one
@ -199,6 +203,7 @@ Q_SIGNALS:
void countChanged();
void currentChatChanged();
void shouldSaveChatsChanged();
void shouldSaveChatGPTChatsChanged();
private Q_SLOTS:
void newChatCountChanged()
@ -240,6 +245,7 @@ private Q_SLOTS:
private:
bool m_shouldSaveChats;
bool m_shouldSaveChatGPTChats;
Chat* m_newChat;
Chat* m_dummyChat;
Chat* m_serverChat;

View File

@ -611,6 +611,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version)
stream >> compressed;
m_state = qUncompress(compressed);
} else {
stream >> m_state;
}
#if defined(DEBUG)
@ -624,6 +625,15 @@ void ChatLLM::saveState()
if (!isModelLoaded())
return;
if (m_isChatGPT) {
m_state.clear();
QDataStream stream(&m_state, QIODeviceBase::WriteOnly);
stream.setVersion(QDataStream::Qt_6_5);
ChatGPT *chatGPT = static_cast<ChatGPT*>(m_modelInfo.model);
stream << chatGPT->context();
return;
}
const size_t stateSize = m_modelInfo.model->stateSize();
m_state.resize(stateSize);
#if defined(DEBUG)
@ -637,6 +647,18 @@ void ChatLLM::restoreState()
if (!isModelLoaded() || m_state.isEmpty())
return;
if (m_isChatGPT) {
QDataStream stream(&m_state, QIODeviceBase::ReadOnly);
stream.setVersion(QDataStream::Qt_6_5);
ChatGPT *chatGPT = static_cast<ChatGPT*>(m_modelInfo.model);
QList<QString> context;
stream >> context;
chatGPT->setContext(context);
m_state.clear();
m_state.resize(0);
return;
}
#if defined(DEBUG)
qDebug() << "restoreState" << m_chat->id() << "size:" << m_state.size();
#endif

View File

@ -40,6 +40,7 @@ Dialog {
property int defaultRepeatPenaltyTokens: 64
property int defaultThreadCount: 0
property bool defaultSaveChats: false
property bool defaultSaveChatGPTChats: true
property bool defaultServerChat: false
property string defaultPromptTemplate: "### Human:
%1
@ -57,6 +58,7 @@ Dialog {
property alias repeatPenaltyTokens: settings.repeatPenaltyTokens
property alias threadCount: settings.threadCount
property alias saveChats: settings.saveChats
property alias saveChatGPTChats: settings.saveChatGPTChats
property alias serverChat: settings.serverChat
property alias modelPath: settings.modelPath
property alias userDefaultModel: settings.userDefaultModel
@ -70,6 +72,7 @@ Dialog {
property int promptBatchSize: settingsDialog.defaultPromptBatchSize
property int threadCount: settingsDialog.defaultThreadCount
property bool saveChats: settingsDialog.defaultSaveChats
property bool saveChatGPTChats: settingsDialog.defaultSaveChatGPTChats
property bool serverChat: settingsDialog.defaultServerChat
property real repeatPenalty: settingsDialog.defaultRepeatPenalty
property int repeatPenaltyTokens: settingsDialog.defaultRepeatPenaltyTokens
@ -94,12 +97,14 @@ Dialog {
settings.modelPath = settingsDialog.defaultModelPath
settings.threadCount = defaultThreadCount
settings.saveChats = defaultSaveChats
settings.saveChatGPTChats = defaultSaveChatGPTChats
settings.serverChat = defaultServerChat
settings.userDefaultModel = defaultUserDefaultModel
Download.downloadLocalModelsPath = settings.modelPath
LLM.threadCount = settings.threadCount
LLM.serverEnabled = settings.serverChat
LLM.chatListModel.shouldSaveChats = settings.saveChats
LLM.chatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats
settings.sync()
}
@ -107,6 +112,7 @@ Dialog {
LLM.threadCount = settings.threadCount
LLM.serverEnabled = settings.serverChat
LLM.chatListModel.shouldSaveChats = settings.saveChats
LLM.chatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats
Download.downloadLocalModelsPath = settings.modelPath
}
@ -803,16 +809,65 @@ Dialog {
}
}
Label {
id: serverChatLabel
text: qsTr("Enable web server:")
id: saveChatGPTChatsLabel
text: qsTr("Save ChatGPT chats to disk:")
color: theme.textColor
Layout.row: 5
Layout.column: 0
}
CheckBox {
id: serverChatBox
id: saveChatGPTChatsBox
Layout.row: 5
Layout.column: 1
checked: settingsDialog.saveChatGPTChats
onClicked: {
settingsDialog.saveChatGPTChats = saveChatGPTChatsBox.checked
LLM.chatListModel.shouldSaveChatGPTChats = saveChatGPTChatsBox.checked
settings.sync()
}
background: Rectangle {
color: "transparent"
}
indicator: Rectangle {
implicitWidth: 26
implicitHeight: 26
x: saveChatGPTChatsBox.leftPadding
y: parent.height / 2 - height / 2
border.color: theme.dialogBorder
color: "transparent"
Rectangle {
width: 14
height: 14
x: 6
y: 6
color: theme.textColor
visible: saveChatGPTChatsBox.checked
}
}
contentItem: Text {
text: saveChatGPTChatsBox.text
font: saveChatGPTChatsBox.font
opacity: enabled ? 1.0 : 0.3
color: theme.textColor
verticalAlignment: Text.AlignVCenter
leftPadding: saveChatGPTChatsBox.indicator.width + saveChatGPTChatsBox.spacing
}
}
Label {
id: serverChatLabel
text: qsTr("Enable web server:")
color: theme.textColor
Layout.row: 6
Layout.column: 0
}
CheckBox {
id: serverChatBox
Layout.row: 6
Layout.column: 1
checked: settings.serverChat
onClicked: {
settingsDialog.serverChat = serverChatBox.checked
@ -855,7 +910,7 @@ Dialog {
}
}
Button {
Layout.row: 6
Layout.row: 7
Layout.column: 1
Layout.fillWidth: true
padding: 10