Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/Renci.SshNet/DownloadFileProgressReport.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
namespace Renci.SshNet
{
/// <summary>
/// Provides the progress for a file download.
/// </summary>
public struct DownloadFileProgressReport
{
/// <summary>
/// Gets the total number of bytes downloaded.
/// </summary>
public ulong TotalBytesDownloaded { get; internal set; }
}
}
36 changes: 35 additions & 1 deletion src/Renci.SshNet/ISftpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,23 @@ public interface ISftpClient : IBaseClient
/// <exception cref="ObjectDisposedException">The method was called after the client was disposed.</exception>
Task DownloadFileAsync(string path, Stream output, CancellationToken cancellationToken = default);

/// <summary>
/// Asynchronously downloads a remote file into a <see cref="Stream"/>.
/// </summary>
/// <param name="path">The path to the remote file.</param>
/// <param name="output">The <see cref="Stream"/> to write the file into.</param>
/// <param name="downloadProgress">The download progress.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to observe.</param>
/// <returns>A <see cref="Task"/> that represents the asynchronous download operation.</returns>
/// <exception cref="ArgumentNullException"><paramref name="output"/> or <paramref name="path"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentException"><paramref name="path"/> is empty or contains only whitespace characters.</exception>
/// <exception cref="SshConnectionException">Client is not connected.</exception>
/// <exception cref="SftpPermissionDeniedException">Permission to perform the operation was denied by the remote host. <para>-or-</para> An SSH command was denied by the server.</exception>
/// <exception cref="SftpPathNotFoundException"><paramref name="path"/> was not found on the remote host.</exception>
/// <exception cref="SshException">An SSH error where <see cref="Exception.Message" /> is the message from the remote host.</exception>
/// <exception cref="ObjectDisposedException">The method was called after the client was disposed.</exception>
Task DownloadFileAsync(string path, Stream output, IProgress<DownloadFileProgressReport>? downloadProgress, CancellationToken cancellationToken = default);

/// <summary>
/// Ends an asynchronous file downloading into the stream.
/// </summary>
Expand Down Expand Up @@ -1133,12 +1150,29 @@ public interface ISftpClient : IBaseClient
/// <exception cref="ObjectDisposedException">The method was called after the client was disposed.</exception>
Task UploadFileAsync(Stream input, string path, CancellationToken cancellationToken = default);

/// <summary>
/// Asynchronously uploads a <see cref="Stream"/> to a remote file path.
/// </summary>
/// <param name="input">The <see cref="Stream"/> to write to the remote path.</param>
/// <param name="path">The remote file path to write to.</param>
/// <param name="uploadProgress">The upload progress.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to observe.</param>
/// <returns>A <see cref="Task"/> that represents the asynchronous upload operation.</returns>
/// <exception cref="ArgumentNullException"><paramref name="input"/> or <paramref name="path"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentException"><paramref name="path" /> is empty or contains only whitespace characters.</exception>
/// <exception cref="SshConnectionException">Client is not connected.</exception>
/// <exception cref="SftpPermissionDeniedException">Permission to upload the file was denied by the remote host. <para>-or-</para> An SSH command was denied by the server.</exception>
/// <exception cref="SshException">An SSH error where <see cref="Exception.Message" /> is the message from the remote host.</exception>
/// <exception cref="ObjectDisposedException">The method was called after the client was disposed.</exception>
Task UploadFileAsync(Stream input, string path, IProgress<UploadFileProgressReport>? uploadProgress, CancellationToken cancellationToken = default);

/// <summary>
/// Asynchronously uploads a <see cref="Stream"/> to a remote file path.
/// </summary>
/// <param name="input">The <see cref="Stream"/> to write to the remote path.</param>
/// <param name="path">The remote file path to write to.</param>
/// <param name="canOverride">Whether the remote file can be overwritten if it already exists.</param>
/// <param name="uploadProgress">The upload progress.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to observe.</param>
/// <returns>A <see cref="Task"/> that represents the asynchronous upload operation.</returns>
/// <exception cref="ArgumentNullException"><paramref name="input"/> or <paramref name="path"/> is <see langword="null"/>.</exception>
Expand All @@ -1147,7 +1181,7 @@ public interface ISftpClient : IBaseClient
/// <exception cref="SftpPermissionDeniedException">Permission to upload the file was denied by the remote host. <para>-or-</para> An SSH command was denied by the server.</exception>
/// <exception cref="SshException">An SSH error where <see cref="Exception.Message" /> is the message from the remote host.</exception>
/// <exception cref="ObjectDisposedException">The method was called after the client was disposed.</exception>
Task UploadFileAsync(Stream input, string path, bool canOverride, CancellationToken cancellationToken = default);
Task UploadFileAsync(Stream input, string path, bool canOverride, IProgress<UploadFileProgressReport>? uploadProgress = null, CancellationToken cancellationToken = default);

