SocketAbstraction
using Renci.SshNet.Common;
using Renci.SshNet.Messages.Transport;
using System;
using System.Globalization;
using System.Net;
using System.Net.Sockets;
using System.Threading;
namespace Renci.SshNet.Abstractions
{
internal static class SocketAbstraction
{
private interface Token
{
void Process(SocketAsyncEventArgs args);
}
private class BlockingSendReceiveToken : Token
{
private readonly int _bytesToTransfer;
private readonly EventWaitHandle _completionWaitHandle;
private readonly Socket _socket;
private readonly byte[] _buffer;
private int _offset;
public int TotalBytesTransferred { get; set; }
public BlockingSendReceiveToken(Socket socket, byte[] buffer, int offset, int size, EventWaitHandle completionWaitHandle)
{
_socket = socket;
_buffer = buffer;
_offset = offset;
_bytesToTransfer = size;
_completionWaitHandle = completionWaitHandle;
}
public void Process(SocketAsyncEventArgs args)
{
if (args.SocketError == SocketError.Success) {
TotalBytesTransferred += args.BytesTransferred;
if (TotalBytesTransferred == _bytesToTransfer)
_completionWaitHandle.Set();
else if (args.BytesTransferred == 0) {
_completionWaitHandle.Set();
} else {
_offset += args.BytesTransferred;
args.SetBuffer(_buffer, _offset, _bytesToTransfer - TotalBytesTransferred);
ResumeOperation(args);
}
} else if (IsErrorResumable(args.SocketError)) {
ThreadAbstraction.Sleep(30);
ResumeOperation(args);
} else {
_completionWaitHandle.Set();
}
}
private void ResumeOperation(SocketAsyncEventArgs args)
{
switch (args.LastOperation) {
case SocketAsyncOperation.Receive:
_socket.ReceiveAsync(args);
break;
case SocketAsyncOperation.Send:
_socket.SendAsync(args);
break;
}
}
}
private class PartialSendReceiveToken : Token
{
private readonly EventWaitHandle _completionWaitHandle;
private readonly Socket _socket;
public PartialSendReceiveToken(Socket socket, EventWaitHandle completionWaitHandle)
{
_socket = socket;
_completionWaitHandle = completionWaitHandle;
}
public void Process(SocketAsyncEventArgs args)
{
if (args.SocketError == SocketError.Success)
_completionWaitHandle.Set();
else if (IsErrorResumable(args.SocketError)) {
ThreadAbstraction.Sleep(30);
ResumeOperation(args);
} else {
_completionWaitHandle.Set();
}
}
private void ResumeOperation(SocketAsyncEventArgs args)
{
switch (args.LastOperation) {
case SocketAsyncOperation.Receive:
_socket.ReceiveAsync(args);
break;
case SocketAsyncOperation.Send:
_socket.SendAsync(args);
break;
}
}
}
private class ContinuousReceiveToken : Token
{
private readonly EventWaitHandle _completionWaitHandle;
private readonly Socket _socket;
private readonly Action<byte[], int, int> _processReceivedBytesAction;
public Exception Exception { get; set; }
public ContinuousReceiveToken(Socket socket, Action<byte[], int, int> processReceivedBytesAction, EventWaitHandle completionWaitHandle)
{
_socket = socket;
_processReceivedBytesAction = processReceivedBytesAction;
_completionWaitHandle = completionWaitHandle;
}
public void Process(SocketAsyncEventArgs args)
{
if (args.SocketError == SocketError.Success) {
if (args.BytesTransferred == 0)
_completionWaitHandle.Set();
else {
_processReceivedBytesAction(args.Buffer, args.Offset, args.BytesTransferred);
ResumeOperation(args);
}
} else if (IsErrorResumable(args.SocketError)) {
ThreadAbstraction.Sleep(30);
ResumeOperation(args);
} else {
if (args.SocketError != SocketError.OperationAborted)
Exception = new SocketException((int)args.SocketError);
_completionWaitHandle.Set();
}
}
private void ResumeOperation(SocketAsyncEventArgs args)
{
switch (args.LastOperation) {
case SocketAsyncOperation.Receive:
_socket.ReceiveAsync(args);
break;
case SocketAsyncOperation.Send:
_socket.SendAsync(args);
break;
}
}
}
public static bool CanRead(Socket socket)
{
if (socket.Connected)
return true;
return false;
}
public static bool CanWrite(Socket socket)
{
if (socket != null && socket.Connected)
return true;
return false;
}
public static Socket Connect(IPEndPoint remoteEndpoint, TimeSpan connectTimeout)
{
Socket obj = new Socket(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp) {
NoDelay = true
};
ConnectCore(obj, remoteEndpoint, connectTimeout, true);
return obj;
}
public static void Connect(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout)
{
ConnectCore(socket, remoteEndpoint, connectTimeout, false);
}
private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket)
{
ManualResetEvent manualResetEvent = new ManualResetEvent(false);
SocketAsyncEventArgs socketAsyncEventArgs = new SocketAsyncEventArgs {
UserToken = manualResetEvent,
RemoteEndPoint = remoteEndpoint
};
socketAsyncEventArgs.Completed += ConnectCompleted;
if (socket.ConnectAsync(socketAsyncEventArgs) && !manualResetEvent.WaitOne(connectTimeout)) {
socketAsyncEventArgs.Completed -= ConnectCompleted;
if (ownsSocket)
socket.Dispose();
manualResetEvent.Dispose();
socketAsyncEventArgs.Dispose();
throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Connection failed to establish within {0:F0} milliseconds.", new object[1] {
connectTimeout.TotalMilliseconds
}));
}
manualResetEvent.Dispose();
if (socketAsyncEventArgs.SocketError != 0) {
SocketError socketError = socketAsyncEventArgs.SocketError;
if (ownsSocket)
socket.Dispose();
socketAsyncEventArgs.Dispose();
throw new SocketException((int)socketError);
}
socketAsyncEventArgs.Dispose();
}
public static void ClearReadBuffer(Socket socket)
{
TimeSpan timeout = TimeSpan.FromMilliseconds(500);
byte[] array = new byte[256];
int num;
do {
num = ReadPartial(socket, array, 0, array.Length, timeout);
} while (num > 0);
}
public static int ReadPartial(Socket socket, byte[] buffer, int offset, int size, TimeSpan timeout)
{
ManualResetEvent manualResetEvent = new ManualResetEvent(false);
PartialSendReceiveToken partialSendReceiveToken = new PartialSendReceiveToken(socket, manualResetEvent);
SocketAsyncEventArgs socketAsyncEventArgs = new SocketAsyncEventArgs {
RemoteEndPoint = socket.RemoteEndPoint,
UserToken = partialSendReceiveToken
};
socketAsyncEventArgs.Completed += ReceiveCompleted;
socketAsyncEventArgs.SetBuffer(buffer, offset, size);
try {
if (socket.ReceiveAsync(socketAsyncEventArgs)) {
if (!manualResetEvent.WaitOne(timeout))
throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", new object[1] {
timeout.TotalMilliseconds
}));
} else
partialSendReceiveToken.Process(socketAsyncEventArgs);
if (socketAsyncEventArgs.SocketError != 0)
throw new SocketException((int)socketAsyncEventArgs.SocketError);
return socketAsyncEventArgs.BytesTransferred;
} finally {
socketAsyncEventArgs.UserToken = null;
socketAsyncEventArgs.Dispose();
manualResetEvent.Dispose();
}
}
public static void ReadContinuous(Socket socket, byte[] buffer, int offset, int size, Action<byte[], int, int> processReceivedBytesAction)
{
ManualResetEvent manualResetEvent = new ManualResetEvent(false);
ContinuousReceiveToken continuousReceiveToken = new ContinuousReceiveToken(socket, processReceivedBytesAction, manualResetEvent);
SocketAsyncEventArgs socketAsyncEventArgs = new SocketAsyncEventArgs {
RemoteEndPoint = socket.RemoteEndPoint,
UserToken = continuousReceiveToken
};
socketAsyncEventArgs.Completed += ReceiveCompleted;
socketAsyncEventArgs.SetBuffer(buffer, offset, size);
if (!socket.ReceiveAsync(socketAsyncEventArgs))
ReceiveCompleted(null, socketAsyncEventArgs);
manualResetEvent.WaitOne();
manualResetEvent.Dispose();
if (continuousReceiveToken.Exception != null)
throw continuousReceiveToken.Exception;
}
public static int ReadByte(Socket socket, TimeSpan timeout)
{
byte[] array = new byte[1];
if (Read(socket, array, 0, 1, timeout) == 0)
return -1;
return array[0];
}
public static void SendByte(Socket socket, byte value)
{
byte[] data = new byte[1] {
value
};
Send(socket, data, 0, 1);
}
public static byte[] Read(Socket socket, int size, TimeSpan timeout)
{
byte[] array = new byte[size];
Read(socket, array, 0, size, timeout);
return array;
}
public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeSpan readTimeout)
{
ManualResetEvent manualResetEvent = new ManualResetEvent(false);
BlockingSendReceiveToken blockingSendReceiveToken = new BlockingSendReceiveToken(socket, buffer, offset, size, manualResetEvent);
SocketAsyncEventArgs socketAsyncEventArgs = new SocketAsyncEventArgs {
UserToken = blockingSendReceiveToken,
RemoteEndPoint = socket.RemoteEndPoint
};
socketAsyncEventArgs.Completed += ReceiveCompleted;
socketAsyncEventArgs.SetBuffer(buffer, offset, size);
try {
if (socket.ReceiveAsync(socketAsyncEventArgs)) {
if (!manualResetEvent.WaitOne(readTimeout))
throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", new object[1] {
readTimeout.TotalMilliseconds
}));
} else
blockingSendReceiveToken.Process(socketAsyncEventArgs);
if (socketAsyncEventArgs.SocketError != 0)
throw new SocketException((int)socketAsyncEventArgs.SocketError);
return blockingSendReceiveToken.TotalBytesTransferred;
} finally {
socketAsyncEventArgs.UserToken = null;
socketAsyncEventArgs.Dispose();
manualResetEvent.Dispose();
}
}
public static void Send(Socket socket, byte[] data)
{
Send(socket, data, 0, data.Length);
}
public static void Send(Socket socket, byte[] data, int offset, int size)
{
ManualResetEvent manualResetEvent = new ManualResetEvent(false);
BlockingSendReceiveToken blockingSendReceiveToken = new BlockingSendReceiveToken(socket, data, offset, size, manualResetEvent);
SocketAsyncEventArgs socketAsyncEventArgs = new SocketAsyncEventArgs {
RemoteEndPoint = socket.RemoteEndPoint,
UserToken = blockingSendReceiveToken
};
socketAsyncEventArgs.SetBuffer(data, offset, size);
socketAsyncEventArgs.Completed += SendCompleted;
try {
if (socket.SendAsync(socketAsyncEventArgs)) {
if (!manualResetEvent.WaitOne())
throw new SocketException(10060);
} else
blockingSendReceiveToken.Process(socketAsyncEventArgs);
if (socketAsyncEventArgs.SocketError != 0)
throw new SocketException((int)socketAsyncEventArgs.SocketError);
if (blockingSendReceiveToken.TotalBytesTransferred == 0)
throw new SshConnectionException("An established connection was aborted by the server.", DisconnectReason.ConnectionLost);
} finally {
socketAsyncEventArgs.UserToken = null;
socketAsyncEventArgs.Dispose();
manualResetEvent.Dispose();
}
}
public static bool IsErrorResumable(SocketError socketError)
{
if (socketError == SocketError.IOPending || socketError == SocketError.WouldBlock || socketError == SocketError.NoBufferSpaceAvailable)
return true;
return false;
}
private static void ConnectCompleted(object sender, SocketAsyncEventArgs e)
{
((ManualResetEvent)e.UserToken)?.Set();
}
private static void ReceiveCompleted(object sender, SocketAsyncEventArgs e)
{
((Token)e.UserToken)?.Process(e);
}
private static void SendCompleted(object sender, SocketAsyncEventArgs e)
{
((Token)e.UserToken)?.Process(e);
}
}
}