SocketExtensions
using System;
using System.Net;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
namespace Renci.SshNet.Abstractions
{
internal static class SocketExtensions
{
private sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, INotifyCompletion
{
private static readonly Action SENTINEL = delegate {
};
private bool _isCancelled;
private Action _continuationAction;
public bool IsCompleted { get; set; }
public AwaitableSocketAsyncEventArgs()
{
base.Completed += delegate {
SetCompleted();
};
}
public AwaitableSocketAsyncEventArgs ExecuteAsync(Func<SocketAsyncEventArgs, bool> func)
{
if (!func(this))
SetCompleted();
return this;
}
private void SetCompleted()
{
IsCompleted = true;
Interlocked.Exchange(ref _continuationAction, SENTINEL)?.Invoke();
}
public void SetCancelled()
{
_isCancelled = true;
SetCompleted();
}
public AwaitableSocketAsyncEventArgs GetAwaiter()
{
return this;
}
void INotifyCompletion.OnCompleted(Action continuation)
{
if (_continuationAction == SENTINEL || Interlocked.CompareExchange(ref _continuationAction, continuation, null) == SENTINEL)
Task.Run(continuation);
}
public void GetResult()
{
if (_isCancelled)
throw new TaskCanceledException();
if (!IsCompleted)
throw new InvalidOperationException("The asynchronous operation has not yet completed.");
if (base.SocketError != 0)
throw new SocketException((int)base.SocketError);
}
}
[AsyncStateMachine(typeof(<ConnectAsync>d__1))]
public static Task ConnectAsync(this Socket socket, EndPoint remoteEndpoint, CancellationToken cancellationToken)
{
<ConnectAsync>d__1 stateMachine = default(<ConnectAsync>d__1);
stateMachine.<>t__builder = AsyncTaskMethodBuilder.Create();
stateMachine.socket = socket;
stateMachine.remoteEndpoint = remoteEndpoint;
stateMachine.cancellationToken = cancellationToken;
stateMachine.<>1__state = -1;
stateMachine.<>t__builder.Start(ref stateMachine);
return stateMachine.<>t__builder.Task;
}
[AsyncStateMachine(typeof(<ReceiveAsync>d__2))]
public static Task<int> ReceiveAsync(this Socket socket, byte[] buffer, int offset, int length, CancellationToken cancellationToken)
{
<ReceiveAsync>d__2 stateMachine = default(<ReceiveAsync>d__2);
stateMachine.<>t__builder = AsyncTaskMethodBuilder<int>.Create();
stateMachine.socket = socket;
stateMachine.buffer = buffer;
stateMachine.offset = offset;
stateMachine.length = length;
stateMachine.cancellationToken = cancellationToken;
stateMachine.<>1__state = -1;
stateMachine.<>t__builder.Start(ref stateMachine);
return stateMachine.<>t__builder.Task;
}
}
}