/// <summary>
/// Writes the specified byte array to the specified file, and closes the file.
Expand Down
82 changes: 64 additions & 18 deletions src/Renci.SshNet/SftpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -901,17 +901,30 @@ public void DownloadFile(string path, Stream output, Action<ulong>? downloadCall
ArgumentNullException.ThrowIfNull(output);
CheckDisposed();

IProgress<DownloadFileProgressReport>? downloadProgress = null;

if (downloadCallback != null)
{
downloadProgress = new Progress<DownloadFileProgressReport>(r => downloadCallback(r.TotalBytesDownloaded));
}

InternalDownloadFile(
path,
output,
asyncResult: null,
downloadCallback,
downloadProgress,
isAsync: false,
CancellationToken.None).GetAwaiter().GetResult();
}

/// <inheritdoc />
public Task DownloadFileAsync(string path, Stream output, CancellationToken cancellationToken = default)
{
return DownloadFileAsync(path, output, downloadProgress: null, cancellationToken);
}

/// <inheritdoc />
public Task DownloadFileAsync(string path, Stream output, IProgress<DownloadFileProgressReport>? downloadProgress, CancellationToken cancellationToken = default)
{
ArgumentException.ThrowIfNullOrWhiteSpace(path);
ArgumentNullException.ThrowIfNull(output);
Expand All @@ -921,7 +934,7 @@ public Task DownloadFileAsync(string path, Stream output, CancellationToken canc
path,
output,
asyncResult: null,
downloadCallback: null,
downloadProgress: downloadProgress,
isAsync: true,
cancellationToken);
}
Expand Down Expand Up @@ -994,6 +1007,13 @@ public IAsyncResult BeginDownloadFile(string path, Stream output, AsyncCallback?
ArgumentNullException.ThrowIfNull(output);
CheckDisposed();

IProgress<DownloadFileProgressReport>? downloadProgress = null;

if (downloadCallback != null)
{
downloadProgress = new Progress<DownloadFileProgressReport>(r => downloadCallback(r.TotalBytesDownloaded));
}

var asyncResult = new SftpDownloadAsyncResult(asyncCallback, state);

_ = DoDownloadAndSetResult();
Expand All @@ -1006,7 +1026,7 @@ await InternalDownloadFile(
path,
output,
asyncResult,
downloadCallback,
downloadProgress,
isAsync: true,
CancellationToken.None).ConfigureAwait(false);

Expand Down Expand Up @@ -1065,24 +1085,37 @@ public void UploadFile(Stream input, string path, bool canOverride, Action<ulong
flags |= Flags.CreateNew;
}

IProgress<UploadFileProgressReport>? uploadProgress = null;

if (uploadCallback != null)
{
uploadProgress = new Progress<UploadFileProgressReport>(r => uploadCallback(r.TotalBytesUploaded));
}

InternalUploadFile(
input,
path,
flags,
asyncResult: null,
uploadCallback,
uploadProgress,
isAsync: false,
CancellationToken.None).GetAwaiter().GetResult();
}

/// <inheritdoc />
public Task UploadFileAsync(Stream input, string path, CancellationToken cancellationToken = default)
{
return UploadFileAsync(input, path, canOverride: true, cancellationToken);
return UploadFileAsync(input, path, canOverride: true, uploadProgress: null, cancellationToken);
}

/// <inheritdoc />
public Task UploadFileAsync(Stream input, string path, bool canOverride, CancellationToken cancellationToken = default)
public Task UploadFileAsync(Stream input, string path, IProgress<UploadFileProgressReport>? uploadProgress, CancellationToken cancellationToken = default)
{
return UploadFileAsync(input, path, canOverride: true, uploadProgress, cancellationToken);
}

/// <inheritdoc />
public Task UploadFileAsync(Stream input, string path, bool canOverride, IProgress<UploadFileProgressReport>? uploadProgress = null, CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(input);
ArgumentException.ThrowIfNullOrWhiteSpace(path);
Expand All @@ -1104,7 +1137,7 @@ public Task UploadFileAsync(Stream input, string path, bool canOverride, Cancell
path,
flags,
asyncResult: null,
uploadCallback: null,
uploadProgress,
isAsync: true,
cancellationToken);
}
Expand Down Expand Up @@ -1236,6 +1269,13 @@ public IAsyncResult BeginUploadFile(Stream input, string path, bool canOverride,
flags |= Flags.CreateNew;
}

IProgress<UploadFileProgressReport>? uploadProgress = null;

if (uploadCallback != null)
{
uploadProgress = new Progress<UploadFileProgressReport>(r => uploadCallback(r.TotalBytesUploaded));
}

var asyncResult = new SftpUploadAsyncResult(asyncCallback, state);

_ = DoUploadAndSetResult();
Expand All @@ -1249,7 +1289,7 @@ await InternalUploadFile(
path,
flags,
asyncResult,
uploadCallback,
uploadProgress,
isAsync: true,
CancellationToken.None).ConfigureAwait(false);

