diff --git a/NAPS2.Lib.Tests/Dependencies/DownloadFormatTests.cs b/NAPS2.Lib.Tests/Dependencies/DownloadFormatTests.cs index 2915ad506..279c7004e 100644 --- a/NAPS2.Lib.Tests/Dependencies/DownloadFormatTests.cs +++ b/NAPS2.Lib.Tests/Dependencies/DownloadFormatTests.cs @@ -10,10 +10,12 @@ public class DownloadFormatTests : ContextualTests public void Gzip() { var path = Path.Combine(FolderPath, "f.gz"); - var gzData = BinaryResources.stock_dog_jpeg; - File.WriteAllBytes(path, gzData); - var extractedPath = DownloadFormat.Gzip.Prepare(path); + string extractedPath; + using (MemoryStream stream = new(BinaryResources.stock_dog_jpeg)) + { + extractedPath = DownloadFormat.Gzip.Prepare(stream, path); + } var expectedDog = BinaryResources.stock_dog; Assert.Equal(expectedDog, File.ReadAllBytes(extractedPath)); @@ -23,10 +25,12 @@ public class DownloadFormatTests : ContextualTests public void Zip() { var path = Path.Combine(FolderPath, "f.zip"); - var zipData = BinaryResources.animals; - File.WriteAllBytes(path, zipData); - var extractedPath = DownloadFormat.Zip.Prepare(path); + string extractedPath; + using (MemoryStream stream = new(BinaryResources.animals)) + { + extractedPath = DownloadFormat.Zip.Prepare(stream, path); + } var dogPath = Path.Combine(extractedPath, "animals/dogs/stock-dog.jpeg"); var catPath = Path.Combine(extractedPath, "animals/cats/stock-cat.jpeg"); diff --git a/NAPS2.Lib/Automation/AutomatedScanning.cs b/NAPS2.Lib/Automation/AutomatedScanning.cs index 0997a0e57..56035b5c0 100644 --- a/NAPS2.Lib/Automation/AutomatedScanning.cs +++ b/NAPS2.Lib/Automation/AutomatedScanning.cs @@ -202,8 +202,7 @@ public class AutomatedScanning { _errorOutput.DisplayError(MiscResources.FilesCouldNotBeDownloaded); }; - downloadController.Start(); - await downloadController.CompletionTask; + await downloadController.StartDownloadsAsync(); } private void ReorderScannedImages() diff --git a/NAPS2.Lib/Dependencies/DownloadController.cs b/NAPS2.Lib/Dependencies/DownloadController.cs index 31ccfd275..7be2cb951 100644 --- a/NAPS2.Lib/Dependencies/DownloadController.cs +++ b/NAPS2.Lib/Dependencies/DownloadController.cs @@ -1,5 +1,4 @@ -using System.ComponentModel; -using System.Net; +using System.Net.Http; using System.Security.Cryptography; using Microsoft.Extensions.Logging; using NAPS2.Scan; @@ -11,25 +10,15 @@ public class DownloadController private readonly ScanningContext _scanningContext; private readonly ILogger _logger; - // TODO: Migrate to HttpClient -#pragma warning disable SYSLIB0014 - private readonly WebClient _client = new(); -#pragma warning restore SYSLIB0014 + private static readonly HttpClient _client = new(); private readonly List _filesToDownload = new(); - private readonly TaskCompletionSource _completionSource = new(); - private int _urlIndex; - private bool _hasError; private bool _cancel; public DownloadController(ScanningContext scanningContext) { _scanningContext = scanningContext; _logger = scanningContext.Logger; - // TODO: Is this needed for net462? - ServicePointManager.SecurityProtocol = SecurityProtocolType.Tls12; - _client.DownloadFileCompleted += client_DownloadFileCompleted; - _client.DownloadProgressChanged += client_DownloadProgressChanged; } public int FilesDownloaded { get; private set; } @@ -40,41 +29,6 @@ public class DownloadController public long CurrentFileProgress { get; private set; } - public Task CompletionTask => _completionSource.Task; - - void client_DownloadProgressChanged(object? sender, DownloadProgressChangedEventArgs e) - { - CurrentFileProgress = e.BytesReceived; - CurrentFileSize = e.TotalBytesToReceive; - DownloadProgress?.Invoke(this, EventArgs.Empty); - } - - void client_DownloadFileCompleted(object? sender, AsyncCompletedEventArgs e) - { - var file = _filesToDownload[FilesDownloaded]; - if (e.Error != null) - { - _hasError = true; - if (!_cancel) - { - _logger.LogError(e.Error, "Error downloading file: {FileName}", file.DownloadInfo.FileName); - } - } - else if (file.DownloadInfo.Sha1 != CalculateSha1(Path.Combine(file.TempFolder!, file.DownloadInfo.FileName))) - { - _hasError = true; - _logger.LogError("Error downloading file (invalid checksum): {FileName}", file.DownloadInfo.FileName); - } - else - { - FilesDownloaded++; - } - CurrentFileProgress = 0; - CurrentFileSize = 0; - DownloadProgress?.Invoke(this, EventArgs.Empty); - StartNextDownload(); - } - public void QueueFile(DownloadInfo downloadInfo, Action fileCallback) { _filesToDownload.Add(new QueueItem { DownloadInfo = downloadInfo, FileCallback = fileCallback }); @@ -88,7 +42,7 @@ public class DownloadController public void Stop() { _cancel = true; - _client.CancelAsync(); + _client.CancelPendingRequests(); } public event EventHandler? DownloadError; @@ -97,63 +51,116 @@ public class DownloadController public event EventHandler? DownloadProgress; - private void StartNextDownload() + private async Task TryDownloadFromUrlAsync(string filename, string url) { - if (_hasError) + CurrentFileProgress = 0; + CurrentFileSize = 0; + DownloadProgress?.Invoke(this, EventArgs.Empty); + try { - var prev = _filesToDownload[FilesDownloaded]; - Directory.Delete(prev.TempFolder!, true); - if (_cancel) + var response = await _client.GetAsync(url, HttpCompletionOption.ResponseHeadersRead); + response.EnsureSuccessStatusCode(); + CurrentFileSize = response.Content.Headers.ContentLength.GetValueOrDefault(); + + using var contentStream = await response.Content.ReadAsStreamAsync(); + + var result = new MemoryStream(); + long previousLength; + byte[] buffer = new byte[1024 * 40]; + do { - return; + previousLength = result.Length; + int length = await contentStream.ReadAsync(buffer, 0, buffer.Length); + if (length > 0) + { + result.Write(buffer, 0, length); + CurrentFileProgress = result.Length; + DownloadProgress?.Invoke(this, EventArgs.Empty); + } + if (_cancel) + { + throw new OperationCanceledException(); + } } - // Retry if possible - _urlIndex++; - _hasError = false; + while (previousLength < result.Length); + return result; } - else + catch (OperationCanceledException) { - _urlIndex = 0; + throw; } - if (FilesDownloaded > 0 && _urlIndex == 0) + catch (Exception ex) { - var prev = _filesToDownload[FilesDownloaded - 1]; - var filePath = Path.Combine(prev.TempFolder!, prev.DownloadInfo.FileName); + _logger.LogError(ex, "Error downloading file: {FileName}", filename); + return null; + } + } + + private async Task TryDownloadQueueItemAsync(QueueItem fileToDownload) + { + foreach (var url in fileToDownload.DownloadInfo.Urls) + { + var result = await TryDownloadFromUrlAsync(fileToDownload.DownloadInfo.FileName, url); + if (result != null) + { + result.Position = 0; + if (fileToDownload.DownloadInfo.Sha1 == CalculateSha1(result)) + { + return result; + } + _logger.LogError("Error downloading file (invalid checksum): {FileName}", fileToDownload.DownloadInfo.FileName); + } + } + return null; + } + + private async Task InternalStartDownloadsAsync() + { + FilesDownloaded = 0; + foreach (var fileToDownload in _filesToDownload) + { + MemoryStream? result; try { - var preparedFilePath = prev.DownloadInfo.Format.Prepare(filePath); - prev.FileCallback(preparedFilePath); + result = await TryDownloadQueueItemAsync(fileToDownload); + } + catch (OperationCanceledException) + { + return false; + } + + if (result == null) + { + DownloadComplete?.Invoke(this, EventArgs.Empty); + DownloadError?.Invoke(this, EventArgs.Empty); + return false; + } + + fileToDownload.TempFolder = Path.Combine(_scanningContext.TempFolderPath, Path.GetRandomFileName()); + Directory.CreateDirectory(fileToDownload.TempFolder); + string p = Path.Combine(fileToDownload.TempFolder, fileToDownload.DownloadInfo.FileName); + try + { + result.Position = 0; + var preparedFilePath = fileToDownload.DownloadInfo.Format.Prepare(result, p); + fileToDownload.FileCallback(preparedFilePath); } catch (Exception ex) { _logger.LogError(ex, "Error preparing downloaded file"); DownloadError?.Invoke(this, EventArgs.Empty); } - Directory.Delete(prev.TempFolder!, true); + FilesDownloaded++; + Directory.Delete(fileToDownload.TempFolder, true); } - if (FilesDownloaded >= _filesToDownload.Count) - { - DownloadComplete?.Invoke(this, EventArgs.Empty); - _completionSource.SetResult(true); - return; - } - if (_urlIndex >= _filesToDownload[FilesDownloaded].DownloadInfo.Urls.Count) - { - DownloadComplete?.Invoke(this, EventArgs.Empty); - _completionSource.SetResult(false); - DownloadError?.Invoke(this, EventArgs.Empty); - return; - } - var next = _filesToDownload[FilesDownloaded]; - next.TempFolder = Path.Combine(_scanningContext.TempFolderPath, Path.GetRandomFileName()); - Directory.CreateDirectory(next.TempFolder); - _client.DownloadFileAsync(new Uri(next.DownloadInfo.Urls[_urlIndex]), Path.Combine(next.TempFolder, next.DownloadInfo.FileName)); + + DownloadComplete?.Invoke(this, EventArgs.Empty); + return true; } - private string CalculateSha1(string filePath) + private string CalculateSha1(Stream stream) { using var sha = SHA1.Create(); - using FileStream stream = File.OpenRead(filePath); byte[] checksum = sha.ComputeHash(stream); string str = BitConverter.ToString(checksum).Replace("-", String.Empty).ToLowerInvariant(); return str; @@ -168,9 +175,9 @@ public class DownloadController public required Action FileCallback { get; set; } } - public void Start() + public async Task StartDownloadsAsync() { DownloadProgress?.Invoke(this, EventArgs.Empty); - StartNextDownload(); + return await InternalStartDownloadsAsync(); } } \ No newline at end of file diff --git a/NAPS2.Lib/Dependencies/DownloadFormat.cs b/NAPS2.Lib/Dependencies/DownloadFormat.cs index e5ba34869..2e7a29ba8 100644 --- a/NAPS2.Lib/Dependencies/DownloadFormat.cs +++ b/NAPS2.Lib/Dependencies/DownloadFormat.cs @@ -4,46 +4,41 @@ namespace NAPS2.Dependencies; public abstract class DownloadFormat { - public static DownloadFormat Gzip = new GzipDownloadFormat(); + public static readonly DownloadFormat Gzip = new GzipDownloadFormat(); - public static DownloadFormat Zip = new ZipDownloadFormat(); + public static readonly DownloadFormat Zip = new ZipDownloadFormat(); - public abstract string Prepare(string tempFilePath); + public abstract string Prepare(MemoryStream stream, string tempFilePath); private class GzipDownloadFormat : DownloadFormat { - public override string Prepare(string tempFilePath) + private const string FileExtension = ".gz"; + + public override string Prepare(MemoryStream stream, string tempFilePath) { - if (!tempFilePath.EndsWith(".gz", StringComparison.InvariantCultureIgnoreCase)) + if (tempFilePath.EndsWith(FileExtension, StringComparison.InvariantCultureIgnoreCase)) { - throw new ArgumentException(); + tempFilePath = tempFilePath.Substring(0, tempFilePath.Length - 3); } - var pathWithoutGz = tempFilePath.Substring(0, tempFilePath.Length - 3); - Extract(tempFilePath, pathWithoutGz); - return pathWithoutGz; + Extract(stream, tempFilePath); + return tempFilePath; } - private static void Extract(string sourcePath, string destPath) + private static void Extract(MemoryStream stream, string destPath) { - using FileStream inFile = new FileInfo(sourcePath).OpenRead(); using FileStream outFile = File.Create(destPath); - using GZipStream decompress = new GZipStream(inFile, CompressionMode.Decompress); + using GZipStream decompress = new(stream, CompressionMode.Decompress); decompress.CopyTo(outFile); } } private class ZipDownloadFormat : DownloadFormat { - public override string Prepare(string tempFilePath) + public override string Prepare(MemoryStream stream, string tempFilePath) { - if (!tempFilePath.EndsWith(".zip", StringComparison.InvariantCultureIgnoreCase)) - { - throw new ArgumentException(); - } - - var tempDir = Path.GetDirectoryName(tempFilePath) ?? throw new ArgumentNullException(); - ZipFile.ExtractToDirectory(tempFilePath, tempDir); - File.Delete(tempFilePath); + var tempDir = Path.GetDirectoryName(tempFilePath) ?? throw new ArgumentException("Path was a root path", nameof(tempFilePath)); + ZipArchive archive = new(stream); + archive.ExtractToDirectory(tempDir); return tempDir; } } diff --git a/NAPS2.Lib/EtoForms/Ui/DownloadProgressForm.cs b/NAPS2.Lib/EtoForms/Ui/DownloadProgressForm.cs index 35f0da3d6..5bc5f343f 100644 --- a/NAPS2.Lib/EtoForms/Ui/DownloadProgressForm.cs +++ b/NAPS2.Lib/EtoForms/Ui/DownloadProgressForm.cs @@ -51,7 +51,7 @@ public class DownloadProgressForm : EtoDialogBase protected override void OnLoad(EventArgs e) { base.OnLoad(e); - Controller.Start(); + Controller.StartDownloadsAsync().AssertNoAwait(); } private void OnDownloadProgress(object? sender, EventArgs e)