Salsa20Engine
Implementation of Daniel J. Bernstein's Salsa20 stream cipher, Snuffle 2005
using Org.BouncyCastle.Crypto.Parameters;
using Org.BouncyCastle.Crypto.Utilities;
using Org.BouncyCastle.Utilities;
using System;
using System.Runtime.CompilerServices;
namespace Org.BouncyCastle.Crypto.Engines
{
public class Salsa20Engine : IStreamCipher
{
public static readonly int DEFAULT_ROUNDS = 20;
private const int StateSize = 16;
private static readonly uint[] TAU_SIGMA = Pack.LE_To_UInt32(Strings.ToAsciiByteArray("expand 16-byte kexpand 32-byte k"), 0, 8);
protected int rounds;
internal int index;
internal uint[] engineState = new uint[16];
internal uint[] x = new uint[16];
internal byte[] keyStream = new byte[64];
internal bool initialised;
private uint cW0;
private uint cW1;
private uint cW2;
protected virtual int NonceSize => 8;
public virtual string AlgorithmName {
get {
string text = "Salsa20";
if (rounds != DEFAULT_ROUNDS)
text = text + "/" + rounds.ToString();
return text;
}
}
internal static void PackTauOrSigma(int keyLength, uint[] state, int stateOffset)
{
int num = (keyLength - 16) / 4;
state[stateOffset] = TAU_SIGMA[num];
state[stateOffset + 1] = TAU_SIGMA[num + 1];
state[stateOffset + 2] = TAU_SIGMA[num + 2];
state[stateOffset + 3] = TAU_SIGMA[num + 3];
}
public Salsa20Engine()
: this(DEFAULT_ROUNDS)
{
}
public Salsa20Engine(int rounds)
{
if (rounds <= 0 || (rounds & 1) != 0)
throw new ArgumentException("'rounds' must be a positive, even number");
this.rounds = rounds;
}
public virtual void Init(bool forEncryption, ICipherParameters parameters)
{
ParametersWithIV obj = parameters as ParametersWithIV;
if (obj == null)
throw new ArgumentException(AlgorithmName + " Init requires an IV", "parameters");
byte[] iV = obj.GetIV();
if (iV == null || iV.Length != NonceSize)
throw new ArgumentException(AlgorithmName + " requires exactly " + NonceSize.ToString() + " bytes of IV");
ICipherParameters parameters2 = obj.Parameters;
if (parameters2 == null) {
if (!initialised)
throw new InvalidOperationException(AlgorithmName + " KeyParameter can not be null for first initialisation");
SetKey(null, iV);
} else {
if (!(parameters2 is KeyParameter))
throw new ArgumentException(AlgorithmName + " Init parameters must contain a KeyParameter (or null for re-init)");
SetKey(((KeyParameter)parameters2).GetKey(), iV);
}
Reset();
initialised = true;
}
public virtual byte ReturnByte(byte input)
{
if (LimitExceeded())
throw new MaxBytesExceededException("2^70 byte limit per IV; Change IV");
if (index == 0) {
GenerateKeyStream(keyStream);
AdvanceCounter();
}
byte result = (byte)(keyStream[index] ^ input);
index = ((index + 1) & 63);
return result;
}
protected virtual void AdvanceCounter()
{
if (++engineState[8] == 0)
engineState[9]++;
}
public virtual void ProcessBytes(byte[] inBytes, int inOff, int len, byte[] outBytes, int outOff)
{
if (!initialised)
throw new InvalidOperationException(AlgorithmName + " not initialised");
Check.DataLength(inBytes, inOff, len, "input buffer too short");
Check.OutputLength(outBytes, outOff, len, "output buffer too short");
if (LimitExceeded((uint)len))
throw new MaxBytesExceededException("2^70 byte limit per IV would be exceeded; Change IV");
for (int i = 0; i < len; i++) {
if (index == 0) {
GenerateKeyStream(keyStream);
AdvanceCounter();
}
outBytes[i + outOff] = (byte)(keyStream[index] ^ inBytes[i + inOff]);
index = ((index + 1) & 63);
}
}
public virtual void Reset()
{
index = 0;
ResetLimitCounter();
ResetCounter();
}
protected virtual void ResetCounter()
{
engineState[8] = (engineState[9] = 0);
}
protected virtual void SetKey(byte[] keyBytes, byte[] ivBytes)
{
if (keyBytes != null) {
if (keyBytes.Length != 16 && keyBytes.Length != 32)
throw new ArgumentException(AlgorithmName + " requires 128 bit or 256 bit key");
int num = (keyBytes.Length - 16) / 4;
engineState[0] = TAU_SIGMA[num];
engineState[5] = TAU_SIGMA[num + 1];
engineState[10] = TAU_SIGMA[num + 2];
engineState[15] = TAU_SIGMA[num + 3];
Pack.LE_To_UInt32(keyBytes, 0, engineState, 1, 4);
Pack.LE_To_UInt32(keyBytes, keyBytes.Length - 16, engineState, 11, 4);
}
Pack.LE_To_UInt32(ivBytes, 0, engineState, 6, 2);
}
protected virtual void GenerateKeyStream(byte[] output)
{
SalsaCore(rounds, engineState, x);
Pack.UInt32_To_LE(x, output, 0);
}
internal static void SalsaCore(int rounds, uint[] input, uint[] output)
{
if (input.Length < 16)
throw new ArgumentException();
if (output.Length < 16)
throw new ArgumentException();
if (rounds % 2 != 0)
throw new ArgumentException("Number of rounds must be even");
uint a = input[0];
uint d = input[1];
uint c = input[2];
uint b = input[3];
uint b2 = input[4];
uint a2 = input[5];
uint d2 = input[6];
uint c2 = input[7];
uint c3 = input[8];
uint b3 = input[9];
uint a3 = input[10];
uint d3 = input[11];
uint d4 = input[12];
uint c4 = input[13];
uint b4 = input[14];
uint a4 = input[15];
for (int num = rounds; num > 0; num -= 2) {
QuarterRound(ref a, ref b2, ref c3, ref d4);
QuarterRound(ref a2, ref b3, ref c4, ref d);
QuarterRound(ref a3, ref b4, ref c, ref d2);
QuarterRound(ref a4, ref b, ref c2, ref d3);
QuarterRound(ref a, ref d, ref c, ref b);
QuarterRound(ref a2, ref d2, ref c2, ref b2);
QuarterRound(ref a3, ref d3, ref c3, ref b3);
QuarterRound(ref a4, ref d4, ref c4, ref b4);
}
output[0] = a + input[0];
output[1] = d + input[1];
output[2] = c + input[2];
output[3] = b + input[3];
output[4] = b2 + input[4];
output[5] = a2 + input[5];
output[6] = d2 + input[6];
output[7] = c2 + input[7];
output[8] = c3 + input[8];
output[9] = b3 + input[9];
output[10] = a3 + input[10];
output[11] = d3 + input[11];
output[12] = d4 + input[12];
output[13] = c4 + input[13];
output[14] = b4 + input[14];
output[15] = a4 + input[15];
}
internal void ResetLimitCounter()
{
cW0 = 0;
cW1 = 0;
cW2 = 0;
}
internal bool LimitExceeded()
{
if (++cW0 == 0 && ++cW1 == 0)
return (++cW2 & 32) != 0;
return false;
}
internal bool LimitExceeded(uint len)
{
uint num = cW0;
cW0 += len;
if (cW0 < num && ++cW1 == 0)
return (++cW2 & 32) != 0;
return false;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void QuarterRound(ref uint a, ref uint b, ref uint c, ref uint d)
{
b ^= Integers.RotateLeft(a + d, 7);
c ^= Integers.RotateLeft(b + a, 9);
d ^= Integers.RotateLeft(c + b, 13);
a ^= Integers.RotateLeft(d + c, 18);
}
}
}