PolyVec
using System;
namespace Org.BouncyCastle.Crypto.Kems.MLKem
{
internal sealed class PolyVec
{
private readonly MLKemEngine m_engine;
internal readonly Poly[] m_vec;
internal PolyVec(MLKemEngine engine)
{
m_engine = engine;
m_vec = new Poly[engine.K];
for (int i = 0; i < engine.K; i++) {
m_vec[i] = new Poly(engine);
}
}
internal void Ntt()
{
for (int i = 0; i < m_engine.K; i++) {
m_vec[i].PolyNtt();
}
}
internal void InverseNttToMont()
{
for (int i = 0; i < m_engine.K; i++) {
m_vec[i].PolyInverseNttToMont();
}
}
internal static void PointwiseAccountMontgomery(Poly r, PolyVec a, PolyVec b, MLKemEngine engine)
{
Poly poly = new Poly(engine);
Poly.BaseMultMontgomery(r, a.m_vec[0], b.m_vec[0]);
for (int i = 1; i < engine.K; i++) {
Poly.BaseMultMontgomery(poly, a.m_vec[i], b.m_vec[i]);
r.Add(poly);
}
r.PolyReduce();
}
internal void Add(PolyVec a)
{
for (int i = 0; i < m_engine.K; i++) {
m_vec[i].Add(a.m_vec[i]);
}
}
internal void Reduce()
{
for (int i = 0; i < m_engine.K; i++) {
m_vec[i].PolyReduce();
}
}
internal unsafe void CompressPolyVec(Span<byte> rBuf)
{
int num = 0;
ConditionalSubQ();
if (m_engine.PolyVecCompressedBytes == m_engine.K * 320) {
Span<short> span = new Span<short>(stackalloc byte[8], 4);
for (int i = 0; i < m_engine.K; i++) {
short[] coeffs = m_vec[i].m_coeffs;
for (int j = 0; j < 64; j++) {
for (int k = 0; k < 4; k++) {
int num2 = coeffs[4 * j + k];
span[k] = (short)(((long)((num2 << 3) + 13) * 165141429 >> 32) & 1023);
}
rBuf[num] = (byte)span[0];
rBuf[num + 1] = (byte)((span[0] >> 8) | (span[1] << 2));
rBuf[num + 2] = (byte)((span[1] >> 6) | (span[2] << 4));
rBuf[num + 3] = (byte)((span[2] >> 4) | (span[3] << 6));
rBuf[num + 4] = (byte)(span[3] >> 2);
num += 5;
}
}
} else {
if (m_engine.PolyVecCompressedBytes != m_engine.K * 352)
throw new ArgumentException("ML-KEM PolyVecCompressedBytes neither 320 * K or 352 * K!");
Span<short> span2 = new Span<short>(stackalloc byte[16], 8);
for (int l = 0; l < m_engine.K; l++) {
short[] coeffs2 = m_vec[l].m_coeffs;
for (int m = 0; m < 32; m++) {
for (int n = 0; n < 8; n++) {
int num3 = coeffs2[8 * m + n];
span2[n] = (short)(((long)((num3 << 4) + 13) * 165141429 >> 32) & 2047);
}
rBuf[num] = (byte)span2[0];
rBuf[num + 1] = (byte)((span2[0] >> 8) | (span2[1] << 3));
rBuf[num + 2] = (byte)((span2[1] >> 5) | (span2[2] << 6));
rBuf[num + 3] = (byte)(span2[2] >> 2);
rBuf[num + 4] = (byte)((span2[2] >> 10) | (span2[3] << 1));
rBuf[num + 5] = (byte)((span2[3] >> 7) | (span2[4] << 4));
rBuf[num + 6] = (byte)((span2[4] >> 4) | (span2[5] << 7));
rBuf[num + 7] = (byte)(span2[5] >> 1);
rBuf[num + 8] = (byte)((span2[5] >> 9) | (span2[6] << 2));
rBuf[num + 9] = (byte)((span2[6] >> 6) | (span2[7] << 5));
rBuf[num + 10] = (byte)(span2[7] >> 3);
num += 11;
}
}
}
}
internal void DecompressPolyVec(ReadOnlySpan<byte> cBuf)
{
int num = 0;
if (m_engine.PolyVecCompressedBytes == m_engine.K * 320) {
for (int i = 0; i < m_engine.K; i++) {
short[] coeffs = m_vec[i].m_coeffs;
for (int j = 0; j < 256; j += 4) {
byte num2 = cBuf[num];
int num3 = cBuf[num + 1];
int num4 = cBuf[num + 2];
int num5 = cBuf[num + 3];
int num6 = cBuf[num + 4];
num += 5;
short num7 = (short)(num2 | ((ushort)num3 << 8));
short num8 = (short)((num3 >> 2) | ((ushort)num4 << 6));
short num9 = (short)((num4 >> 4) | ((ushort)num5 << 4));
short num10 = (short)((num5 >> 6) | ((ushort)num6 << 2));
coeffs[j] = (short)((num7 & 1023) * 3329 + 512 >> 10);
coeffs[j + 1] = (short)((num8 & 1023) * 3329 + 512 >> 10);
coeffs[j + 2] = (short)((num9 & 1023) * 3329 + 512 >> 10);
coeffs[j + 3] = (short)((num10 & 1023) * 3329 + 512 >> 10);
}
}
} else {
if (m_engine.PolyVecCompressedBytes != m_engine.K * 352)
throw new ArgumentException("ML-KEM PolyVecCompressedBytes neither 320 * K or 352 * K!");
for (int k = 0; k < m_engine.K; k++) {
short[] coeffs2 = m_vec[k].m_coeffs;
for (int l = 0; l < 256; l += 8) {
byte num11 = cBuf[num];
int num12 = cBuf[num + 1];
int num13 = cBuf[num + 2];
int num14 = cBuf[num + 3];
int num15 = cBuf[num + 4];
int num16 = cBuf[num + 5];
int num17 = cBuf[num + 6];
int num18 = cBuf[num + 7];
int num19 = cBuf[num + 8];
int num20 = cBuf[num + 9];
int num21 = cBuf[num + 10];
num += 11;
short num22 = (short)(num11 | ((ushort)num12 << 8));
short num23 = (short)((num12 >> 3) | ((ushort)num13 << 5));
short num24 = (short)((num13 >> 6) | ((ushort)num14 << 2) | (ushort)(num15 << 10));
short num25 = (short)((num15 >> 1) | ((ushort)num16 << 7));
short num26 = (short)((num16 >> 4) | ((ushort)num17 << 4));
short num27 = (short)((num17 >> 7) | ((ushort)num18 << 1) | (ushort)(num19 << 9));
short num28 = (short)((num19 >> 2) | ((ushort)num20 << 6));
short num29 = (short)((num20 >> 5) | ((ushort)num21 << 3));
coeffs2[l] = (short)((num22 & 2047) * 3329 + 1024 >> 11);
coeffs2[l + 1] = (short)((num23 & 2047) * 3329 + 1024 >> 11);
coeffs2[l + 2] = (short)((num24 & 2047) * 3329 + 1024 >> 11);
coeffs2[l + 3] = (short)((num25 & 2047) * 3329 + 1024 >> 11);
coeffs2[l + 4] = (short)((num26 & 2047) * 3329 + 1024 >> 11);
coeffs2[l + 5] = (short)((num27 & 2047) * 3329 + 1024 >> 11);
coeffs2[l + 6] = (short)((num28 & 2047) * 3329 + 1024 >> 11);
coeffs2[l + 7] = (short)((num29 & 2047) * 3329 + 1024 >> 11);
}
}
}
}
internal void FromBytes(ReadOnlySpan<byte> pk)
{
for (int i = 0; i < m_engine.K; i++) {
m_vec[i].FromBytes(pk.Slice(i * 384));
}
}
internal void ToBytes(Span<byte> r)
{
for (int i = 0; i < m_engine.K; i++) {
m_vec[i].ToBytes(r.Slice(i * 384));
}
}
private void ConditionalSubQ()
{
for (int i = 0; i < m_engine.K; i++) {
m_vec[i].CondSubQ();
}
}
}
}