ChaCha20Poly1305
using Org.BouncyCastle.Crypto.Engines;
using Org.BouncyCastle.Crypto.Macs;
using Org.BouncyCastle.Crypto.Parameters;
using Org.BouncyCastle.Crypto.Utilities;
using Org.BouncyCastle.Utilities;
using System;
namespace Org.BouncyCastle.Crypto.Modes
{
public class ChaCha20Poly1305 : IAeadCipher
{
private enum State
{
Uninitialized,
EncInit,
EncAad,
EncData,
EncFinal,
DecInit,
DecAad,
DecData,
DecFinal
}
private const int BufSize = 64;
private const int KeySize = 32;
private const int NonceSize = 12;
private const int MacSize = 16;
private static readonly byte[] Zeroes = new byte[15];
private const ulong AadLimit = ulong.MaxValue;
private const ulong DataLimit = 274877906880;
private readonly ChaCha7539Engine mChacha20;
private readonly IMac mPoly1305;
private readonly byte[] mKey = new byte[32];
private readonly byte[] mNonce = new byte[12];
private readonly byte[] mBuf = new byte[80];
private readonly byte[] mMac = new byte[16];
private byte[] mInitialAad;
private ulong mAadCount;
private ulong mDataCount;
private State mState;
private int mBufPos;
public virtual string AlgorithmName => "ChaCha20Poly1305";
public ChaCha20Poly1305()
: this(new Poly1305())
{
}
public ChaCha20Poly1305(IMac poly1305)
{
if (poly1305 == null)
throw new ArgumentNullException("poly1305");
if (16 != poly1305.GetMacSize())
throw new ArgumentException("must be a 128-bit MAC", "poly1305");
mChacha20 = new ChaCha7539Engine();
mPoly1305 = poly1305;
}
public virtual void Init(bool forEncryption, ICipherParameters parameters)
{
AeadParameters aeadParameters = parameters as AeadParameters;
KeyParameter keyParameter;
ReadOnlySpan<byte> readOnlySpan;
ICipherParameters parameters2;
if (aeadParameters != null) {
int macSize = aeadParameters.MacSize;
if (128 != macSize)
throw new ArgumentException("Invalid value for MAC size: " + macSize.ToString());
keyParameter = aeadParameters.Key;
readOnlySpan = aeadParameters.Nonce;
parameters2 = new ParametersWithIV(keyParameter, readOnlySpan);
mInitialAad = aeadParameters.GetAssociatedText();
} else {
ParametersWithIV parametersWithIV = parameters as ParametersWithIV;
if (parametersWithIV == null)
throw new ArgumentException("invalid parameters passed to ChaCha20Poly1305", "parameters");
keyParameter = (KeyParameter)parametersWithIV.Parameters;
readOnlySpan = parametersWithIV.IV;
parameters2 = parametersWithIV;
mInitialAad = null;
}
if (keyParameter == null) {
if (mState == State.Uninitialized)
throw new ArgumentException("Key must be specified in initial init");
} else if (32 != keyParameter.KeyLength) {
throw new ArgumentException("Key must be 256 bits");
}
if (12 != readOnlySpan.Length)
throw new ArgumentException("Nonce must be 96 bits");
if (((mState != State.Uninitialized) & forEncryption) && readOnlySpan.SequenceEqual(mNonce) && (keyParameter == null || keyParameter.FixedTimeEquals(mKey)))
throw new ArgumentException("cannot reuse nonce for ChaCha20Poly1305 encryption");
keyParameter?.CopyTo(mKey, 0, 32);
readOnlySpan.CopyTo(mNonce);
mChacha20.Init(true, parameters2);
mState = (forEncryption ? State.EncInit : State.DecInit);
Reset(true, false);
}
public virtual int GetOutputSize(int len)
{
int num = System.Math.Max(0, len);
switch (mState) {
case State.DecInit:
case State.DecAad:
return System.Math.Max(0, num - 16);
case State.DecData:
case State.DecFinal:
return System.Math.Max(0, num + mBufPos - 16);
case State.EncData:
case State.EncFinal:
return num + mBufPos + 16;
default:
return num + 16;
}
}
public virtual int GetUpdateOutputSize(int len)
{
int num = System.Math.Max(0, len);
switch (mState) {
case State.DecInit:
case State.DecAad:
num = System.Math.Max(0, num - 16);
break;
case State.DecData:
case State.DecFinal:
num = System.Math.Max(0, num + mBufPos - 16);
break;
case State.EncData:
case State.EncFinal:
num += mBufPos;
break;
}
return num - num % 64;
}
public virtual void ProcessAadByte(byte input)
{
CheckAad();
mAadCount = IncrementCount(mAadCount, 1, ulong.MaxValue);
mPoly1305.Update(input);
}
public virtual void ProcessAadBytes(byte[] inBytes, int inOff, int len)
{
if (inBytes == null)
throw new ArgumentNullException("inBytes");
if (inOff < 0)
throw new ArgumentException("cannot be negative", "inOff");
if (len < 0)
throw new ArgumentException("cannot be negative", "len");
Check.DataLength(inBytes, inOff, len, "input buffer too short");
CheckAad();
if (len > 0) {
mAadCount = IncrementCount(mAadCount, (uint)len, ulong.MaxValue);
mPoly1305.BlockUpdate(inBytes, inOff, len);
}
}
public virtual void ProcessAadBytes(ReadOnlySpan<byte> input)
{
CheckAad();
if (!input.IsEmpty) {
mAadCount = IncrementCount(mAadCount, (uint)input.Length, ulong.MaxValue);
mPoly1305.BlockUpdate(input);
}
}
public virtual int ProcessByte(byte input, byte[] outBytes, int outOff)
{
CheckData();
switch (mState) {
case State.DecData:
mBuf[mBufPos] = input;
if (++mBufPos == mBuf.Length) {
mPoly1305.BlockUpdate(mBuf, 0, 64);
ProcessBlock(mBuf, outBytes.AsSpan(outOff));
Array.Copy(mBuf, 64, mBuf, 0, 16);
mBufPos = 16;
return 64;
}
return 0;
case State.EncData:
mBuf[mBufPos] = input;
if (++mBufPos == 64) {
ProcessBlock(mBuf, outBytes.AsSpan(outOff));
mPoly1305.BlockUpdate(outBytes, outOff, 64);
mBufPos = 0;
return 64;
}
return 0;
default:
throw new InvalidOperationException();
}
}
public virtual int ProcessByte(byte input, Span<byte> output)
{
CheckData();
switch (mState) {
case State.DecData:
mBuf[mBufPos] = input;
if (++mBufPos == mBuf.Length) {
mPoly1305.BlockUpdate(mBuf.AsSpan(0, 64));
ProcessBlock(mBuf, output);
Array.Copy(mBuf, 64, mBuf, 0, 16);
mBufPos = 16;
return 64;
}
return 0;
case State.EncData:
mBuf[mBufPos] = input;
if (++mBufPos == 64) {
ProcessBlock(mBuf, output);
mPoly1305.BlockUpdate(output.Slice(0, 64));
mBufPos = 0;
return 64;
}
return 0;
default:
throw new InvalidOperationException();
}
}
public virtual int ProcessBytes(byte[] inBytes, int inOff, int len, byte[] outBytes, int outOff)
{
if (inBytes == null)
throw new ArgumentNullException("inBytes");
if (inOff < 0)
throw new ArgumentException("cannot be negative", "inOff");
if (len < 0)
throw new ArgumentException("cannot be negative", "len");
Check.DataLength(inBytes, inOff, len, "input buffer too short");
if (outOff < 0)
throw new ArgumentException("cannot be negative", "outOff");
return ProcessBytes(inBytes.AsSpan(inOff, len), Spans.FromNullable(outBytes, outOff));
}
public virtual int ProcessBytes(ReadOnlySpan<byte> input, Span<byte> output)
{
CheckData();
int num = 0;
ReadOnlySpan<byte> readOnlySpan;
switch (mState) {
case State.DecData: {
int num4 = mBuf.Length - mBufPos;
if (input.Length < num4) {
input.CopyTo(mBuf.AsSpan(mBufPos));
mBufPos += input.Length;
} else {
if (mBufPos >= 64) {
mPoly1305.BlockUpdate(mBuf.AsSpan(0, 64));
ProcessBlock(mBuf, output);
Array.Copy(mBuf, 64, mBuf, 0, mBufPos -= 64);
num = 64;
num4 += 64;
if (input.Length < num4) {
input.CopyTo(mBuf.AsSpan(mBufPos));
mBufPos += input.Length;
break;
}
}
int num5 = mBuf.Length;
int num6 = num5 + 64;
num4 = 64 - mBufPos;
readOnlySpan = input.Slice(0, num4);
readOnlySpan.CopyTo(mBuf.AsSpan(mBufPos));
mPoly1305.BlockUpdate(mBuf.AsSpan(0, 64));
ReadOnlySpan<byte> input4 = mBuf;
int num3 = num;
ProcessBlock(input4, output.Slice(num3, output.Length - num3));
num3 = num4;
input = input.Slice(num3, input.Length - num3);
num += 64;
while (input.Length >= num6) {
mPoly1305.BlockUpdate(input.Slice(0, 128));
ReadOnlySpan<byte> input5 = input;
num3 = num;
ProcessBlocks2(input5, output.Slice(num3, output.Length - num3));
input = input.Slice(128, input.Length - 128);
num += 128;
}
if (input.Length >= num5) {
mPoly1305.BlockUpdate(input.Slice(0, 64));
ReadOnlySpan<byte> input6 = input;
num3 = num;
ProcessBlock(input6, output.Slice(num3, output.Length - num3));
input = input.Slice(64, input.Length - 64);
num += 64;
}
mBufPos = input.Length;
input.CopyTo(mBuf);
}
break;
}
case State.EncData: {
int num2 = 64 - mBufPos;
if (input.Length < num2) {
input.CopyTo(mBuf.AsSpan(mBufPos));
mBufPos += input.Length;
} else {
if (mBufPos > 0) {
readOnlySpan = input.Slice(0, num2);
readOnlySpan.CopyTo(mBuf.AsSpan(mBufPos));
ProcessBlock(mBuf, output);
int num3 = num2;
input = input.Slice(num3, input.Length - num3);
num = 64;
}
while (input.Length >= 128) {
ReadOnlySpan<byte> input2 = input;
int num3 = num;
ProcessBlocks2(input2, output.Slice(num3, output.Length - num3));
input = input.Slice(128, input.Length - 128);
num += 128;
}
if (input.Length >= 64) {
ReadOnlySpan<byte> input3 = input;
int num3 = num;
ProcessBlock(input3, output.Slice(num3, output.Length - num3));
input = input.Slice(64, input.Length - 64);
num += 64;
}
mPoly1305.BlockUpdate(output.Slice(0, num));
mBufPos = input.Length;
input.CopyTo(mBuf);
}
break;
}
default:
throw new InvalidOperationException();
}
return num;
}
public virtual int DoFinal(byte[] outBytes, int outOff)
{
if (outBytes == null)
throw new ArgumentNullException("outBytes");
if (outOff < 0)
throw new ArgumentException("cannot be negative", "outOff");
return DoFinal(outBytes.AsSpan(outOff));
}
public virtual int DoFinal(Span<byte> output)
{
CheckData();
Array.Clear(mMac, 0, 16);
int num = 0;
switch (mState) {
case State.DecData:
if (mBufPos < 16)
throw new InvalidCipherTextException("data too short");
num = mBufPos - 16;
Check.OutputLength(output, num, "output buffer too short");
if (num > 0) {
mPoly1305.BlockUpdate(mBuf, 0, num);
ProcessData(mBuf.AsSpan(0, num), output);
}
FinishData(State.DecFinal);
if (!Arrays.FixedTimeEquals(16, mMac, 0, mBuf, num))
throw new InvalidCipherTextException("mac check in ChaCha20Poly1305 failed");
break;
case State.EncData: {
num = mBufPos + 16;
Check.OutputLength(output, num, "output buffer too short");
if (mBufPos > 0) {
ProcessData(mBuf.AsSpan(0, mBufPos), output);
mPoly1305.BlockUpdate(output.Slice(0, mBufPos));
}
FinishData(State.EncFinal);
Span<byte> span = mMac.AsSpan(0, 16);
int num2 = mBufPos;
span.CopyTo(output.Slice(num2, output.Length - num2));
break;
}
default:
throw new InvalidOperationException();
}
Reset(false, true);
return num;
}
public virtual byte[] GetMac()
{
return Arrays.Clone(mMac);
}
public virtual void Reset()
{
Reset(true, true);
}
private void CheckAad()
{
switch (mState) {
case State.EncAad:
case State.DecAad:
break;
case State.DecInit:
mState = State.DecAad;
break;
case State.EncInit:
mState = State.EncAad;
break;
case State.EncFinal:
throw new InvalidOperationException(AlgorithmName + " cannot be reused for encryption");
default:
throw new InvalidOperationException(AlgorithmName + " needs to be initialized");
}
}
private void CheckData()
{
switch (mState) {
case State.EncData:
case State.DecData:
break;
case State.DecInit:
case State.DecAad:
FinishAad(State.DecData);
break;
case State.EncInit:
case State.EncAad:
FinishAad(State.EncData);
break;
case State.EncFinal:
throw new InvalidOperationException(AlgorithmName + " cannot be reused for encryption");
default:
throw new InvalidOperationException(AlgorithmName + " needs to be initialized");
}
}
private void FinishAad(State nextState)
{
PadMac(mAadCount);
mState = nextState;
}
private void FinishData(State nextState)
{
PadMac(mDataCount);
byte[] array = new byte[16];
Pack.UInt64_To_LE(mAadCount, array, 0);
Pack.UInt64_To_LE(mDataCount, array, 8);
mPoly1305.BlockUpdate(array, 0, 16);
mPoly1305.DoFinal(mMac, 0);
mState = nextState;
}
private ulong IncrementCount(ulong count, uint increment, ulong limit)
{
if (count > limit - increment)
throw new InvalidOperationException("Limit exceeded");
return count + increment;
}
private unsafe void InitMac()
{
Span<byte> span = new Span<byte>(stackalloc byte[64], 64);
try {
mChacha20.ProcessBytes(span, span);
mPoly1305.Init(new KeyParameter(span.Slice(0, 32)));
} finally {
span.Fill(0);
}
}
private void PadMac(ulong count)
{
int num = (int)count & 15;
if (num != 0)
mPoly1305.BlockUpdate(Zeroes, 0, 16 - num);
}
private void ProcessBlock(ReadOnlySpan<byte> input, Span<byte> output)
{
Check.OutputLength(output, 64, "output buffer too short");
mChacha20.ProcessBlock(input, output);
mDataCount = IncrementCount(mDataCount, 64, 274877906880);
}
private void ProcessBlocks2(ReadOnlySpan<byte> input, Span<byte> output)
{
Check.OutputLength(output, 128, "output buffer too short");
mChacha20.ProcessBlocks2(input, output);
mDataCount = IncrementCount(mDataCount, 128, 274877906880);
}
private void ProcessData(ReadOnlySpan<byte> input, Span<byte> output)
{
Check.OutputLength(output, input.Length, "output buffer too short");
mChacha20.ProcessBytes(input, output);
mDataCount = IncrementCount(mDataCount, (uint)input.Length, 274877906880);
}
private void Reset(bool clearMac, bool resetCipher)
{
Array.Clear(mBuf, 0, mBuf.Length);
if (clearMac)
Array.Clear(mMac, 0, mMac.Length);
mAadCount = 0;
mDataCount = 0;
mBufPos = 0;
switch (mState) {
case State.DecAad:
case State.DecData:
case State.DecFinal:
mState = State.DecInit;
break;
case State.EncAad:
case State.EncData:
case State.EncFinal:
mState = State.EncFinal;
return;
default:
throw new InvalidOperationException(AlgorithmName + " needs to be initialized");
case State.EncInit:
case State.DecInit:
break;
}
if (resetCipher)
mChacha20.Reset();
InitMac();
if (mInitialAad != null)
ProcessAadBytes(mInitialAad);
}
}
}