Poly
using System;
namespace Org.BouncyCastle.Crypto.Kems.MLKem
{
    internal sealed class Poly
    {
        private readonly MLKemEngine m_engine;
        private readonly Symmetric m_symmetric;
        internal readonly short[] m_coeffs = new short[256];
        internal short[] Coeffs => m_coeffs;
        internal Poly(MLKemEngine mEngine)
        {
            m_engine = mEngine;
            m_symmetric = mEngine.Symmetric;
        }
        internal void GetNoiseEta1(byte[] seed, byte nonce)
        {
            byte[] array = new byte[m_engine.Eta1 * 256 / 4];
            m_symmetric.Prf(seed, nonce, array);
            Cbd.Eta(this, array, m_engine.Eta1);
        }
        internal void GetNoiseEta2(byte[] seed, byte nonce)
        {
            byte[] array = new byte[128];
            m_symmetric.Prf(seed, nonce, array);
            Cbd.Eta(this, array, 2);
        }
        internal void PolyNtt()
        {
            Ntt.NTT(Coeffs);
            PolyReduce();
        }
        internal void PolyInverseNttToMont()
        {
            Ntt.InvNTT(Coeffs);
        }
        internal static void BaseMultMontgomery(Poly r, Poly a, Poly b)
        {
            for (int i = 0; i < 64; i++) {
                Ntt.BaseMult(r.Coeffs, 4 * i, a.Coeffs[4 * i], a.Coeffs[4 * i + 1], b.Coeffs[4 * i], b.Coeffs[4 * i + 1], Ntt.Zetas[64 + i]);
                Ntt.BaseMult(r.Coeffs, 4 * i + 2, a.Coeffs[4 * i + 2], a.Coeffs[4 * i + 3], b.Coeffs[4 * i + 2], b.Coeffs[4 * i + 3], (short)(-1 * Ntt.Zetas[64 + i]));
            }
        }
        internal void ToMont()
        {
            for (int i = 0; i < 256; i++) {
                Coeffs[i] = Reduce.MontgomeryReduce(Coeffs[i] * 1353);
            }
        }
        internal void Add(Poly a)
        {
            for (int i = 0; i < 256; i++) {
                Coeffs[i] += a.Coeffs[i];
            }
        }
        internal void Subtract(Poly a)
        {
            for (int i = 0; i < 256; i++) {
                Coeffs[i] = (short)(a.Coeffs[i] - Coeffs[i]);
            }
        }
        internal void PolyReduce()
        {
            for (int i = 0; i < 256; i++) {
                Coeffs[i] = Reduce.BarrettReduce(Coeffs[i]);
            }
        }
        internal void CompressPoly(byte[] rBuf, int rOff)
        {
            int num = rOff;
            byte[] array = new byte[8];
            CondSubQ();
            if (m_engine.PolyCompressedBytes == 128) {
                for (int i = 0; i < 32; i++) {
                    for (int j = 0; j < 8; j++) {
                        int num2 = m_coeffs[8 * i + j];
                        array[j] = (byte)(((num2 + 104) * 315 >> 16) & 15);
                    }
                    rBuf[num] = (byte)(array[0] | (array[1] << 4));
                    rBuf[num + 1] = (byte)(array[2] | (array[3] << 4));
                    rBuf[num + 2] = (byte)(array[4] | (array[5] << 4));
                    rBuf[num + 3] = (byte)(array[6] | (array[7] << 4));
                    num += 4;
                }
            } else {
                if (m_engine.PolyCompressedBytes != 160)
                    throw new ArgumentException("PolyCompressedBytes is neither 128 or 160!");
                for (int k = 0; k < 32; k++) {
                    for (int l = 0; l < 8; l++) {
                        int num3 = m_coeffs[8 * k + l];
                        array[l] = (byte)(((num3 + 52) * 630 >> 16) & 31);
                    }
                    rBuf[num] = (byte)(array[0] | (array[1] << 5));
                    rBuf[num + 1] = (byte)((array[1] >> 3) | (array[2] << 2) | (array[3] << 7));
                    rBuf[num + 2] = (byte)((array[3] >> 1) | (array[4] << 4));
                    rBuf[num + 3] = (byte)((array[4] >> 4) | (array[5] << 1) | (array[6] << 6));
                    rBuf[num + 4] = (byte)((array[6] >> 2) | (array[7] << 3));
                    num += 5;
                }
            }
        }
        internal void DecompressPoly(byte[] cBuf, int cOff)
        {
            int num = cOff;
            if (m_engine.PolyCompressedBytes == 128) {
                for (int i = 0; i < 128; i++) {
                    Coeffs[2 * i] = (short)((short)(cBuf[num] & 255 & 15) * 3329 + 8 >> 4);
                    Coeffs[2 * i + 1] = (short)((short)((cBuf[num] & 255) >> 4) * 3329 + 8 >> 4);
                    num++;
                }
            } else {
                if (m_engine.PolyCompressedBytes != 160)
                    throw new ArgumentException("PolyCompressedBytes is neither 128 or 160!");
                byte[] array = new byte[8];
                for (int j = 0; j < 32; j++) {
                    array[0] = (byte)(cBuf[num] & 255);
                    array[1] = (byte)(((cBuf[num] & 255) >> 5) | ((cBuf[num + 1] & 255) << 3));
                    array[2] = (byte)((cBuf[num + 1] & 255) >> 2);
                    array[3] = (byte)(((cBuf[num + 1] & 255) >> 7) | ((cBuf[num + 2] & 255) << 1));
                    array[4] = (byte)(((cBuf[num + 2] & 255) >> 4) | ((cBuf[num + 3] & 255) << 4));
                    array[5] = (byte)((cBuf[num + 3] & 255) >> 1);
                    array[6] = (byte)(((cBuf[num + 3] & 255) >> 6) | ((cBuf[num + 4] & 255) << 2));
                    array[7] = (byte)((cBuf[num + 4] & 255) >> 3);
                    num += 5;
                    for (int k = 0; k < 8; k++) {
                        Coeffs[8 * j + k] = (short)((array[k] & 31) * 3329 + 16 >> 5);
                    }
                }
            }
        }
        internal void FromBytes(byte[] a, int off)
        {
            for (int i = 0; i < 128; i++) {
                Coeffs[2 * i] = (short)(((a[off + 3 * i] & 255) | (ushort)((a[off + 3 * i + 1] & 255) << 8)) & 4095);
                Coeffs[2 * i + 1] = (short)((((a[off + 3 * i + 1] & 255) >> 4) | (ushort)((a[off + 3 * i + 2] & 255) << 4)) & 4095);
            }
        }
        internal void ToBytes(byte[] r, int off)
        {
            CondSubQ();
            for (int i = 0; i < 128; i++) {
                ushort num = (ushort)Coeffs[2 * i];
                ushort num2 = (ushort)Coeffs[2 * i + 1];
                r[off + 3 * i] = (byte)num;
                r[off + 3 * i + 1] = (byte)((num >> 8) | (ushort)(num2 << 4));
                r[off + 3 * i + 2] = (byte)(ushort)(num2 >> 4);
            }
        }
        internal void ToMsg(byte[] msg)
        {
            CondSubQ();
            for (int i = 0; i < 32; i++) {
                uint num = 0;
                for (int j = 0; j < 8; j++) {
                    int num2 = Coeffs[8 * i + j];
                    uint num3 = (uint)((832 - num2) & (num2 - 2497)) >> 31;
                    num |= num3 << j;
                }
                msg[i] = (byte)num;
            }
        }
        internal void FromMsg(byte[] m)
        {
            if (m.Length != 32)
                throw new ArgumentException("ML_KEM_INDCPA_MSGBYTES must be equal to ML_KEM_N/8 bytes!");
            for (int i = 0; i < 32; i++) {
                for (int j = 0; j < 8; j++) {
                    short num = (short)(-1 * (short)(((m[i] & 255) >> j) & 1));
                    Coeffs[8 * i + j] = (short)(num & 1665);
                }
            }
        }
        internal void CondSubQ()
        {
            for (int i = 0; i < 256; i++) {
                Coeffs[i] = Reduce.CondSubQ(Coeffs[i]);
            }
        }
    }
}