Expand Down Expand Up @@ -2195,7 +2235,7 @@ private List<FileInfo> InternalSynchronizeDirectories(string sourcePath, string
remoteFileName,
uploadFlag,
asyncResult: null,
uploadCallback: null,
uploadProgress: null,
isAsync: false,
CancellationToken.None).GetAwaiter().GetResult();
#pragma warning restore CA2025 // Do not pass 'IDisposable' instances into unawaited tasks
Expand Down Expand Up @@ -2291,7 +2331,7 @@ private async Task InternalDownloadFile(
string path,
Stream output,
SftpDownloadAsyncResult? asyncResult,
Action<ulong>? downloadCallback,
IProgress<DownloadFileProgressReport>? downloadProgress,
bool isAsync,
CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -2377,13 +2417,15 @@ private async Task InternalDownloadFile(

asyncResult?.Update(totalBytesRead);

if (downloadCallback is not null)
if (downloadProgress is not null)
{
// Copy offset to ensure it's not modified between now and execution of callback
var downloadOffset = totalBytesRead;
var report = new DownloadFileProgressReport()
{
TotalBytesDownloaded = totalBytesRead,
};

// Execute callback on different thread
ThreadAbstraction.ExecuteThread(() => { downloadCallback(downloadOffset); });
downloadProgress.Report(report);
}
}
}
Expand All @@ -2407,7 +2449,7 @@ private async Task InternalUploadFile(
string path,
Flags flags,
SftpUploadAsyncResult? asyncResult,
Action<ulong>? uploadCallback,
IProgress<UploadFileProgressReport>? uploadProgress,
bool isAsync,
CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -2495,10 +2537,14 @@ private async Task InternalUploadFile(
asyncResult?.Update(writtenBytes);

// Call callback to report number of bytes written
if (uploadCallback is not null)
if (uploadProgress is not null)
{
// Execute callback on different thread
ThreadAbstraction.ExecuteThread(() => uploadCallback(writtenBytes));
UploadFileProgressReport report = new()
{
TotalBytesUploaded = writtenBytes,
};

uploadProgress.Report(report);
}
}
finally
Expand Down
13 changes: 13 additions & 0 deletions src/Renci.SshNet/UploadFileProgressReport.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
namespace Renci.SshNet
{
/// <summary>
/// Provides the progress for a file upload.
/// </summary>
public struct UploadFileProgressReport
{
/// <summary>
/// Gets the total number of bytes uploaded.
/// </summary>
public ulong TotalBytesUploaded { get; internal set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,5 +124,34 @@ public void Test_Sftp_EndDownloadFile_Invalid_Async_Handle()
Assert.ThrowsExactly<ArgumentException>(() => sftp.EndDownloadFile(async1));
}
}

[TestMethod]
[TestCategory("Sftp")]
public async Task Test_Sftp_DownloadFileAsync_DownloadProgress()
{
using (var sftp = new SftpClient(SshServerHostName, SshServerPort, User.UserName, User.Password))
{
await sftp.ConnectAsync(CancellationToken.None);
var filename = Path.GetTempFileName();
int testFileSizeMB = 1;
CreateTestFile(filename, testFileSizeMB);
await sftp.UploadFileAsync(File.OpenRead(filename), "test123");
using ManualResetEventSlim finalCallbackCalledEvent = new();

IProgress<DownloadFileProgressReport> progress = new Progress<DownloadFileProgressReport>(r =>
{
if ((int)r.TotalBytesDownloaded == testFileSizeMB * 1024 * 1024)
{
finalCallbackCalledEvent.Set();
}
});

await sftp.DownloadFileAsync("test123", new MemoryStream(), progress, CancellationToken.None);

// since the callback is queued to the thread pool, wait for the event.
bool callbackCalled = finalCallbackCalledEvent.Wait(5000);
Assert.IsTrue(callbackCalled);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -453,5 +453,34 @@ public void Test_Sftp_EndUploadFile_Invalid_Async_Handle()
Assert.ThrowsExactly<ArgumentException>(() => sftp.EndUploadFile(async1));
}
}

[TestMethod]
[TestCategory("Sftp")]
public async Task Test_Sftp_UploadFileAsync_UploadProgress()
{
using (var sftp = new SftpClient(SshServerHostName, SshServerPort, User.UserName, User.Password))
{
await sftp.ConnectAsync(CancellationToken.None);
var filename = Path.GetTempFileName();
int testFileSizeMB = 1;
CreateTestFile(filename, testFileSizeMB);
using var fileStream = File.OpenRead(filename);
using ManualResetEventSlim finalCallbackCalledEvent = new();

IProgress<UploadFileProgressReport> progress = new Progress<UploadFileProgressReport>(r =>
{
if ((int)r.TotalBytesUploaded == testFileSizeMB * 1024 * 1024)
{
finalCallbackCalledEvent.Set();
}
});

await sftp.UploadFileAsync(fileStream, "test", progress);

// since the callback is queued to the thread pool, wait for the event.
bool callbackCalled = finalCallbackCalledEvent.Wait(5000);
Assert.IsTrue(callbackCalled);
}
}
}
}
Loading