DownloadController uses HttpClient instead of WebClient

This can use Task and download into a Stream. Thus DownloadFormat is now
accepting a Stream and canceling a download (before extraction) does not
require it to remove the temporary file.
This commit is contained in:
Fabian Neundorf 2023-06-10 10:47:51 +02:00
parent dd6b7ccc45
commit 758e1f2e01
5 changed files with 122 additions and 117 deletions

View File

@ -10,10 +10,12 @@ public class DownloadFormatTests : ContextualTests
public void Gzip()
{
var path = Path.Combine(FolderPath, "f.gz");
var gzData = BinaryResources.stock_dog_jpeg;
File.WriteAllBytes(path, gzData);
var extractedPath = DownloadFormat.Gzip.Prepare(path);
string extractedPath;
using (MemoryStream stream = new(BinaryResources.stock_dog_jpeg))
{
extractedPath = DownloadFormat.Gzip.Prepare(stream, path);
}
var expectedDog = BinaryResources.stock_dog;
Assert.Equal(expectedDog, File.ReadAllBytes(extractedPath));
@ -23,10 +25,12 @@ public class DownloadFormatTests : ContextualTests
public void Zip()
{
var path = Path.Combine(FolderPath, "f.zip");
var zipData = BinaryResources.animals;
File.WriteAllBytes(path, zipData);
var extractedPath = DownloadFormat.Zip.Prepare(path);
string extractedPath;
using (MemoryStream stream = new(BinaryResources.animals))
{
extractedPath = DownloadFormat.Zip.Prepare(stream, path);
}
var dogPath = Path.Combine(extractedPath, "animals/dogs/stock-dog.jpeg");
var catPath = Path.Combine(extractedPath, "animals/cats/stock-cat.jpeg");

View File

@ -202,8 +202,7 @@ public class AutomatedScanning
{
_errorOutput.DisplayError(MiscResources.FilesCouldNotBeDownloaded);
};
downloadController.Start();
await downloadController.CompletionTask;
await downloadController.StartDownloadsAsync();
}
private void ReorderScannedImages()

View File

@ -1,5 +1,4 @@
using System.ComponentModel;
using System.Net;
using System.Net.Http;
using System.Security.Cryptography;
using Microsoft.Extensions.Logging;
using NAPS2.Scan;
@ -11,25 +10,15 @@ public class DownloadController
private readonly ScanningContext _scanningContext;
private readonly ILogger _logger;
// TODO: Migrate to HttpClient
#pragma warning disable SYSLIB0014
private readonly WebClient _client = new();
#pragma warning restore SYSLIB0014
private static readonly HttpClient _client = new();
private readonly List<QueueItem> _filesToDownload = new();
private readonly TaskCompletionSource<bool> _completionSource = new();
private int _urlIndex;
private bool _hasError;
private bool _cancel;
public DownloadController(ScanningContext scanningContext)
{
_scanningContext = scanningContext;
_logger = scanningContext.Logger;
// TODO: Is this needed for net462?
ServicePointManager.SecurityProtocol = SecurityProtocolType.Tls12;
_client.DownloadFileCompleted += client_DownloadFileCompleted;
_client.DownloadProgressChanged += client_DownloadProgressChanged;
}
public int FilesDownloaded { get; private set; }
@ -40,41 +29,6 @@ public class DownloadController
public long CurrentFileProgress { get; private set; }
public Task CompletionTask => _completionSource.Task;
void client_DownloadProgressChanged(object? sender, DownloadProgressChangedEventArgs e)
{
CurrentFileProgress = e.BytesReceived;
CurrentFileSize = e.TotalBytesToReceive;
DownloadProgress?.Invoke(this, EventArgs.Empty);
}
void client_DownloadFileCompleted(object? sender, AsyncCompletedEventArgs e)
{
var file = _filesToDownload[FilesDownloaded];
if (e.Error != null)
{
_hasError = true;
if (!_cancel)
{
_logger.LogError(e.Error, "Error downloading file: {FileName}", file.DownloadInfo.FileName);
}
}
else if (file.DownloadInfo.Sha1 != CalculateSha1(Path.Combine(file.TempFolder!, file.DownloadInfo.FileName)))
{
_hasError = true;
_logger.LogError("Error downloading file (invalid checksum): {FileName}", file.DownloadInfo.FileName);
}
else
{
FilesDownloaded++;
}
CurrentFileProgress = 0;
CurrentFileSize = 0;
DownloadProgress?.Invoke(this, EventArgs.Empty);
StartNextDownload();
}
public void QueueFile(DownloadInfo downloadInfo, Action<string> fileCallback)
{
_filesToDownload.Add(new QueueItem { DownloadInfo = downloadInfo, FileCallback = fileCallback });
@ -88,7 +42,7 @@ public class DownloadController
public void Stop()
{
_cancel = true;
_client.CancelAsync();
_client.CancelPendingRequests();
}
public event EventHandler? DownloadError;
@ -97,63 +51,116 @@ public class DownloadController
public event EventHandler? DownloadProgress;
private void StartNextDownload()
private async Task<MemoryStream?> TryDownloadFromUrlAsync(string filename, string url)
{
if (_hasError)
CurrentFileProgress = 0;
CurrentFileSize = 0;
DownloadProgress?.Invoke(this, EventArgs.Empty);
try
{
var prev = _filesToDownload[FilesDownloaded];
Directory.Delete(prev.TempFolder!, true);
if (_cancel)
var response = await _client.GetAsync(url, HttpCompletionOption.ResponseHeadersRead);
response.EnsureSuccessStatusCode();
CurrentFileSize = response.Content.Headers.ContentLength.GetValueOrDefault();
using var contentStream = await response.Content.ReadAsStreamAsync();
var result = new MemoryStream();
long previousLength;
byte[] buffer = new byte[1024 * 40];
do
{
return;
previousLength = result.Length;
int length = await contentStream.ReadAsync(buffer, 0, buffer.Length);
if (length > 0)
{
result.Write(buffer, 0, length);
CurrentFileProgress = result.Length;
DownloadProgress?.Invoke(this, EventArgs.Empty);
}
if (_cancel)
{
throw new OperationCanceledException();
}
}
// Retry if possible
_urlIndex++;
_hasError = false;
while (previousLength < result.Length);
return result;
}
else
catch (OperationCanceledException)
{
_urlIndex = 0;
throw;
}
if (FilesDownloaded > 0 && _urlIndex == 0)
catch (Exception ex)
{
var prev = _filesToDownload[FilesDownloaded - 1];
var filePath = Path.Combine(prev.TempFolder!, prev.DownloadInfo.FileName);
_logger.LogError(ex, "Error downloading file: {FileName}", filename);
return null;
}
}
private async Task<MemoryStream?> TryDownloadQueueItemAsync(QueueItem fileToDownload)
{
foreach (var url in fileToDownload.DownloadInfo.Urls)
{
var result = await TryDownloadFromUrlAsync(fileToDownload.DownloadInfo.FileName, url);
if (result != null)
{
result.Position = 0;
if (fileToDownload.DownloadInfo.Sha1 == CalculateSha1(result))
{
return result;
}
_logger.LogError("Error downloading file (invalid checksum): {FileName}", fileToDownload.DownloadInfo.FileName);
}
}
return null;
}
private async Task<bool> InternalStartDownloadsAsync()
{
FilesDownloaded = 0;
foreach (var fileToDownload in _filesToDownload)
{
MemoryStream? result;
try
{
var preparedFilePath = prev.DownloadInfo.Format.Prepare(filePath);
prev.FileCallback(preparedFilePath);
result = await TryDownloadQueueItemAsync(fileToDownload);
}
catch (OperationCanceledException)
{
return false;
}
if (result == null)
{
DownloadComplete?.Invoke(this, EventArgs.Empty);
DownloadError?.Invoke(this, EventArgs.Empty);
return false;
}
fileToDownload.TempFolder = Path.Combine(_scanningContext.TempFolderPath, Path.GetRandomFileName());
Directory.CreateDirectory(fileToDownload.TempFolder);
string p = Path.Combine(fileToDownload.TempFolder, fileToDownload.DownloadInfo.FileName);
try
{
result.Position = 0;
var preparedFilePath = fileToDownload.DownloadInfo.Format.Prepare(result, p);
fileToDownload.FileCallback(preparedFilePath);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error preparing downloaded file");
DownloadError?.Invoke(this, EventArgs.Empty);
}
Directory.Delete(prev.TempFolder!, true);
FilesDownloaded++;
Directory.Delete(fileToDownload.TempFolder, true);
}
if (FilesDownloaded >= _filesToDownload.Count)
{
DownloadComplete?.Invoke(this, EventArgs.Empty);
_completionSource.SetResult(true);
return;
}
if (_urlIndex >= _filesToDownload[FilesDownloaded].DownloadInfo.Urls.Count)
{
DownloadComplete?.Invoke(this, EventArgs.Empty);
_completionSource.SetResult(false);
DownloadError?.Invoke(this, EventArgs.Empty);
return;
}
var next = _filesToDownload[FilesDownloaded];
next.TempFolder = Path.Combine(_scanningContext.TempFolderPath, Path.GetRandomFileName());
Directory.CreateDirectory(next.TempFolder);
_client.DownloadFileAsync(new Uri(next.DownloadInfo.Urls[_urlIndex]), Path.Combine(next.TempFolder, next.DownloadInfo.FileName));
DownloadComplete?.Invoke(this, EventArgs.Empty);
return true;
}
private string CalculateSha1(string filePath)
private string CalculateSha1(Stream stream)
{
using var sha = SHA1.Create();
using FileStream stream = File.OpenRead(filePath);
byte[] checksum = sha.ComputeHash(stream);
string str = BitConverter.ToString(checksum).Replace("-", String.Empty).ToLowerInvariant();
return str;
@ -168,9 +175,9 @@ public class DownloadController
public required Action<string> FileCallback { get; set; }
}
public void Start()
public async Task<bool> StartDownloadsAsync()
{
DownloadProgress?.Invoke(this, EventArgs.Empty);
StartNextDownload();
return await InternalStartDownloadsAsync();
}
}

View File

@ -4,46 +4,41 @@ namespace NAPS2.Dependencies;
public abstract class DownloadFormat
{
public static DownloadFormat Gzip = new GzipDownloadFormat();
public static readonly DownloadFormat Gzip = new GzipDownloadFormat();
public static DownloadFormat Zip = new ZipDownloadFormat();
public static readonly DownloadFormat Zip = new ZipDownloadFormat();
public abstract string Prepare(string tempFilePath);
public abstract string Prepare(MemoryStream stream, string tempFilePath);
private class GzipDownloadFormat : DownloadFormat
{
public override string Prepare(string tempFilePath)
private const string FileExtension = ".gz";
public override string Prepare(MemoryStream stream, string tempFilePath)
{
if (!tempFilePath.EndsWith(".gz", StringComparison.InvariantCultureIgnoreCase))
if (tempFilePath.EndsWith(FileExtension, StringComparison.InvariantCultureIgnoreCase))
{
throw new ArgumentException();
tempFilePath = tempFilePath.Substring(0, tempFilePath.Length - 3);
}
var pathWithoutGz = tempFilePath.Substring(0, tempFilePath.Length - 3);
Extract(tempFilePath, pathWithoutGz);
return pathWithoutGz;
Extract(stream, tempFilePath);
return tempFilePath;
}
private static void Extract(string sourcePath, string destPath)
private static void Extract(MemoryStream stream, string destPath)
{
using FileStream inFile = new FileInfo(sourcePath).OpenRead();
using FileStream outFile = File.Create(destPath);
using GZipStream decompress = new GZipStream(inFile, CompressionMode.Decompress);
using GZipStream decompress = new(stream, CompressionMode.Decompress);
decompress.CopyTo(outFile);
}
}
private class ZipDownloadFormat : DownloadFormat
{
public override string Prepare(string tempFilePath)
public override string Prepare(MemoryStream stream, string tempFilePath)
{
if (!tempFilePath.EndsWith(".zip", StringComparison.InvariantCultureIgnoreCase))
{
throw new ArgumentException();
}
var tempDir = Path.GetDirectoryName(tempFilePath) ?? throw new ArgumentNullException();
ZipFile.ExtractToDirectory(tempFilePath, tempDir);
File.Delete(tempFilePath);
var tempDir = Path.GetDirectoryName(tempFilePath) ?? throw new ArgumentException("Path was a root path", nameof(tempFilePath));
ZipArchive archive = new(stream);
archive.ExtractToDirectory(tempDir);
return tempDir;
}
}

View File

@ -51,7 +51,7 @@ public class DownloadProgressForm : EtoDialogBase
protected override void OnLoad(EventArgs e)
{
base.OnLoad(e);
Controller.Start();
Controller.StartDownloadsAsync().AssertNoAwait();
}
private void OnDownloadProgress(object? sender, EventArgs e)