Add a new model download feature.

This commit is contained in:
Adam Treat 2023-04-18 21:10:06 -04:00
parent 1eda8f030e
commit e6cb6a2ae3
7 changed files with 562 additions and 8 deletions

View File

@ -33,6 +33,7 @@ add_subdirectory(ggml)
qt_add_executable(chat
main.cpp
download.h download.cpp
gptj.h gptj.cpp
llm.h llm.cpp
llmodel.h
@ -41,7 +42,7 @@ qt_add_executable(chat
qt_add_qml_module(chat
URI gpt4all-chat
VERSION 1.0
QML_FILES main.qml
QML_FILES main.qml qml/ModelDownloaderDialog.qml
RESOURCES
icons/send_message.svg
icons/stop_generating.svg

209
download.cpp Normal file
View File

@ -0,0 +1,209 @@
#include "download.h"
#include <QCoreApplication>
#include <QNetworkRequest>
#include <QNetworkAccessManager>
#include <QJsonDocument>
#include <QJsonObject>
#include <QJsonArray>
#include <QUrl>
#include <QDir>
class MyDownload: public Download { };
Q_GLOBAL_STATIC(MyDownload, downloadInstance)
Download *Download::globalInstance()
{
return downloadInstance();
}
Download::Download()
: QObject(nullptr)
{
updateModelList();
}
QList<ModelInfo> Download::modelList() const
{
// We make sure the default model is listed first
QList<ModelInfo> values = m_modelMap.values();
ModelInfo defaultInfo;
for (ModelInfo v : values) {
if (v.isDefault) {
defaultInfo = v;
break;
}
}
values.removeAll(defaultInfo);
values.prepend(defaultInfo);
return values;
}
void Download::updateModelList()
{
QUrl jsonUrl("http://gpt4all.io/models/models.json");
QNetworkRequest request(jsonUrl);
QNetworkReply *jsonReply = m_networkManager.get(request);
connect(jsonReply, &QNetworkReply::finished, this, &Download::handleJsonDownloadFinished);
}
void Download::downloadModel(const QString &modelFile)
{
QNetworkRequest request("http://gpt4all.io/models/" + modelFile);
QNetworkReply *modelReply = m_networkManager.get(request);
connect(modelReply, &QNetworkReply::downloadProgress, this, &Download::handleDownloadProgress);
connect(modelReply, &QNetworkReply::finished, this, &Download::handleModelDownloadFinished);
m_activeDownloads.append(modelReply);
}
void Download::cancelDownload(const QString &modelFile)
{
for (int i = 0; i < m_activeDownloads.size(); ++i) {
QNetworkReply *modelReply = m_activeDownloads.at(i);
QUrl url = modelReply->request().url();
if (url.toString().endsWith(modelFile)) {
// Disconnect the signals
disconnect(modelReply, &QNetworkReply::downloadProgress, this, &Download::handleDownloadProgress);
disconnect(modelReply, &QNetworkReply::finished, this, &Download::handleModelDownloadFinished);
modelReply->abort(); // Abort the download
modelReply->deleteLater(); // Schedule the reply for deletion
m_activeDownloads.removeAll(modelReply);
// Emit downloadFinished signal for cleanup
emit downloadFinished(modelFile);
break;
}
}
}
void Download::handleJsonDownloadFinished()
{
#if 0
QByteArray jsonData = QString(""
"["
" {"
" \"md5sum\": \"61d48a82cb188cceb14ebb8082bfec37\","
" \"filename\": \"ggml-gpt4all-j-v1.1-breezy.bin\""
" },"
" {"
" \"md5sum\": \"879344aaa9d62fdccbda0be7a09e7976\","
" \"filename\": \"ggml-gpt4all-j-v1.2-jazzy.bin\","
" \"isDefault\": \"true\""
" },"
" {"
" \"md5sum\": \"5b5a3f9b858d33b29b52b89692415595\","
" \"filename\": \"ggml-gpt4all-j.bin\""
" }"
"]"
).toUtf8();
printf("%s\n", jsonData.toStdString().c_str());
fflush(stdout);
#else
QNetworkReply *jsonReply = qobject_cast<QNetworkReply *>(sender());
if (!jsonReply)
return;
QByteArray jsonData = jsonReply->readAll();
jsonReply->deleteLater();
#endif
parseJsonFile(jsonData);
}
void Download::parseJsonFile(const QByteArray &jsonData)
{
QJsonParseError err;
QJsonDocument document = QJsonDocument::fromJson(jsonData, &err);
if (err.error != QJsonParseError::NoError) {
qDebug() << "ERROR: Couldn't parse: " << jsonData << err.errorString();
return;
}
QJsonArray jsonArray = document.array();
m_modelMap.clear();
for (const QJsonValue &value : jsonArray) {
QJsonObject obj = value.toObject();
QString modelFilename = obj["filename"].toString();
QByteArray modelMd5sum = obj["md5sum"].toString().toLatin1().constData();
bool isDefault = obj.contains("isDefault") && obj["isDefault"] == QString("true");
QString filePath = QCoreApplication::applicationDirPath() + QDir::separator() + modelFilename;
QFileInfo info(filePath);
ModelInfo modelInfo;
modelInfo.filename = modelFilename;
modelInfo.md5sum = modelMd5sum;
modelInfo.installed = info.exists();
modelInfo.isDefault = isDefault;
m_modelMap.insert(modelInfo.filename, modelInfo);
}
emit modelListChanged();
}
void Download::handleDownloadProgress(qint64 bytesReceived, qint64 bytesTotal)
{
QNetworkReply *modelReply = qobject_cast<QNetworkReply *>(sender());
if (!modelReply)
return;
QString modelFilename = modelReply->url().fileName();
// qDebug() << "handleDownloadProgress" << bytesReceived << bytesTotal << modelFilename;
emit downloadProgress(bytesReceived, bytesTotal, modelFilename);
}
bool operator==(const ModelInfo& lhs, const ModelInfo& rhs) {
return lhs.filename == rhs.filename && lhs.md5sum == rhs.md5sum;
}
void Download::handleModelDownloadFinished()
{
QNetworkReply *modelReply = qobject_cast<QNetworkReply *>(sender());
if (!modelReply)
return;
QString modelFilename = modelReply->url().fileName();
// qDebug() << "handleModelDownloadFinished" << modelFilename;
m_activeDownloads.removeAll(modelReply);
if (modelReply->error()) {
qWarning() << "ERROR: downloading:" << modelReply->errorString();
modelReply->deleteLater();
emit downloadFinished(modelFilename);
return;
}
QByteArray modelData = modelReply->readAll();
if (!m_modelMap.contains(modelFilename)) {
qWarning() << "ERROR: Cannot find in download map:" << modelFilename;
modelReply->deleteLater();
emit downloadFinished(modelFilename);
return;
}
ModelInfo info = m_modelMap.value(modelFilename);
QCryptographicHash hash(QCryptographicHash::Md5);
hash.addData(modelData);
if (hash.result().toHex() != info.md5sum) {
qWarning() << "ERROR: Download error MD5SUM did not match:"
<< hash.result().toHex()
<< "!=" << info.md5sum << "for" << modelFilename;
modelReply->deleteLater();
emit downloadFinished(modelFilename);
return;
}
// Save the model file to disk
QFile file(QCoreApplication::applicationDirPath() + QDir::separator() + modelFilename);
if (file.open(QIODevice::WriteOnly)) {
file.write(modelData);
file.close();
}
modelReply->deleteLater();
emit downloadFinished(modelFilename);
info.installed = true;
m_modelMap.insert(modelFilename, info);
emit modelListChanged();
}

62
download.h Normal file
View File

@ -0,0 +1,62 @@
#ifndef DOWNLOAD_H
#define DOWNLOAD_H
#include <QObject>
#include <QObject>
#include <QNetworkAccessManager>
#include <QNetworkReply>
#include <QFile>
#include <QVariant>
struct ModelInfo {
Q_GADGET
Q_PROPERTY(QString filename MEMBER filename)
Q_PROPERTY(QByteArray md5sum MEMBER md5sum)
Q_PROPERTY(bool installed MEMBER installed)
Q_PROPERTY(bool isDefault MEMBER isDefault)
public:
QString filename;
QByteArray md5sum;
bool installed = false;
bool isDefault = false;
};
Q_DECLARE_METATYPE(ModelInfo)
class Download : public QObject
{
Q_OBJECT
Q_PROPERTY(QList<ModelInfo> modelList READ modelList NOTIFY modelListChanged)
public:
static Download *globalInstance();
QList<ModelInfo> modelList() const;
Q_INVOKABLE void updateModelList();
Q_INVOKABLE void downloadModel(const QString &modelFile);
Q_INVOKABLE void cancelDownload(const QString &modelFile);
public Q_SLOTS:
void handleJsonDownloadFinished();
void handleDownloadProgress(qint64 bytesReceived, qint64 bytesTotal);
void handleModelDownloadFinished();
Q_SIGNALS:
void downloadProgress(qint64 bytesReceived, qint64 bytesTotal, const QString &modelFile);
void downloadFinished(const QString &modelFile);
void modelListChanged();
private:
void parseJsonFile(const QByteArray &jsonData);
QMap<QString, ModelInfo> m_modelMap;
QNetworkAccessManager m_networkManager;
QVector<QNetworkReply*> m_activeDownloads;
private:
explicit Download();
~Download() {}
friend class MyDownload;
};
#endif // DOWNLOAD_H

View File

@ -1,4 +1,5 @@
#include "llm.h"
#include "download.h"
#include <QCoreApplication>
#include <QDir>
@ -30,6 +31,13 @@ LLMObject::LLMObject()
bool LLMObject::loadModel()
{
if (modelList().isEmpty()) {
// try again when we get a list of models
connect(Download::globalInstance(), &Download::modelListChanged, this,
&LLMObject::loadModel, Qt::SingleShotConnection);
return false;
}
return loadModelPrivate(modelList().first());
}
@ -210,6 +218,7 @@ LLM::LLM()
, m_llmodel(new LLMObject)
, m_responseInProgress(false)
{
connect(Download::globalInstance(), &Download::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::isModelLoadedChanged, this, &LLM::isModelLoadedChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::responseChanged, this, &LLM::responseChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::responseStarted, this, &LLM::responseStarted, Qt::QueuedConnection);

View File

@ -5,15 +5,21 @@
#include <QDirIterator>
#include "llm.h"
#include "download.h"
#include "config.h"
int main(int argc, char *argv[])
{
QCoreApplication::setOrganizationName("nomic.ai");
QCoreApplication::setOrganizationDomain("gpt4all.io");
QCoreApplication::setApplicationName("GPT4All");
QCoreApplication::setApplicationVersion(APP_VERSION);
QGuiApplication app(argc, argv);
QQmlApplicationEngine engine;
qmlRegisterSingletonInstance("llm", 1, 0, "LLM", LLM::globalInstance());
qmlRegisterSingletonInstance("download", 1, 0, "Download", Download::globalInstance());
const QUrl url(u"qrc:/gpt4all-chat/main.qml"_qs);
QObject::connect(&engine, &QQmlApplicationEngine::objectCreated,
@ -23,7 +29,7 @@ int main(int argc, char *argv[])
}, Qt::QueuedConnection);
engine.load(url);
#if 1
#if 0
QDirIterator it("qrc:", QDirIterator::Subdirectories);
while (it.hasNext()) {
qDebug() << it.next();

View File

@ -27,7 +27,6 @@ Window {
Item {
anchors.centerIn: parent
width: childrenRect.width
height: childrenRect.height
visible: LLM.isModelLoaded
@ -93,6 +92,16 @@ Window {
title: qsTr("Settings")
height: 600
width: 600
opacity: 0.9
background: Rectangle {
anchors.fill: parent
anchors.margins: -20
color: "#202123"
border.width: 1
border.color: "white"
radius: 10
}
property real defaultTemperature: 0.28
property real defaultTopP: 0.95
property int defaultTopK: 40
@ -134,10 +143,7 @@ Window {
columns: 2
rowSpacing: 10
columnSpacing: 10
anchors.top: parent.top
anchors.left: parent.left
anchors.right: parent.right
anchors.bottom: parent.bottom
anchors.fill: parent
Label {
id: tempLabel
@ -558,6 +564,7 @@ Window {
}
background: Rectangle {
anchors.fill: parent
anchors.margins: -20
color: "#202123"
border.width: 1
border.color: "white"
@ -565,6 +572,16 @@ Window {
}
}
ModelDownloaderDialog {
id: downloadNewModels
anchors.centerIn: parent
Item {
Accessible.role: Accessible.Dialog
Accessible.name: qsTr("Download new models dialog")
Accessible.description: qsTr("Dialog for downloading new models")
}
}
Drawer {
id: drawer
y: header.height
@ -638,7 +655,8 @@ Window {
Button {
anchors.left: parent.left
anchors.right: parent.right
anchors.bottom: parent.bottom
anchors.bottom: downloadButton.top
anchors.bottomMargin: 20
padding: 15
contentItem: Text {
text: qsTr("Check for updates...")
@ -663,6 +681,36 @@ Window {
checkForUpdatesError.open()
}
}
Button {
id: downloadButton
anchors.left: parent.left
anchors.right: parent.right
anchors.bottom: parent.bottom
padding: 15
contentItem: Text {
text: qsTr("Download new models...")
horizontalAlignment: Text.AlignHCenter
color: "#d1d5db"
Accessible.role: Accessible.Button
Accessible.name: text
Accessible.description: qsTr("Use this to launch a dialog to download new models")
}
background: Rectangle {
opacity: .5
border.color: "#7d7d8e"
border.width: 1
radius: 10
color: "#343541"
}
onClicked: {
downloadNewModels.open()
}
}
}
}

View File

@ -0,0 +1,219 @@
import QtQuick 6.5
import QtQuick.Controls 6.5
import QtQuick.Layouts 1.12
import download
import llm
Dialog {
id: modelDownloaderDialog
width: 900
height: 400
title: "Model Downloader"
modal: true
opacity: 0.9
closePolicy: LLM.modelList.length === 0 ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside)
background: Rectangle {
anchors.fill: parent
anchors.margins: -20
color: "#202123"
border.width: 1
border.color: "white"
radius: 10
}
Component.onCompleted: {
if (LLM.modelList.length === 0)
open();
}
ColumnLayout {
anchors.fill: parent
anchors.margins: 20
spacing: 10
Label {
id: listLabel
text: "Available Models:"
Layout.alignment: Qt.AlignLeft
Layout.fillWidth: true
color: "#d1d5db"
}
ListView {
id: modelList
Layout.fillWidth: true
Layout.fillHeight: true
model: Download.modelList
clip: true
boundsBehavior: Flickable.StopAtBounds
delegate: Item {
id: delegateItem
width: modelList.width
height: 50
objectName: "delegateItem"
property bool downloading: false
Rectangle {
anchors.fill: parent
color: index % 2 === 0 ? "#2c2f33" : "#1e2125"
}
Text {
id: modelName
objectName: "modelName"
text: modelData.filename
anchors.verticalCenter: parent.verticalCenter
anchors.left: parent.left
anchors.leftMargin: 10
font.pixelSize: 24
color: "#d1d5db"
Accessible.role: Accessible.Paragraph
Accessible.name: qsTr("Model file")
Accessible.description: qsTr("Model file to be downloaded")
}
Text {
text: qsTr("(default)")
visible: modelData.isDefault
anchors.verticalCenter: parent.verticalCenter
anchors.left: modelName.right
anchors.leftMargin: 10
font.pixelSize: 24
color: "#d1d5db"
Accessible.role: Accessible.Paragraph
Accessible.name: qsTr("Default file")
Accessible.description: qsTr("Whether the file is the default model")
}
Label {
id: speedLabel
anchors.verticalCenter: parent.verticalCenter
anchors.right: itemProgressBar.left
anchors.rightMargin: 10
objectName: "speedLabel"
color: "#d1d5db"
text: ""
visible: downloading
Accessible.role: Accessible.Paragraph
Accessible.name: qsTr("Download speed")
Accessible.description: qsTr("Download speed in bytes/kilobytes/megabytes per second")
}
ProgressBar {
id: itemProgressBar
objectName: "itemProgressBar"
anchors.verticalCenter: parent.verticalCenter
anchors.right: downloadButton.left
anchors.rightMargin: 10
width: 100
visible: downloading
Accessible.role: Accessible.ProgressBar
Accessible.name: qsTr("Download progressBar")
Accessible.description: qsTr("Shows the progress made in the download")
}
Label {
id: installedLabel
anchors.verticalCenter: parent.verticalCenter
anchors.right: parent.right
anchors.rightMargin: 15
objectName: "installedLabel"
color: "#d1d5db"
text: qsTr("Already installed")
visible: modelData.installed
Accessible.role: Accessible.Paragraph
Accessible.name: text
Accessible.description: qsTr("Whether the file is already installed on your system")
}
Button {
id: downloadButton
text: downloading ? "Cancel" : "Download"
anchors.verticalCenter: parent.verticalCenter
anchors.right: parent.right
anchors.rightMargin: 10
visible: !modelData.installed
onClicked: {
if (!downloading) {
downloading = true;
Download.downloadModel(modelData.filename);
} else {
downloading = false;
Download.cancelDownload(modelData.filename);
}
}
Accessible.role: Accessible.Button
Accessible.name: text
Accessible.description: qsTr("Cancel/Download button to stop/start the download")
}
}
Component.onCompleted: {
Download.downloadProgress.connect(updateProgress);
Download.downloadFinished.connect(resetProgress);
}
property var lastUpdate: ({})
function updateProgress(bytesReceived, bytesTotal, modelName) {
let currentTime = new Date().getTime();
for (let i = 0; i < modelList.contentItem.children.length; i++) {
let delegateItem = modelList.contentItem.children[i];
if (delegateItem.objectName === "delegateItem") {
let modelNameText = delegateItem.children.find(child => child.objectName === "modelName").text;
if (modelNameText === modelName) {
let progressBar = delegateItem.children.find(child => child.objectName === "itemProgressBar");
progressBar.value = bytesReceived / bytesTotal;
// Calculate the download speed
if (lastUpdate[modelName] && lastUpdate[modelName].timestamp) {
let timeDifference = currentTime - lastUpdate[modelName].timestamp;
let bytesDifference = bytesReceived - lastUpdate[modelName].bytesReceived;
let speed = (bytesDifference / timeDifference) * 1000; // bytes per second
// Update the speed label
let speedLabel = delegateItem.children.find(child => child.objectName === "speedLabel");
if (speed < 1024) {
speedLabel.text = speed.toFixed(2) + " B/s";
} else if (speed < 1024 * 1024) {
speedLabel.text = (speed / 1024).toFixed(2) + " KB/s";
} else {
speedLabel.text = (speed / (1024 * 1024)).toFixed(2) + " MB/s";
}
}
// Update the lastUpdate object for the current model
lastUpdate[modelName] = {"timestamp": currentTime, "bytesReceived": bytesReceived};
break;
}
}
}
}
function resetProgress(modelName) {
for (let i = 0; i < modelList.contentItem.children.length; i++) {
let delegateItem = modelList.contentItem.children[i];
if (delegateItem.objectName === "delegateItem") {
let modelNameText = delegateItem.children.find(child => child.objectName === "modelName").text;
if (modelNameText === modelName) {
let progressBar = delegateItem.children.find(child => child.objectName === "itemProgressBar");
progressBar.value = 0;
delegateItem.downloading = false;
// Remove speed label text
let speedLabel = delegateItem.children.find(child => child.objectName === "speedLabel");
speedLabel.text = "";
// Remove the lastUpdate object for the canceled model
delete lastUpdate[modelName];
break;
}
}
}
}
}
}
}