DilithiumEngine
class DilithiumEngine
using Org.BouncyCastle.Crypto.Digests;
using Org.BouncyCastle.Security;
using Org.BouncyCastle.Utilities;
using System;
namespace Org.BouncyCastle.Pqc.Crypto.Crystals.Dilithium
{
internal class DilithiumEngine
{
internal const int N = 256;
internal const int Q = 8380417;
internal const int QInv = 58728449;
internal const int D = 13;
internal const int RootOfUnity = 1753;
internal const int SeedBytes = 32;
internal const int CrhBytes = 64;
internal const int RndBytes = 32;
internal const int TrBytes = 64;
internal const int PolyT1PackedBytes = 320;
internal const int PolyT0PackedBytes = 416;
internal int Mode { get; set; }
internal SecureRandom Random { get; set; }
internal int K { get; set; }
internal int L { get; set; }
internal int Eta { get; set; }
internal int Tau { get; set; }
internal int Beta { get; set; }
internal int Gamma1 { get; set; }
internal int Gamma2 { get; set; }
internal int Omega { get; set; }
internal int CTilde { get; set; }
internal int PolyVecHPackedBytes { get; set; }
internal int PolyZPackedBytes { get; set; }
internal int PolyW1PackedBytes { get; set; }
internal int PolyEtaPackedBytes { get; set; }
internal int CryptoPublicKeyBytes { get; set; }
internal int CryptoSecretKeyBytes { get; set; }
internal int CryptoBytes { get; set; }
internal int PolyUniformGamma1NBytes { get; set; }
internal Symmetric Symmetric { get; set; }
internal DilithiumEngine(int mode, SecureRandom random, bool usingAes)
{
Mode = mode;
Random = random;
switch (Mode) {
case 2:
K = 4;
L = 4;
Eta = 2;
Tau = 39;
Beta = 78;
Gamma1 = 131072;
Gamma2 = 95232;
Omega = 80;
PolyZPackedBytes = 576;
PolyW1PackedBytes = 192;
PolyEtaPackedBytes = 96;
CTilde = 32;
break;
case 3:
K = 6;
L = 5;
Eta = 4;
Tau = 49;
Beta = 196;
Gamma1 = 524288;
Gamma2 = 261888;
Omega = 55;
PolyZPackedBytes = 640;
PolyW1PackedBytes = 128;
PolyEtaPackedBytes = 128;
CTilde = 48;
break;
case 5:
K = 8;
L = 7;
Eta = 2;
Tau = 60;
Beta = 120;
Gamma1 = 524288;
Gamma2 = 261888;
Omega = 75;
PolyZPackedBytes = 640;
PolyW1PackedBytes = 128;
PolyEtaPackedBytes = 96;
CTilde = 64;
break;
default:
throw new ArgumentException("The mode " + mode.ToString() + "is not supported by Crystals Dilithium!");
}
if (usingAes)
Symmetric = new Symmetric.AesSymmetric();
else
Symmetric = new Symmetric.ShakeSymmetric();
PolyVecHPackedBytes = Omega + K;
CryptoPublicKeyBytes = 32 + K * 320;
CryptoSecretKeyBytes = 128 + L * PolyEtaPackedBytes + K * PolyEtaPackedBytes + K * 416;
CryptoBytes = CTilde + L * PolyZPackedBytes + PolyVecHPackedBytes;
if (Gamma1 == 131072)
PolyUniformGamma1NBytes = (576 + Symmetric.Stream256BlockBytes - 1) / Symmetric.Stream256BlockBytes;
else {
if (Gamma1 != 524288)
throw new ArgumentException("Wrong Dilithium Gamma1!");
PolyUniformGamma1NBytes = (640 + Symmetric.Stream256BlockBytes - 1) / Symmetric.Stream256BlockBytes;
}
}
internal static byte[] CalculatePublicKeyHash(byte[] rho, byte[] encT1)
{
byte[] array = new byte[64];
ShakeDigest shakeDigest = new ShakeDigest(256);
shakeDigest.BlockUpdate(rho, 0, rho.Length);
shakeDigest.BlockUpdate(encT1, 0, encT1.Length);
shakeDigest.OutputFinal(array, 0, 64);
return array;
}
internal void GenerateKeyPair(bool legacy, out byte[] rho, out byte[] k, out byte[] tr, out byte[] s1, out byte[] s2, out byte[] t0, out byte[] encT1, out byte[] seed)
{
seed = SecureRandom.GetNextBytes(Random, 32);
GenerateKeyPairInternal(seed, legacy, out rho, out k, out tr, out s1, out s2, out t0, out encT1);
}
internal void GenerateKeyPairInternal(byte[] seed, bool legacy, out byte[] rho, out byte[] k, out byte[] tr, out byte[] s1_, out byte[] s2_, out byte[] t0_, out byte[] encT1)
{
byte[] array = new byte[128];
byte[] array2 = new byte[64];
tr = new byte[64];
rho = new byte[32];
k = new byte[32];
s1_ = new byte[L * PolyEtaPackedBytes];
s2_ = new byte[K * PolyEtaPackedBytes];
t0_ = new byte[K * 416];
PolyVecMatrix polyVecMatrix = new PolyVecMatrix(this);
PolyVec polyVec = new PolyVec(this, L);
PolyVec polyVec2 = new PolyVec(this, K);
PolyVec polyVec3 = new PolyVec(this, K);
PolyVec polyVec4 = new PolyVec(this, K);
ShakeDigest shakeDigest = new ShakeDigest(256);
shakeDigest.BlockUpdate(seed, 0, 32);
if (!legacy) {
shakeDigest.Update((byte)K);
shakeDigest.Update((byte)L);
}
shakeDigest.OutputFinal(array, 0, 128);
rho = Arrays.CopyOfRange(array, 0, 32);
array2 = Arrays.CopyOfRange(array, 32, 96);
k = Arrays.CopyOfRange(array, 96, 128);
polyVecMatrix.ExpandMatrix(rho);
polyVec.UniformEta(array2, 0);
polyVec2.UniformEta(array2, (ushort)L);
PolyVec polyVec5 = new PolyVec(this, L);
polyVec.CopyTo(polyVec5);
polyVec5.Ntt();
polyVecMatrix.PointwiseMontgomery(polyVec3, polyVec5);
polyVec3.Reduce();
polyVec3.InverseNttToMont();
polyVec3.Add(polyVec2);
polyVec3.ConditionalAddQ();
polyVec3.Power2Round(polyVec4);
encT1 = Packing.PackPublicKey(polyVec3, this);
shakeDigest.BlockUpdate(rho, 0, rho.Length);
shakeDigest.BlockUpdate(encT1, 0, encT1.Length);
shakeDigest.OutputFinal(tr, 0, 64);
Packing.PackSecretKey(t0_, s1_, s2_, polyVec4, polyVec, polyVec2, this);
}
internal byte[] DeriveT1(byte[] rho, byte[] s1Enc, byte[] s2Enc, byte[] t0Enc)
{
PolyVecMatrix polyVecMatrix = new PolyVecMatrix(this);
PolyVec polyVec = new PolyVec(this, L);
PolyVec polyVec2 = new PolyVec(this, K);
PolyVec polyVec3 = new PolyVec(this, K);
PolyVec polyVec4 = new PolyVec(this, K);
Packing.UnpackSecretKey(polyVec4, polyVec, polyVec2, t0Enc, s1Enc, s2Enc, this);
polyVecMatrix.ExpandMatrix(rho);
PolyVec polyVec5 = new PolyVec(this, L);
polyVec.CopyTo(polyVec5);
polyVec5.Ntt();
polyVecMatrix.PointwiseMontgomery(polyVec3, polyVec5);
polyVec3.Reduce();
polyVec3.InverseNttToMont();
polyVec3.Add(polyVec2);
polyVec3.ConditionalAddQ();
polyVec3.Power2Round(polyVec4);
return Packing.PackPublicKey(polyVec3, this);
}
internal void MsgRepBegin(ShakeDigest d, byte[] tr)
{
d.BlockUpdate(tr, 0, 64);
}
internal static ShakeDigest MsgRepCreateDigest()
{
return new ShakeDigest(256);
}
internal void MsgRepEndSign(ShakeDigest d, byte[] sig, int siglen, byte[] rho, byte[] k, byte[] t0Enc, byte[] s1Enc, byte[] s2Enc, bool legacy)
{
byte[] array = new byte[32];
Random?.NextBytes(array);
MsgRepEndSignInternal(d, sig, siglen, rho, k, t0Enc, s1Enc, s2Enc, array, legacy);
}
internal void MsgRepEndSignInternal(ShakeDigest d, byte[] sig, int siglen, byte[] rho, byte[] k, byte[] t0Enc, byte[] s1Enc, byte[] s2Enc, byte[] rnd, bool legacy)
{
byte[] array = new byte[64];
d.OutputFinal(array, 0, 64);
byte[] array3 = new byte[224];
byte[] array2 = new byte[64];
PolyVecMatrix polyVecMatrix = new PolyVecMatrix(this);
PolyVec polyVec = new PolyVec(this, L);
PolyVec polyVec2 = new PolyVec(this, L);
PolyVec polyVec3 = new PolyVec(this, L);
PolyVec polyVec4 = new PolyVec(this, K);
PolyVec polyVec5 = new PolyVec(this, K);
PolyVec polyVec6 = new PolyVec(this, K);
PolyVec polyVec7 = new PolyVec(this, K);
PolyVec polyVec8 = new PolyVec(this, K);
Poly poly = new Poly(this);
Packing.UnpackSecretKey(polyVec4, polyVec, polyVec5, t0Enc, s1Enc, s2Enc, this);
d.BlockUpdate(k, 0, 32);
d.BlockUpdate(rnd, 0, 32);
d.BlockUpdate(array, 0, 64);
d.OutputFinal(array2, 0, 64);
polyVecMatrix.ExpandMatrix(rho);
polyVec.Ntt();
polyVec5.Ntt();
polyVec4.Ntt();
ushort num = 0;
int num2 = 0;
while (legacy || ++num2 <= 1000) {
PolyVec polyVec9 = polyVec2;
byte[] seed = array2;
ushort num3 = num;
num = (ushort)(num3 + 1);
polyVec9.UniformGamma1(seed, num3);
polyVec2.CopyTo(polyVec3);
polyVec3.Ntt();
polyVecMatrix.PointwiseMontgomery(polyVec6, polyVec3);
polyVec6.Reduce();
polyVec6.InverseNttToMont();
polyVec6.ConditionalAddQ();
polyVec6.Decompose(polyVec7);
polyVec6.PackW1(this, sig, 0);
d.BlockUpdate(array, 0, 64);
d.BlockUpdate(sig, 0, K * PolyW1PackedBytes);
d.OutputFinal(sig, 0, CTilde);
poly.Challenge(sig, 0, CTilde);
poly.PolyNtt();
polyVec3.PointwisePolyMontgomery(poly, polyVec);
polyVec3.InverseNttToMont();
polyVec3.Add(polyVec2);
polyVec3.Reduce();
if (!polyVec3.CheckNorm(Gamma1 - Beta)) {
polyVec8.PointwisePolyMontgomery(poly, polyVec5);
polyVec8.InverseNttToMont();
polyVec7.Subtract(polyVec8);
polyVec7.Reduce();
if (!polyVec7.CheckNorm(Gamma2 - Beta)) {
polyVec8.PointwisePolyMontgomery(poly, polyVec4);
polyVec8.InverseNttToMont();
polyVec8.Reduce();
if (!polyVec8.CheckNorm(Gamma2)) {
polyVec7.Add(polyVec8);
polyVec7.ConditionalAddQ();
if (polyVec8.MakeHint(polyVec7, polyVec6) <= Omega) {
Packing.PackSignature(sig, polyVec3, polyVec8, this);
return;
}
}
}
}
}
throw new InvalidOperationException();
}
internal bool MsgRepEndVerifyInternal(ShakeDigest d, byte[] sig, int siglen, byte[] rho, byte[] encT1)
{
if (siglen != CryptoBytes)
return false;
PolyVec h = new PolyVec(this, K);
PolyVec polyVec = new PolyVec(this, L);
if (!Packing.UnpackSignature(polyVec, h, sig, this))
return false;
if (polyVec.CheckNorm(Gamma1 - Beta))
return false;
byte[] array = new byte[System.Math.Max(64 + K * PolyW1PackedBytes, CTilde)];
d.DoFinal(array, 0);
Poly poly = new Poly(this);
PolyVecMatrix polyVecMatrix = new PolyVecMatrix(this);
PolyVec polyVec2 = new PolyVec(this, K);
PolyVec polyVec3 = new PolyVec(this, K);
Packing.UnpackPublicKey(polyVec2, encT1, this);
poly.Challenge(sig, 0, CTilde);
polyVecMatrix.ExpandMatrix(rho);
polyVec.Ntt();
polyVecMatrix.PointwiseMontgomery(polyVec3, polyVec);
poly.PolyNtt();
polyVec2.ShiftLeft();
polyVec2.Ntt();
polyVec2.PointwisePolyMontgomery(poly, polyVec2);
polyVec3.Subtract(polyVec2);
polyVec3.Reduce();
polyVec3.InverseNttToMont();
polyVec3.ConditionalAddQ();
polyVec3.UseHint(polyVec3, h);
polyVec3.PackW1(this, array, 64);
d.BlockUpdate(array, 0, 64 + K * PolyW1PackedBytes);
d.OutputFinal(array, 0, CTilde);
return Arrays.FixedTimeEquals(CTilde, sig, 0, array, 0);
}
internal ShakeDigest MsgRepPreHash(byte[] tr, byte[] msg, int msgOff, int msgLen)
{
ShakeDigest shakeDigest = MsgRepCreateDigest();
MsgRepBegin(shakeDigest, tr);
shakeDigest.BlockUpdate(msg, msgOff, msgLen);
return shakeDigest;
}
internal void Sign(byte[] sig, int siglen, byte[] msg, int msgOff, int msgLen, byte[] rho, byte[] k, byte[] tr, byte[] t0Enc, byte[] s1Enc, byte[] s2Enc, bool legacy)
{
ShakeDigest d = MsgRepPreHash(tr, msg, msgOff, msgLen);
MsgRepEndSign(d, sig, siglen, rho, k, t0Enc, s1Enc, s2Enc, legacy);
}
internal void SignInternal(byte[] sig, int siglen, byte[] msg, int msgOff, int msgLen, byte[] rho, byte[] k, byte[] tr, byte[] t0Enc, byte[] s1Enc, byte[] s2Enc, byte[] rnd, bool legacy)
{
ShakeDigest d = MsgRepPreHash(tr, msg, msgOff, msgLen);
MsgRepEndSignInternal(d, sig, siglen, rho, k, t0Enc, s1Enc, s2Enc, rnd, legacy);
}
internal bool VerifyInternal(byte[] sig, int siglen, byte[] msg, int msgOff, int msgLen, byte[] rho, byte[] encT1, byte[] tr)
{
ShakeDigest d = MsgRepPreHash(tr, msg, msgOff, msgLen);
return MsgRepEndVerifyInternal(d, sig, siglen, rho, encT1);
}
}
}