Salsa20Engine
Implementation of Daniel J. Bernstein's Salsa20 stream cipher, Snuffle 2005
using Org.BouncyCastle.Crypto.Parameters;
using Org.BouncyCastle.Crypto.Utilities;
using Org.BouncyCastle.Runtime.Intrinsics;
using Org.BouncyCastle.Runtime.Intrinsics.X86;
using Org.BouncyCastle.Utilities;
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
namespace Org.BouncyCastle.Crypto.Engines
{
public class Salsa20Engine : IStreamCipher
{
public static readonly int DEFAULT_ROUNDS = 20;
private const int StateSize = 16;
private static readonly uint[] TAU_SIGMA = Pack.LE_To_UInt32(Strings.ToAsciiByteArray("expand 16-byte kexpand 32-byte k"), 0, 8);
protected int rounds;
internal int index;
internal uint[] engineState = new uint[16];
internal uint[] x = new uint[16];
internal byte[] keyStream = new byte[64];
internal bool initialised;
private uint cW0;
private uint cW1;
private uint cW2;
protected virtual int NonceSize => 8;
public virtual string AlgorithmName {
get {
string text = "Salsa20";
if (rounds != DEFAULT_ROUNDS)
text = text + "/" + rounds.ToString();
return text;
}
}
internal static void PackTauOrSigma(int keyLength, uint[] state, int stateOffset)
{
int num = (keyLength - 16) / 4;
state[stateOffset] = TAU_SIGMA[num];
state[stateOffset + 1] = TAU_SIGMA[num + 1];
state[stateOffset + 2] = TAU_SIGMA[num + 2];
state[stateOffset + 3] = TAU_SIGMA[num + 3];
}
public Salsa20Engine()
: this(DEFAULT_ROUNDS)
{
}
public Salsa20Engine(int rounds)
{
if (rounds <= 0 || (rounds & 1) != 0)
throw new ArgumentException("'rounds' must be a positive, even number");
this.rounds = rounds;
}
public virtual void Init(bool forEncryption, ICipherParameters parameters)
{
ParametersWithIV obj = parameters as ParametersWithIV;
if (obj == null)
throw new ArgumentException(AlgorithmName + " Init requires an IV", "parameters");
byte[] iV = obj.GetIV();
if (iV == null || iV.Length != NonceSize)
throw new ArgumentException(AlgorithmName + " requires exactly " + NonceSize.ToString() + " bytes of IV");
ICipherParameters parameters2 = obj.Parameters;
if (parameters2 == null) {
if (!initialised)
throw new InvalidOperationException(AlgorithmName + " KeyParameter can not be null for first initialisation");
SetKey(null, iV);
} else {
if (!(parameters2 is KeyParameter))
throw new ArgumentException(AlgorithmName + " Init parameters must contain a KeyParameter (or null for re-init)");
SetKey(((KeyParameter)parameters2).GetKey(), iV);
}
Reset();
initialised = true;
}
public virtual byte ReturnByte(byte input)
{
if (LimitExceeded())
throw new MaxBytesExceededException("2^70 byte limit per IV; Change IV");
if (index == 0) {
GenerateKeyStream(keyStream);
AdvanceCounter();
}
byte result = (byte)(keyStream[index] ^ input);
index = ((index + 1) & 63);
return result;
}
protected virtual void AdvanceCounter()
{
if (++engineState[8] == 0)
engineState[9]++;
}
public virtual void ProcessBytes(byte[] inBytes, int inOff, int len, byte[] outBytes, int outOff)
{
if (!initialised)
throw new InvalidOperationException(AlgorithmName + " not initialised");
Check.DataLength(inBytes, inOff, len, "input buffer too short");
Check.OutputLength(outBytes, outOff, len, "output buffer too short");
if (LimitExceeded((uint)len))
throw new MaxBytesExceededException("2^70 byte limit per IV would be exceeded; Change IV");
for (int i = 0; i < len; i++) {
if (index == 0) {
GenerateKeyStream(keyStream);
AdvanceCounter();
}
outBytes[i + outOff] = (byte)(keyStream[index] ^ inBytes[i + inOff]);
index = ((index + 1) & 63);
}
}
public virtual void ProcessBytes(ReadOnlySpan<byte> input, Span<byte> output)
{
if (!initialised)
throw new InvalidOperationException(AlgorithmName + " not initialised");
Check.OutputLength(output, input.Length, "output buffer too short");
if (LimitExceeded((uint)input.Length))
throw new MaxBytesExceededException("2^70 byte limit per IV would be exceeded; Change IV");
for (int i = 0; i < input.Length; i++) {
if (index == 0) {
GenerateKeyStream(keyStream);
AdvanceCounter();
}
output[i] = (byte)(keyStream[index++] ^ input[i]);
index &= 63;
}
}
public virtual void Reset()
{
index = 0;
ResetLimitCounter();
ResetCounter();
}
protected virtual void ResetCounter()
{
engineState[8] = (engineState[9] = 0);
}
protected virtual void SetKey(byte[] keyBytes, byte[] ivBytes)
{
if (keyBytes != null) {
if (keyBytes.Length != 16 && keyBytes.Length != 32)
throw new ArgumentException(AlgorithmName + " requires 128 bit or 256 bit key");
int num = (keyBytes.Length - 16) / 4;
engineState[0] = TAU_SIGMA[num];
engineState[5] = TAU_SIGMA[num + 1];
engineState[10] = TAU_SIGMA[num + 2];
engineState[15] = TAU_SIGMA[num + 3];
Pack.LE_To_UInt32(keyBytes, 0, engineState, 1, 4);
Pack.LE_To_UInt32(keyBytes, keyBytes.Length - 16, engineState, 11, 4);
}
Pack.LE_To_UInt32(ivBytes, 0, engineState, 6, 2);
}
protected virtual void GenerateKeyStream(byte[] output)
{
SalsaCore(rounds, engineState, x);
Pack.UInt32_To_LE(x, output, 0);
}
internal static void SalsaCore(int rounds, ReadOnlySpan<uint> input, Span<uint> output)
{
if (input.Length < 16)
throw new ArgumentException();
if (output.Length < 16)
throw new ArgumentException();
if (rounds % 2 != 0)
throw new ArgumentException("Number of rounds must be even");
if (Org.BouncyCastle.Runtime.Intrinsics.X86.Sse41.IsEnabled && Vector.IsPackedLittleEndian) {
ReadOnlySpan<byte> readOnlySpan = MemoryMarshal.AsBytes(input.Slice(0, 16));
Vector128<ushort> left = MemoryMarshal.Read<Vector128<ushort>>(readOnlySpan.Slice(0, 16));
Vector128<ushort> left2 = MemoryMarshal.Read<Vector128<ushort>>(readOnlySpan.Slice(16, 16));
Vector128<ushort> right = MemoryMarshal.Read<Vector128<ushort>>(readOnlySpan.Slice(32, 16));
Vector128<ushort> right2 = MemoryMarshal.Read<Vector128<ushort>>(readOnlySpan.Slice(48, 16));
Vector128<ushort> left3 = System.Runtime.Intrinsics.X86.Sse41.Blend(left, right, 240);
Vector128<ushort> right3 = System.Runtime.Intrinsics.X86.Sse41.Blend(left2, right2, 195);
Vector128<ushort> left4 = System.Runtime.Intrinsics.X86.Sse41.Blend(left, right, 15);
Vector128<ushort> right4 = System.Runtime.Intrinsics.X86.Sse41.Blend(left2, right2, 60);
Vector128<uint> vector = System.Runtime.Intrinsics.X86.Sse41.Blend(left3, right3, 204).AsUInt32();
Vector128<uint> vector2 = System.Runtime.Intrinsics.X86.Sse41.Blend(left3, right3, 51).AsUInt32();
Vector128<uint> vector3 = System.Runtime.Intrinsics.X86.Sse41.Blend(left4, right4, 204).AsUInt32();
Vector128<uint> vector4 = System.Runtime.Intrinsics.X86.Sse41.Blend(left4, right4, 51).AsUInt32();
Vector128<uint> a = vector;
Vector128<uint> d = vector2;
Vector128<uint> c = vector3;
Vector128<uint> b = vector4;
for (int num = rounds; num > 0; num -= 2) {
QuarterRound_Sse2(ref a, ref b, ref c, ref d);
QuarterRound_Sse2(ref a, ref d, ref c, ref b);
}
vector = System.Runtime.Intrinsics.X86.Sse2.Add(vector, a);
vector2 = System.Runtime.Intrinsics.X86.Sse2.Add(vector2, d);
vector3 = System.Runtime.Intrinsics.X86.Sse2.Add(vector3, c);
vector4 = System.Runtime.Intrinsics.X86.Sse2.Add(vector4, b);
Vector128<ushort> left5 = vector.AsUInt16();
Vector128<ushort> right5 = vector2.AsUInt16();
Vector128<ushort> left6 = vector3.AsUInt16();
Vector128<ushort> right6 = vector4.AsUInt16();
Vector128<ushort> left7 = System.Runtime.Intrinsics.X86.Sse41.Blend(left5, right5, 204);
Vector128<ushort> left8 = System.Runtime.Intrinsics.X86.Sse41.Blend(left5, right5, 51);
Vector128<ushort> right7 = System.Runtime.Intrinsics.X86.Sse41.Blend(left6, right6, 204);
Vector128<ushort> right8 = System.Runtime.Intrinsics.X86.Sse41.Blend(left6, right6, 51);
Vector128<ushort> value = System.Runtime.Intrinsics.X86.Sse41.Blend(left7, right7, 240);
Vector128<ushort> value2 = System.Runtime.Intrinsics.X86.Sse41.Blend(left8, right8, 195);
Vector128<ushort> value3 = System.Runtime.Intrinsics.X86.Sse41.Blend(left7, right7, 15);
Vector128<ushort> value4 = System.Runtime.Intrinsics.X86.Sse41.Blend(left8, right8, 60);
Span<byte> span = MemoryMarshal.AsBytes(output.Slice(0, 16));
MemoryMarshal.Write(span.Slice(0, 16), ref value);
MemoryMarshal.Write(span.Slice(16, 16), ref value2);
MemoryMarshal.Write(span.Slice(32, 16), ref value3);
MemoryMarshal.Write(span.Slice(48, 16), ref value4);
} else {
uint a2 = input[0];
uint d2 = input[1];
uint c2 = input[2];
uint b2 = input[3];
uint b3 = input[4];
uint a3 = input[5];
uint d3 = input[6];
uint c3 = input[7];
uint c4 = input[8];
uint b4 = input[9];
uint a4 = input[10];
uint d4 = input[11];
uint d5 = input[12];
uint c5 = input[13];
uint b5 = input[14];
uint a5 = input[15];
for (int num2 = rounds; num2 > 0; num2 -= 2) {
QuarterRound(ref a2, ref b3, ref c4, ref d5);
QuarterRound(ref a3, ref b4, ref c5, ref d2);
QuarterRound(ref a4, ref b5, ref c2, ref d3);
QuarterRound(ref a5, ref b2, ref c3, ref d4);
QuarterRound(ref a2, ref d2, ref c2, ref b2);
QuarterRound(ref a3, ref d3, ref c3, ref b3);
QuarterRound(ref a4, ref d4, ref c4, ref b4);
QuarterRound(ref a5, ref d5, ref c5, ref b5);
}
output[0] = a2 + input[0];
output[1] = d2 + input[1];
output[2] = c2 + input[2];
output[3] = b2 + input[3];
output[4] = b3 + input[4];
output[5] = a3 + input[5];
output[6] = d3 + input[6];
output[7] = c3 + input[7];
output[8] = c4 + input[8];
output[9] = b4 + input[9];
output[10] = a4 + input[10];
output[11] = d4 + input[11];
output[12] = d5 + input[12];
output[13] = c5 + input[13];
output[14] = b5 + input[14];
output[15] = a5 + input[15];
}
}
internal void ResetLimitCounter()
{
cW0 = 0;
cW1 = 0;
cW2 = 0;
}
internal bool LimitExceeded()
{
if (++cW0 == 0 && ++cW1 == 0)
return (++cW2 & 32) != 0;
return false;
}
internal bool LimitExceeded(uint len)
{
uint num = cW0;
cW0 += len;
if (cW0 < num && ++cW1 == 0)
return (++cW2 & 32) != 0;
return false;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void QuarterRound(ref uint a, ref uint b, ref uint c, ref uint d)
{
b ^= Integers.RotateLeft(a + d, 7);
c ^= Integers.RotateLeft(b + a, 9);
d ^= Integers.RotateLeft(c + b, 13);
a ^= Integers.RotateLeft(d + c, 18);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void QuarterRound_Sse2(ref Vector128<uint> a, ref Vector128<uint> b, ref Vector128<uint> c, ref Vector128<uint> d)
{
b = System.Runtime.Intrinsics.X86.Sse2.Xor(b, Rotate_Sse2(System.Runtime.Intrinsics.X86.Sse2.Add(a, d), 7));
c = System.Runtime.Intrinsics.X86.Sse2.Xor(c, Rotate_Sse2(System.Runtime.Intrinsics.X86.Sse2.Add(b, a), 9));
d = System.Runtime.Intrinsics.X86.Sse2.Xor(d, Rotate_Sse2(System.Runtime.Intrinsics.X86.Sse2.Add(c, b), 13));
a = System.Runtime.Intrinsics.X86.Sse2.Xor(a, Rotate_Sse2(System.Runtime.Intrinsics.X86.Sse2.Add(d, c), 18));
b = System.Runtime.Intrinsics.X86.Sse2.Shuffle(b, 147);
c = System.Runtime.Intrinsics.X86.Sse2.Shuffle(c, 78);
d = System.Runtime.Intrinsics.X86.Sse2.Shuffle(d, 57);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector128<uint> Rotate_Sse2(Vector128<uint> x, byte sl)
{
byte count = (byte)(32 - sl);
return System.Runtime.Intrinsics.X86.Sse2.Xor(System.Runtime.Intrinsics.X86.Sse2.ShiftLeftLogical(x, sl), System.Runtime.Intrinsics.X86.Sse2.ShiftRightLogical(x, count));
}
}
}