ForwardedPortDynamic
Provides functionality for forwarding connections from the client to destination servers via the SSH server,
also known as dynamic port forwarding.
using Renci.SshNet.Abstractions;
using Renci.SshNet.Channels;
using Renci.SshNet.Common;
using System;
using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading;
namespace Renci.SshNet
{
public class ForwardedPortDynamic : ForwardedPort
{
private EventWaitHandle _listenerCompleted;
private bool _isDisposed;
private Socket _listener;
private int _pendingRequests;
private ManualResetEvent _stoppingListener;
public string BoundHost { get; set; }
public uint BoundPort { get; set; }
public override bool IsStarted {
get {
if (_listenerCompleted != null)
return !_listenerCompleted.WaitOne(0);
return false;
}
}
public ForwardedPortDynamic(uint port)
: this(string.Empty, port)
{
}
public ForwardedPortDynamic(string host, uint port)
{
BoundHost = host;
BoundPort = port;
}
protected override void StartPort()
{
InternalStart();
}
protected override void StopPort(TimeSpan timeout)
{
if (IsStarted) {
StopListener();
base.StopPort(timeout);
}
InternalStop(timeout);
}
protected override void CheckDisposed()
{
if (_isDisposed)
throw new ObjectDisposedException(GetType().FullName);
}
private void InternalStart()
{
IPAddress address = IPAddress.Any;
if (!string.IsNullOrEmpty(BoundHost))
address = DnsAbstraction.GetHostAddresses(BoundHost)[0];
IPEndPoint iPEndPoint = new IPEndPoint(address, (int)BoundPort);
_listener = new Socket(iPEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
_listener.Bind(iPEndPoint);
_listener.Listen(5);
base.Session.ErrorOccured += Session_ErrorOccured;
base.Session.Disconnected += Session_Disconnected;
_listenerCompleted = new ManualResetEvent(false);
ThreadAbstraction.ExecuteThread(delegate {
try {
_stoppingListener = new ManualResetEvent(false);
StartAccept();
_stoppingListener.WaitOne();
} catch (Exception exception) {
RaiseExceptionEvent(exception);
} finally {
if (base.Session != null) {
base.Session.ErrorOccured -= Session_ErrorOccured;
base.Session.Disconnected -= Session_Disconnected;
}
_listenerCompleted.Set();
}
});
}
private void StopListener()
{
if (IsStarted) {
_stoppingListener.Set();
_listener.Dispose();
_listenerCompleted.WaitOne();
}
}
private void InternalStop(TimeSpan timeout)
{
if (!(timeout == TimeSpan.Zero)) {
Stopwatch stopwatch = new Stopwatch();
stopwatch.Start();
while (!IsStarted && Interlocked.CompareExchange(ref _pendingRequests, 0, 0) != 0 && (!(stopwatch.Elapsed >= timeout) || !(timeout != Renci.SshNet.Session.InfiniteTimeSpan))) {
ThreadAbstraction.Sleep(50);
}
stopwatch.Stop();
}
}
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
private void InternalDispose(bool disposing)
{
if (disposing) {
if (_listener != null) {
_listener.Dispose();
_listener = null;
}
if (_stoppingListener != null) {
_stoppingListener.Dispose();
_stoppingListener = null;
}
}
}
protected override void Dispose(bool disposing)
{
if (!_isDisposed) {
base.Dispose(disposing);
if (disposing) {
EventWaitHandle listenerCompleted = _listenerCompleted;
if (listenerCompleted != null) {
listenerCompleted.Dispose();
_listenerCompleted = null;
}
}
InternalDispose(disposing);
_isDisposed = true;
}
}
~ForwardedPortDynamic()
{
Dispose(false);
}
private void Session_Disconnected(object sender, EventArgs e)
{
StopListener();
}
private void Session_ErrorOccured(object sender, ExceptionEventArgs e)
{
StopListener();
}
private void StartAccept()
{
SocketAsyncEventArgs socketAsyncEventArgs = new SocketAsyncEventArgs();
socketAsyncEventArgs.Completed += AcceptCompleted;
if (!_listener.AcceptAsync(socketAsyncEventArgs))
AcceptCompleted(null, socketAsyncEventArgs);
}
private void AcceptCompleted(object sender, SocketAsyncEventArgs acceptAsyncEventArgs)
{
if (acceptAsyncEventArgs.SocketError != 0) {
StartAccept();
acceptAsyncEventArgs.AcceptSocket.Dispose();
} else {
StartAccept();
ProcessAccept(acceptAsyncEventArgs.AcceptSocket);
}
}
private void ProcessAccept(Socket remoteSocket)
{
Interlocked.Increment(ref _pendingRequests);
try {
using (IChannelDirectTcpip channelDirectTcpip = base.Session.CreateChannelDirectTcpip()) {
channelDirectTcpip.Exception += Channel_Exception;
try {
if (!HandleSocks(channelDirectTcpip, remoteSocket, base.Session.ConnectionInfo.Timeout))
CloseSocket(remoteSocket);
else
channelDirectTcpip.Bind();
} finally {
channelDirectTcpip.Close();
}
}
} catch (SocketException ex) {
if (ex.SocketErrorCode != SocketError.Interrupted)
RaiseExceptionEvent(ex);
CloseSocket(remoteSocket);
} catch (Exception exception) {
RaiseExceptionEvent(exception);
CloseSocket(remoteSocket);
} finally {
Interlocked.Decrement(ref _pendingRequests);
}
}
private bool HandleSocks(IChannelDirectTcpip channel, Socket remoteSocket, TimeSpan timeout)
{
EventHandler value = delegate {
CloseSocket(remoteSocket);
};
base.Closing += value;
try {
int num = SocketAbstraction.ReadByte(remoteSocket, timeout);
switch (num) {
case -1:
return false;
case 4:
return HandleSocks4(remoteSocket, channel, timeout);
case 5:
return HandleSocks5(remoteSocket, channel, timeout);
default:
throw new NotSupportedException($"""{num}""");
}
} finally {
base.Closing -= value;
}
}
private static void CloseSocket(Socket socket)
{
if (socket.Connected) {
socket.Shutdown(SocketShutdown.Both);
socket.Dispose();
}
}
private bool HandleSocks4(Socket socket, IChannelDirectTcpip channel, TimeSpan timeout)
{
if (SocketAbstraction.ReadByte(socket, timeout) == 0)
return false;
byte[] array = new byte[2];
if (SocketAbstraction.Read(socket, array, 0, array.Length, timeout) == 0)
return false;
uint port = (uint)(array[0] * 256 + array[1]);
byte[] array2 = new byte[4];
if (SocketAbstraction.Read(socket, array2, 0, array2.Length, timeout) == 0)
return false;
IPAddress iPAddress = new IPAddress(array2);
if (ReadString(socket, timeout) == null)
return false;
string text = iPAddress.ToString();
RaiseRequestReceived(text, port);
channel.Open(text, port, this, socket);
SocketAbstraction.SendByte(socket, 0);
if (channel.IsOpen) {
SocketAbstraction.SendByte(socket, 90);
SocketAbstraction.Send(socket, array, 0, array.Length);
SocketAbstraction.Send(socket, array2, 0, array2.Length);
return true;
}
SocketAbstraction.SendByte(socket, 91);
return false;
}
private bool HandleSocks5(Socket socket, IChannelDirectTcpip channel, TimeSpan timeout)
{
int num = SocketAbstraction.ReadByte(socket, timeout);
if (num == -1)
return false;
byte[] array = new byte[num];
if (SocketAbstraction.Read(socket, array, 0, array.Length, timeout) != 0) {
if (array.Min() == 0)
SocketAbstraction.Send(socket, new byte[2] {
5,
0
}, 0, 2);
else
SocketAbstraction.Send(socket, new byte[2] {
5,
byte.MaxValue
}, 0, 2);
switch (SocketAbstraction.ReadByte(socket, timeout)) {
case -1:
return false;
default:
throw new ProxyException("SOCKS5: Version 5 is expected.");
case 5:
if (SocketAbstraction.ReadByte(socket, timeout) != -1) {
switch (SocketAbstraction.ReadByte(socket, timeout)) {
case -1:
return false;
default:
throw new ProxyException("SOCKS5: 0 is expected for reserved byte.");
case 0: {
int num2 = SocketAbstraction.ReadByte(socket, timeout);
IPAddress iPAddress;
switch (num2) {
case -1:
return false;
case 1: {
byte[] array2 = new byte[4];
if (SocketAbstraction.Read(socket, array2, 0, 4, timeout) == 0)
return false;
iPAddress = new IPAddress(array2);
break;
}
case 3: {
byte[] array2 = new byte[SocketAbstraction.ReadByte(socket, timeout)];
if (SocketAbstraction.Read(socket, array2, 0, array2.Length, timeout) == 0)
return false;
iPAddress = IPAddress.Parse(SshData.Ascii.GetString(array2));
break;
}
case 4: {
byte[] array2 = new byte[16];
if (SocketAbstraction.Read(socket, array2, 0, 16, timeout) == 0)
return false;
iPAddress = new IPAddress(array2);
break;
}
default:
throw new ProxyException($"""{num2}""");
}
byte[] array3 = new byte[2];
if (SocketAbstraction.Read(socket, array3, 0, array3.Length, timeout) == 0)
return false;
uint port = (uint)(array3[0] * 256 + array3[1]);
string text = iPAddress.ToString();
RaiseRequestReceived(text, port);
channel.Open(text, port, this, socket);
SocketAbstraction.SendByte(socket, 5);
if (channel.IsOpen)
SocketAbstraction.SendByte(socket, 0);
else
SocketAbstraction.SendByte(socket, 1);
SocketAbstraction.SendByte(socket, 0);
if (iPAddress.AddressFamily == AddressFamily.InterNetwork)
SocketAbstraction.SendByte(socket, 1);
else {
if (iPAddress.AddressFamily != AddressFamily.InterNetworkV6)
throw new NotSupportedException("Not supported address family.");
SocketAbstraction.SendByte(socket, 4);
}
byte[] addressBytes = iPAddress.GetAddressBytes();
SocketAbstraction.Send(socket, addressBytes, 0, addressBytes.Length);
SocketAbstraction.Send(socket, array3, 0, array3.Length);
return true;
}
}
}
return false;
}
}
return false;
}
private void Channel_Exception(object sender, ExceptionEventArgs e)
{
RaiseExceptionEvent(e.Exception);
}
private static string ReadString(Socket socket, TimeSpan timeout)
{
StringBuilder stringBuilder = new StringBuilder();
byte[] array = new byte[1];
while (true) {
if (SocketAbstraction.Read(socket, array, 0, 1, timeout) == 0)
return null;
byte b = array[0];
if (b == 0)
break;
char value = (char)b;
stringBuilder.Append(value);
}
return stringBuilder.ToString();
}
}
}