ForwardedPortLocal
Provides functionality for local port forwarding
using Renci.SshNet.Abstractions;
using Renci.SshNet.Channels;
using Renci.SshNet.Common;
using System;
using System.Diagnostics;
using System.Net;
using System.Net.Sockets;
using System.Threading;
namespace Renci.SshNet
{
public class ForwardedPortLocal : ForwardedPort, IDisposable
{
private Socket _listener;
private int _pendingRequests;
private ManualResetEvent _stoppingListener;
private EventWaitHandle _listenerTaskCompleted;
private bool _isDisposed;
public string BoundHost { get; set; }
public uint BoundPort { get; set; }
public string Host { get; set; }
public uint Port { get; set; }
public override bool IsStarted {
get {
if (_listenerTaskCompleted != null)
return !_listenerTaskCompleted.WaitOne(0);
return false;
}
}
private void StartAccept()
{
SocketAsyncEventArgs socketAsyncEventArgs = new SocketAsyncEventArgs();
socketAsyncEventArgs.Completed += AcceptCompleted;
if (!_listener.AcceptAsync(socketAsyncEventArgs))
AcceptCompleted(null, socketAsyncEventArgs);
}
private void AcceptCompleted(object sender, SocketAsyncEventArgs acceptAsyncEventArgs)
{
if (acceptAsyncEventArgs.SocketError != 0) {
StartAccept();
acceptAsyncEventArgs.AcceptSocket.Dispose();
} else {
StartAccept();
ProcessAccept(acceptAsyncEventArgs.AcceptSocket);
}
}
private void ProcessAccept(Socket clientSocket)
{
Interlocked.Increment(ref _pendingRequests);
try {
IPEndPoint iPEndPoint = (IPEndPoint)clientSocket.RemoteEndPoint;
RaiseRequestReceived(iPEndPoint.Address.ToString(), (uint)iPEndPoint.Port);
using (IChannelDirectTcpip channelDirectTcpip = base.Session.CreateChannelDirectTcpip()) {
channelDirectTcpip.Exception += Channel_Exception;
channelDirectTcpip.Open(Host, Port, this, clientSocket);
channelDirectTcpip.Bind();
channelDirectTcpip.Close();
}
} catch (Exception exception) {
RaiseExceptionEvent(exception);
CloseSocket(clientSocket);
} finally {
Interlocked.Decrement(ref _pendingRequests);
}
}
private static void CloseSocket(Socket socket)
{
if (socket.Connected) {
socket.Shutdown(SocketShutdown.Both);
socket.Dispose();
}
}
private void Session_ErrorOccured(object sender, ExceptionEventArgs e)
{
StopListener();
}
private void Session_Disconnected(object sender, EventArgs e)
{
StopListener();
}
private void Channel_Exception(object sender, ExceptionEventArgs e)
{
RaiseExceptionEvent(e.Exception);
}
public ForwardedPortLocal(uint boundPort, string host, uint port)
: this(string.Empty, boundPort, host, port)
{
}
public ForwardedPortLocal(string boundHost, string host, uint port)
: this(boundHost, 0, host, port)
{
}
public ForwardedPortLocal(string boundHost, uint boundPort, string host, uint port)
{
if (boundHost == null)
throw new ArgumentNullException("boundHost");
if (host == null)
throw new ArgumentNullException("host");
boundPort.ValidatePort("boundPort");
port.ValidatePort("port");
BoundHost = boundHost;
BoundPort = boundPort;
Host = host;
Port = port;
}
protected override void StartPort()
{
InternalStart();
}
protected override void StopPort(TimeSpan timeout)
{
if (IsStarted) {
StopListener();
base.StopPort(timeout);
}
InternalStop(timeout);
}
protected override void CheckDisposed()
{
if (_isDisposed)
throw new ObjectDisposedException(GetType().FullName);
}
private void InternalStart()
{
IPEndPoint iPEndPoint = new IPEndPoint(DnsAbstraction.GetHostAddresses(BoundHost)[0], (int)BoundPort);
_listener = new Socket(iPEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
_listener.Bind(iPEndPoint);
_listener.Listen(1);
BoundPort = (uint)((IPEndPoint)_listener.LocalEndPoint).Port;
base.Session.ErrorOccured += Session_ErrorOccured;
base.Session.Disconnected += Session_Disconnected;
_listenerTaskCompleted = new ManualResetEvent(false);
ThreadAbstraction.ExecuteThread(delegate {
try {
_stoppingListener = new ManualResetEvent(false);
StartAccept();
_stoppingListener.WaitOne();
} catch (ObjectDisposedException) {
} catch (Exception exception) {
RaiseExceptionEvent(exception);
} finally {
_listenerTaskCompleted.Set();
}
});
}
private void StopListener()
{
if (IsStarted) {
base.Session.Disconnected -= Session_Disconnected;
base.Session.ErrorOccured -= Session_ErrorOccured;
_stoppingListener.Set();
_listener.Dispose();
_listenerTaskCompleted.WaitOne();
}
}
private void InternalStop(TimeSpan timeout)
{
if (!(timeout == TimeSpan.Zero)) {
Stopwatch stopwatch = new Stopwatch();
stopwatch.Start();
while (Interlocked.CompareExchange(ref _pendingRequests, 0, 0) != 0 && (!(stopwatch.Elapsed >= timeout) || !(timeout != Renci.SshNet.Session.InfiniteTimeSpan))) {
ThreadAbstraction.Sleep(50);
}
stopwatch.Stop();
}
}
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
private void InternalDispose(bool disposing)
{
if (disposing) {
if (_listener != null) {
_listener.Dispose();
_listener = null;
}
if (_stoppingListener != null) {
_stoppingListener.Dispose();
_stoppingListener = null;
}
}
}
protected override void Dispose(bool disposing)
{
if (!_isDisposed) {
base.Dispose(disposing);
if (disposing) {
EventWaitHandle listenerTaskCompleted = _listenerTaskCompleted;
if (listenerTaskCompleted != null) {
listenerTaskCompleted.Dispose();
_listenerTaskCompleted = null;
}
}
InternalDispose(disposing);
_isDisposed = true;
}
}
~ForwardedPortLocal()
{
Dispose(false);
}
}
}