Poly
using System;
namespace Org.BouncyCastle.Crypto.Kems.MLKem
{
internal sealed class Poly
{
private readonly MLKemEngine m_engine;
private readonly Symmetric m_symmetric;
internal readonly short[] m_coeffs = new short[256];
internal short[] Coeffs => m_coeffs;
internal Poly(MLKemEngine mEngine)
{
m_engine = mEngine;
m_symmetric = mEngine.Symmetric;
}
internal unsafe void GetNoiseEta1(ReadOnlySpan<byte> seed, byte nonce)
{
int num = m_engine.Eta1 * 256 / 4;
Span<byte> span = new Span<byte>(stackalloc byte[(int)(uint)num], num);
m_symmetric.Prf(seed, nonce, span);
Cbd.Eta(this, span, m_engine.Eta1);
}
internal unsafe void GetNoiseEta2(ReadOnlySpan<byte> seed, byte nonce)
{
Span<byte> span = new Span<byte>(stackalloc byte[128], 128);
m_symmetric.Prf(seed, nonce, span);
Cbd.Eta(this, span, 2);
}
internal void PolyNtt()
{
Ntt.NTT(Coeffs);
PolyReduce();
}
internal void PolyInverseNttToMont()
{
Ntt.InvNTT(Coeffs);
}
internal static void BaseMultMontgomery(Poly r, Poly a, Poly b)
{
for (int i = 0; i < 64; i++) {
Ntt.BaseMult(r.Coeffs, 4 * i, a.Coeffs[4 * i], a.Coeffs[4 * i + 1], b.Coeffs[4 * i], b.Coeffs[4 * i + 1], Ntt.Zetas[64 + i]);
Ntt.BaseMult(r.Coeffs, 4 * i + 2, a.Coeffs[4 * i + 2], a.Coeffs[4 * i + 3], b.Coeffs[4 * i + 2], b.Coeffs[4 * i + 3], (short)(-1 * Ntt.Zetas[64 + i]));
}
}
internal void ToMont()
{
for (int i = 0; i < 256; i++) {
Coeffs[i] = Reduce.MontgomeryReduce(Coeffs[i] * 1353);
}
}
internal void Add(Poly a)
{
for (int i = 0; i < 256; i++) {
Coeffs[i] += a.Coeffs[i];
}
}
internal void Subtract(Poly a)
{
for (int i = 0; i < 256; i++) {
Coeffs[i] = (short)(a.Coeffs[i] - Coeffs[i]);
}
}
internal void PolyReduce()
{
for (int i = 0; i < 256; i++) {
Coeffs[i] = Reduce.BarrettReduce(Coeffs[i]);
}
}
internal unsafe void CompressPoly(Span<byte> rBuf)
{
int num = 0;
Span<byte> span = new Span<byte>(stackalloc byte[8], 8);
CondSubQ();
if (m_engine.PolyCompressedBytes == 128) {
for (int i = 0; i < 32; i++) {
for (int j = 0; j < 8; j++) {
int num2 = m_coeffs[8 * i + j];
span[j] = (byte)(((num2 + 104) * 315 >> 16) & 15);
}
rBuf[num] = (byte)(span[0] | (span[1] << 4));
rBuf[num + 1] = (byte)(span[2] | (span[3] << 4));
rBuf[num + 2] = (byte)(span[4] | (span[5] << 4));
rBuf[num + 3] = (byte)(span[6] | (span[7] << 4));
num += 4;
}
} else {
if (m_engine.PolyCompressedBytes != 160)
throw new ArgumentException("PolyCompressedBytes is neither 128 or 160!");
for (int k = 0; k < 32; k++) {
for (int l = 0; l < 8; l++) {
int num3 = m_coeffs[8 * k + l];
span[l] = (byte)(((num3 + 52) * 630 >> 16) & 31);
}
rBuf[num] = (byte)(span[0] | (span[1] << 5));
rBuf[num + 1] = (byte)((span[1] >> 3) | (span[2] << 2) | (span[3] << 7));
rBuf[num + 2] = (byte)((span[3] >> 1) | (span[4] << 4));
rBuf[num + 3] = (byte)((span[4] >> 4) | (span[5] << 1) | (span[6] << 6));
rBuf[num + 4] = (byte)((span[6] >> 2) | (span[7] << 3));
num += 5;
}
}
}
internal void DecompressPoly(ReadOnlySpan<byte> cBuf)
{
int num = 0;
if (m_engine.PolyCompressedBytes == 128) {
for (int i = 0; i < 128; i++) {
Coeffs[2 * i] = (short)((short)(cBuf[num] & 255 & 15) * 3329 + 8 >> 4);
Coeffs[2 * i + 1] = (short)((short)((cBuf[num] & 255) >> 4) * 3329 + 8 >> 4);
num++;
}
} else {
if (m_engine.PolyCompressedBytes != 160)
throw new ArgumentException("PolyCompressedBytes is neither 128 or 160!");
byte[] array = new byte[8];
for (int j = 0; j < 32; j++) {
array[0] = (byte)(cBuf[num] & 255);
array[1] = (byte)(((cBuf[num] & 255) >> 5) | ((cBuf[num + 1] & 255) << 3));
array[2] = (byte)((cBuf[num + 1] & 255) >> 2);
array[3] = (byte)(((cBuf[num + 1] & 255) >> 7) | ((cBuf[num + 2] & 255) << 1));
array[4] = (byte)(((cBuf[num + 2] & 255) >> 4) | ((cBuf[num + 3] & 255) << 4));
array[5] = (byte)((cBuf[num + 3] & 255) >> 1);
array[6] = (byte)(((cBuf[num + 3] & 255) >> 6) | ((cBuf[num + 4] & 255) << 2));
array[7] = (byte)((cBuf[num + 4] & 255) >> 3);
num += 5;
for (int k = 0; k < 8; k++) {
Coeffs[8 * j + k] = (short)((array[k] & 31) * 3329 + 16 >> 5);
}
}
}
}
internal void FromBytes(ReadOnlySpan<byte> a)
{
for (int i = 0; i < 128; i++) {
Coeffs[2 * i] = (short)(((a[3 * i] & 255) | (ushort)((a[3 * i + 1] & 255) << 8)) & 4095);
Coeffs[2 * i + 1] = (short)((((a[3 * i + 1] & 255) >> 4) | (ushort)((a[3 * i + 2] & 255) << 4)) & 4095);
}
}
internal void ToBytes(Span<byte> r)
{
CondSubQ();
for (int i = 0; i < 128; i++) {
ushort num = (ushort)Coeffs[2 * i];
ushort num2 = (ushort)Coeffs[2 * i + 1];
r[3 * i] = (byte)num;
r[3 * i + 1] = (byte)((num >> 8) | (ushort)(num2 << 4));
r[3 * i + 2] = (byte)(ushort)(num2 >> 4);
}
}
internal void ToMsg(Span<byte> msg)
{
CondSubQ();
for (int i = 0; i < 32; i++) {
uint num = 0;
for (int j = 0; j < 8; j++) {
int num2 = Coeffs[8 * i + j];
uint num3 = (uint)((832 - num2) & (num2 - 2497)) >> 31;
num |= num3 << j;
}
msg[i] = (byte)num;
}
}
internal void FromMsg(ReadOnlySpan<byte> m)
{
if (m.Length != 32)
throw new ArgumentException("ML_KEM_INDCPA_MSGBYTES must be equal to ML_KEM_N/8 bytes!");
for (int i = 0; i < 32; i++) {
for (int j = 0; j < 8; j++) {
short num = (short)(-1 * (short)(((m[i] & 255) >> j) & 1));
Coeffs[8 * i + j] = (short)(num & 1665);
}
}
}
internal void CondSubQ()
{
for (int i = 0; i < 256; i++) {
Coeffs[i] = Reduce.CondSubQ(Coeffs[i]);
}
}
}
}