diff --git a/prover/src/constraints/cpu.rs b/prover/src/constraints/cpu.rs index 546f2f2a4..1efc56228 100644 --- a/prover/src/constraints/cpu.rs +++ b/prover/src/constraints/cpu.rs @@ -1,18 +1,19 @@ //! CPU table constraints for the 64-bit VM. //! -//! This module defines the constraints for the CPU table, including: -//! - Range checks (IS_BIT) for all flag columns -//! - ALU constraints (ADD, SUB templates) -//! - Extension constraints (arg1, arg2, rvd computation) -//! - Branch condition computation -//! - next_pc computation +//! Translates the `cpu.toml` constraint groups onto the shrunk CPU layout +//! (`tables::cpu::cols`). Byte/half range checks (`IS_BYTE`/`IS_HALF`) and all +//! lookups (`DECODE`/`ALU`/`MEMORY`/`CPU32`/`MEMW`/`BRANCH`/`ECALL`) live in +//! `tables::cpu::bus_interactions`; this module holds only the algebraic +//! (transition) constraints: //! -//! ## Constraint Groups (from spec) +//! - **decode**: `word_instr · {MEMORY,BRANCH,ECALL} = 0` mutex. +//! - **range**: `IS_BIT` for the flag columns + the inline-PC bits + `non_padding`. +//! - **alu**: `arg2` multiplex, `ADD`/`SUB` fast-path templates on `rv1`/`arg2`. +//! - **mem**: `¬read_registerN ⇒ rvN = 0`, `¬MEMORY ⇒ rvd = cast(res, WL)`. +//! - **branch**: `branch_cond = BRANCH·(JALR + (1−JALR)·res[0])`, `next_pc = pc + len`. //! -//! 1. **Range checks**: IS_BIT for all bit flags (~25 constraints) -//! 2. **ALU**: ADD/SUB templates conditional on selectors -//! 3. **Extension**: arg1/arg2/rvd from rv1/rv2/res with sign extension -//! 4. **Misc**: branch_cond, next_pc computation +//! `JALR` is the `mem_flags` byte read directly: under `BRANCH` only the JALR bit +//! of `mem_flags` can be set, so `mem_flags ∈ {0,1} = JALR` there. use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; @@ -20,191 +21,81 @@ use stark::constraints::transition::{TransitionConstraint, TransitionConstraintE use stark::table::TableView; use crate::tables::cpu::cols; -use crate::tables::types::{GoldilocksExtension, GoldilocksField}; +use crate::tables::types::{GoldilocksExtension, GoldilocksField, SHIFT_16}; -use super::templates::{AddConstraint, AddLinearTerm, AddOperand, IsBitConstraint}; - -/// Pack 4 consecutive byte-column values into a 32-bit word field element. -/// `col0 + col1*2^8 + col2*2^16 + col3*2^24` -#[inline] -fn pack_bytes_to_word( - step: &TableView, - col0: usize, - col1: usize, - col2: usize, - col3: usize, -) -> FieldElement -where - F: IsSubFieldOf, - E: IsField, -{ - let b0 = step.get_main_evaluation_element(0, col0); - let b1 = step.get_main_evaluation_element(0, col1); - let b2 = step.get_main_evaluation_element(0, col2); - let b3 = step.get_main_evaluation_element(0, col3); - - let shift_8: FieldElement = FieldElement::from(1u64 << 8); - let shift_16: FieldElement = FieldElement::from(1u64 << 16); - let shift_24: FieldElement = FieldElement::from(1u64 << 24); - - b0 + b1 * &shift_8 + b2 * &shift_16 + b3 * shift_24 -} +use super::templates::{AddConstraint, AddOperand, IsBitConstraint}; // ========================================================================= -// CPU Constraint Collection +// Range: IS_BIT flag columns // ========================================================================= -/// All bit flag columns that need IS_BIT constraints. +/// Bit columns that need `IS_BIT` (`x·(x−1) = 0`) constraints. pub const BIT_FLAG_COLUMNS: &[usize] = &[ cols::READ_REGISTER1, cols::READ_REGISTER2, cols::WRITE_REGISTER, - cols::MEMORY_2BYTES, - cols::MEMORY_4BYTES, - cols::MEMORY_8BYTES, - cols::C_TYPE_INSTRUCTION, - cols::SIGNED, - cols::MP_SELECTOR, - cols::MULDIV_SELECTOR, cols::WORD_INSTR, - // ALU selectors + cols::ALU, cols::ADD, cols::SUB, - cols::SLT, - cols::AND, - cols::OR, - cols::XOR, - cols::SHIFT, - cols::JALR, - cols::BEQ, - cols::BLT, - cols::LOAD, - cols::STORE, - cols::MUL, - cols::DIVREM, + cols::MEMORY, + cols::BRANCH, cols::ECALL, - cols::EBREAK, - // Sign bits - cols::RV1_EXT_BIT, - cols::RV2_EXT_BIT, - cols::RES_EXT_BIT, - // Computed flags - cols::IS_EQUAL, - cols::BRANCH_COND, - // Inline PC columns - cols::PREV_PC_TIMESTAMP_BORROW, cols::PC_DOUBLE_READ, + cols::PREV_PC_TIMESTAMP_BORROW, ]; /// Creates all IS_BIT constraints for CPU flag columns. -/// -/// Returns the constraints and the next available constraint index. pub fn create_is_bit_constraints(constraint_idx_start: usize) -> (Vec, usize) { super::templates::new_is_bit_constraints(BIT_FLAG_COLUMNS, constraint_idx_start) } // ========================================================================= -// ALU ADD Constraints +// Generic helpers // ========================================================================= -/// Creates ADD constraints for the CPU table. -/// -/// ADD template is used when: ADD + LOAD + STORE > 0 -/// - ADD: arg1 + arg2 = res (arithmetic addition) -/// - LOAD/STORE: base_address + offset = effective_address (in res) -/// -/// Returns the constraints and the next available constraint index. -pub fn create_add_constraints(constraint_idx_start: usize) -> (Vec, usize) { - // For ADD/LOAD operations, we compute: arg1 + arg2 = res - // All operands are DWordBL (8 bytes), need to cast to DWordWL (2 words) - - let lhs = AddOperand::from_dword_bl(cols::ARG1_0); - let rhs = AddOperand::from_dword_bl(cols::ARG2_0); - let sum = AddOperand::from_dword_bl(cols::RES_0); - - // Condition: ADD + LOAD (active when any of these flags is set) - let cond_cols = vec![cols::ADD, cols::LOAD]; - - let (add_c0, add_c1) = AddConstraint::new_pair(cond_cols, lhs, rhs, sum, constraint_idx_start); - - // STORE: res = arg1 + imm (separate ADD, because arg2 now holds rv2) - // arg1 is DWordBL, imm is DWordWL, res is DWordBL - let store_lhs = AddOperand::from_dword_bl(cols::ARG1_0); - let store_rhs = AddOperand::dword(cols::IMM_0); - let store_sum = AddOperand::from_dword_bl(cols::RES_0); - let store_cond = vec![cols::STORE]; - let (store_c0, store_c1) = AddConstraint::new_pair( - store_cond, - store_lhs, - store_rhs, - store_sum, - constraint_idx_start + 2, - ); - - ( - vec![add_c0, add_c1, store_c0, store_c1], - constraint_idx_start + 4, - ) +/// `cast(res, DWordWL)` low/high words from the four `res` halves (DWordHL). +#[inline] +fn res_word(step: &TableView, high: bool) -> FieldElement +where + F: IsSubFieldOf, + E: IsField, +{ + let (lo_col, hi_col) = if high { + (cols::RES_2, cols::RES_3) + } else { + (cols::RES_0, cols::RES_1) + }; + let shift_16: FieldElement = FieldElement::from(SHIFT_16); + step.get_main_evaluation_element(0, lo_col) + + step.get_main_evaluation_element(0, hi_col) * shift_16 } // ========================================================================= -// Branch Condition Constraint +// decode group: word_instr mutex // ========================================================================= -/// Constraint for branch_cond computation. -/// -/// From spec: -/// branch_cond = JALR -/// + BLT * (res[0] XOR mp_selector) -/// + BEQ * (is_equal XOR mp_selector) -/// -/// Where XOR is computed as: a XOR b = a + b - 2*a*b -pub struct BranchCondConstraint { +/// Constraint `col_a · col_b = 0`. Used for the decode mutexes +/// `word_instr · {MEMORY, BRANCH, ECALL} = 0`. +pub struct ProductZeroConstraint { + col_a: usize, + col_b: usize, constraint_idx: usize, } -impl BranchCondConstraint { - pub fn new(constraint_idx: usize) -> Self { - Self { constraint_idx } - } - - fn compute(&self, step: &TableView) -> FieldElement - where - F: IsSubFieldOf, - E: IsField, - { - let jalr = step.get_main_evaluation_element(0, cols::JALR).clone(); - let blt = step.get_main_evaluation_element(0, cols::BLT).clone(); - let beq = step.get_main_evaluation_element(0, cols::BEQ).clone(); - let mp_selector = step - .get_main_evaluation_element(0, cols::MP_SELECTOR) - .clone(); - let res_0 = step.get_main_evaluation_element(0, cols::RES_0).clone(); - let is_equal = step.get_main_evaluation_element(0, cols::IS_EQUAL).clone(); - let branch_cond = step - .get_main_evaluation_element(0, cols::BRANCH_COND) - .clone(); - - let two = FieldElement::::from(2u64); - - // XOR computation: a XOR b = a + b - 2*a*b - // res[0] XOR mp_selector - let res_xor_mp = &res_0 + &mp_selector - &two * &res_0 * &mp_selector; - // is_equal XOR mp_selector - let eq_xor_mp = &is_equal + &mp_selector - &two * &is_equal * &mp_selector; - - // branch_cond = JALR + BLT * res_xor_mp + BEQ * eq_xor_mp - let expected = jalr + &blt * res_xor_mp + &beq * eq_xor_mp; - - // Constraint: branch_cond - expected = 0 - branch_cond - expected +impl ProductZeroConstraint { + pub fn new(col_a: usize, col_b: usize, constraint_idx: usize) -> Self { + Self { + col_a, + col_b, + constraint_idx, + } } } -impl TransitionConstraint for BranchCondConstraint { +impl TransitionConstraint for ProductZeroConstraint { fn degree(&self) -> usize { - // BLT * res_0 * mp_selector has degree 3 - 3 + 2 } fn constraint_idx(&self) -> usize { @@ -216,39 +107,31 @@ impl TransitionConstraint for BranchCondCo F: IsSubFieldOf, E: IsField, { - self.compute(step) + step.get_main_evaluation_element(0, self.col_a) + * step.get_main_evaluation_element(0, self.col_b) } } -// ========================================================================= -// EBREAK Constraint -// ========================================================================= - -/// Constraint that EBREAK must be 0 (unprovable trap). -/// -/// From spec: !EBREAK (we treat EBREAK as an unprovable trap) -pub struct EbreakConstraint { +/// `(1 - MEMORY - BRANCH) · read_register2 · imm[i] = 0`: when neither MEMORY nor +/// BRANCH is set, the `arg2` multiplex needs at most one of `rv2`/`imm` nonzero. +/// Decoding already guarantees this; a spec defense-in-depth assumption. +pub struct Arg2ExclusiveConstraint { + imm_col: usize, constraint_idx: usize, } -impl EbreakConstraint { - pub fn new(constraint_idx: usize) -> Self { - Self { constraint_idx } - } - - fn compute(&self, step: &TableView) -> FieldElement - where - F: IsSubFieldOf, - E: IsField, - { - // EBREAK must be 0 - step.get_main_evaluation_element(0, cols::EBREAK).clone() +impl Arg2ExclusiveConstraint { + pub fn new(imm_col: usize, constraint_idx: usize) -> Self { + Self { + imm_col, + constraint_idx, + } } } -impl TransitionConstraint for EbreakConstraint { +impl TransitionConstraint for Arg2ExclusiveConstraint { fn degree(&self) -> usize { - 1 + 3 } fn constraint_idx(&self) -> usize { @@ -260,52 +143,31 @@ impl TransitionConstraint for EbreakConstr F: IsSubFieldOf, E: IsField, { - self.compute(step) + let one = FieldElement::::one(); + let memory = step.get_main_evaluation_element(0, cols::MEMORY).clone(); + let branch = step.get_main_evaluation_element(0, cols::BRANCH).clone(); + let rr2 = step.get_main_evaluation_element(0, cols::READ_REGISTER2); + let imm = step.get_main_evaluation_element(0, self.imm_col); + (one - memory - branch) * rr2 * imm } } -// ========================================================================= -// Extension Constraints -// ========================================================================= - -/// Constraint: arg1[0:4] = rv1[0:2] (lower 32 bits match) -/// -/// arg1 is DWordBL (8 bytes), rv1 is DWordWHH [Half, Half, Word] -/// arg1[:4] as word = rv1[0] + rv1[1] * 2^16 (two halves make a word) -/// -/// Spec (CPU-CE54): arg1::DWordWL[0] - rv1::DWordWL[0] = 0 -pub struct Arg1LowerConstraint { +/// `IS_BIT` on non-MEMORY rows: `(1 - MEMORY) · mem_flags · (1 - mem_flags) = 0`. +/// On non-memory rows `mem_flags` carries only the JALR bit, so it must be 0/1. +/// A spec defense-in-depth assumption (the DECODE lookup already enforces it). +pub struct MemFlagsBitConstraint { constraint_idx: usize, } -impl Arg1LowerConstraint { +impl MemFlagsBitConstraint { pub fn new(constraint_idx: usize) -> Self { Self { constraint_idx } } - - fn compute(&self, step: &TableView) -> FieldElement - where - F: IsSubFieldOf, - E: IsField, - { - let arg1_lo = - pack_bytes_to_word(step, cols::ARG1_0, cols::ARG1_1, cols::ARG1_2, cols::ARG1_3); - - // rv1 is DWordWHH: [Half(0-15), Half(16-31), Word(32-63)] - // rv1::DWordWL[0] = rv1[0] + rv1[1] * 2^16 - let rv1_0 = step.get_main_evaluation_element(0, cols::RV1_0); - let rv1_1 = step.get_main_evaluation_element(0, cols::RV1_1); - let shift_16: FieldElement = FieldElement::from(1u64 << 16); - let rv1_lower = rv1_0 + rv1_1 * shift_16; - - // Constraint: arg1_lo - rv1_lower = 0 - arg1_lo - rv1_lower - } } -impl TransitionConstraint for Arg1LowerConstraint { +impl TransitionConstraint for MemFlagsBitConstraint { fn degree(&self) -> usize { - 1 + 3 } fn constraint_idx(&self) -> usize { @@ -317,56 +179,38 @@ impl TransitionConstraint for Arg1LowerCon F: IsSubFieldOf, E: IsField, { - self.compute(step) + let one = FieldElement::::one(); + let memory = step.get_main_evaluation_element(0, cols::MEMORY).clone(); + let mem_flags = step.get_main_evaluation_element(0, cols::MEM_FLAGS).clone(); + (one.clone() - memory) * &mem_flags * (one - &mem_flags) } } -/// Constraint: arg1[4:8] = rv1[2] * (1 - word_instr) + (2^32 - 1) * rv1_ext_bit * signed -/// -/// Upper 32 bits of arg1 depends on word_instr and sign extension. -pub struct Arg1UpperConstraint { +// ========================================================================= +// mem group: register zero-forcing +// ========================================================================= + +/// Constraint `(1 − flag) · value = 0`: when `flag = 0`, `value` must be 0. +/// Used for `¬read_registerN ⇒ rvN[i] = 0`. +pub struct RegNotReadIsZeroConstraint { + flag_col: usize, + value_col: usize, constraint_idx: usize, } -impl Arg1UpperConstraint { - pub fn new(constraint_idx: usize) -> Self { - Self { constraint_idx } - } - - fn compute(&self, step: &TableView) -> FieldElement - where - F: IsSubFieldOf, - E: IsField, - { - let arg1_hi = - pack_bytes_to_word(step, cols::ARG1_4, cols::ARG1_5, cols::ARG1_6, cols::ARG1_7); - - // rv1 is DWordWHH: rv1[2] IS the upper 32 bits directly (Word) - let rv1_upper = step.get_main_evaluation_element(0, cols::RV1_2); - - let word_instr = step - .get_main_evaluation_element(0, cols::WORD_INSTR) - .clone(); - let signed = step.get_main_evaluation_element(0, cols::SIGNED).clone(); - let rv1_ext_bit = step - .get_main_evaluation_element(0, cols::RV1_EXT_BIT) - .clone(); - - let one = FieldElement::::one(); - let mask_32: FieldElement = FieldElement::from((1u64 << 32) - 1); // 2^32 - 1 - - // Expected: rv1_upper * (1 - word_instr) + mask_32 * rv1_ext_bit * signed - let expected = rv1_upper * (one - &word_instr) + mask_32 * rv1_ext_bit * signed; - - // Constraint: arg1_hi - expected = 0 - arg1_hi - expected +impl RegNotReadIsZeroConstraint { + pub fn new(flag_col: usize, value_col: usize, constraint_idx: usize) -> Self { + Self { + flag_col, + value_col, + constraint_idx, + } } } -impl TransitionConstraint for Arg1UpperConstraint { +impl TransitionConstraint for RegNotReadIsZeroConstraint { fn degree(&self) -> usize { - // rv1_ext_bit * signed * word_instr has degree 3 - 3 + 2 } fn constraint_idx(&self) -> usize { @@ -378,50 +222,51 @@ impl TransitionConstraint for Arg1UpperCon F: IsSubFieldOf, E: IsField, { - self.compute(step) + let one = FieldElement::::one(); + let flag = step.get_main_evaluation_element(0, self.flag_col).clone(); + let value = step.get_main_evaluation_element(0, self.value_col); + (one - flag) * value } } // ========================================================================= -// SLT/BLT Zero Upper Bytes Constraint +// alu group: arg2 multiplex // ========================================================================= -/// Constraint: when SLT + BLT = 1, res[i] = 0 for i in 1..8 +/// `arg2` multiplex (`cpu.toml` CPU-A1), for word index +/// `word_idx ∈ {0,1}`: +/// +/// ```text +/// arg2[i] = MEMORY·imm[i] +/// + BRANCH·rv2[i] +/// + (1−MEMORY−BRANCH)·(rv2[i] + imm[i]) +/// ``` /// -/// The LT result is a single bit stored in res[0], upper bytes must be zero. -pub struct SltResZeroConstraint { - /// Which byte index (1-7) this constraint applies to - byte_idx: usize, +/// For BRANCH rows `arg2 = rv2` (JAL/JALR read no rs2, so `rv2 = 0`; conditional +/// branches feed `rv2` to the EQ/LT comparison). The final `rv2 + imm` term has +/// no inter-word carry because decode assumption A2 guarantees at most one of +/// `rv2`/`imm` is nonzero when `MEMORY+BRANCH = 0`. `MEMORY` and `BRANCH` are +/// mutually exclusive (enforced by the live `MEMORY·BRANCH = 0` constraint), so +/// `1−MEMORY−BRANCH ∈ {0,1}` and matches the degree-2 spec form. +pub struct Arg2Constraint { + /// 0 = low word, 1 = high word. + word_idx: usize, constraint_idx: usize, } -impl SltResZeroConstraint { - pub fn new(byte_idx: usize, constraint_idx: usize) -> Self { - assert!((1..=7).contains(&byte_idx)); +impl Arg2Constraint { + pub fn new(word_idx: usize, constraint_idx: usize) -> Self { Self { - byte_idx, + word_idx, constraint_idx, } } - - fn compute(&self, step: &TableView) -> FieldElement - where - F: IsSubFieldOf, - E: IsField, - { - let slt = step.get_main_evaluation_element(0, cols::SLT).clone(); - let blt = step.get_main_evaluation_element(0, cols::BLT).clone(); - let res_i = step - .get_main_evaluation_element(0, cols::RES[self.byte_idx]) - .clone(); - - // (SLT + BLT) * res[i] = 0 - (slt + blt) * res_i - } } -impl TransitionConstraint for SltResZeroConstraint { +impl TransitionConstraint for Arg2Constraint { fn degree(&self) -> usize { + // (1 - MEMORY - BRANCH) [deg 1] · (rv2 + imm) [deg 1] = 2. The degree-2 + // form relies on the live MEMORY·BRANCH = 0 mutex. 2 } @@ -434,65 +279,58 @@ impl TransitionConstraint for SltResZeroCo F: IsSubFieldOf, E: IsField, { - self.compute(step) - } -} + let (arg2_col, imm_col, rv2_col) = if self.word_idx == 0 { + (cols::ARG2_0, cols::IMM_0, cols::RV2_0) + } else { + (cols::ARG2_1, cols::IMM_1, cols::RV2_1) + }; + + let one = FieldElement::::one(); + let arg2 = step.get_main_evaluation_element(0, arg2_col).clone(); + let imm = step.get_main_evaluation_element(0, imm_col).clone(); + let rv2 = step.get_main_evaluation_element(0, rv2_col).clone(); + let memory = step.get_main_evaluation_element(0, cols::MEMORY).clone(); + let branch = step.get_main_evaluation_element(0, cols::BRANCH).clone(); -/// Creates all SLT/BLT zero constraints for res[1..8]. -pub fn create_slt_res_zero_constraints( - constraint_idx_start: usize, -) -> (Vec, usize) { - let constraints: Vec<_> = (1..8) - .enumerate() - .map(|(i, byte_idx)| SltResZeroConstraint::new(byte_idx, constraint_idx_start + i)) - .collect(); + // MEMORY · imm + let mut expected = &memory * &imm; + // BRANCH · rv2 + expected += &branch * &rv2; + // (1 - MEMORY - BRANCH) · (rv2 + imm) + expected += (&one - &memory - &branch) * (&rv2 + &imm); - (constraints, constraint_idx_start + 7) + arg2 - expected + } } // ========================================================================= -// Extension Bit Constraints (SIGN template from spec) +// mem group: ¬MEMORY ∧ ¬JALR ⇒ rvd = cast(res, WL) // ========================================================================= -/// Constraint: ext_bit must be zero when word_instr = 0 +/// `(1 − MEMORY − BRANCH) · (rvd[i] − cast(res, WL)[i]) = 0` (`cpu.toml` CPU-M*). /// -/// (1 - word_instr) * ext_bit = 0 -/// -/// One instance per extension bit (rv1_ext_bit, rv2_ext_bit, res_ext_bit). -pub struct ExtBitZeroConstraint { +/// On plain ALU rows `rvd = res`. BRANCH rows are exempt: their `rvd` is the +/// return address `pc + instruction_length`, pinned by [`BranchRvdConstraint`]. +/// `MEMORY` and `BRANCH` are mutually exclusive (decode assumption), so +/// `1 − MEMORY − BRANCH ∈ {0,1}`. For LOAD/STORE `rvd` comes from the MEMORY bus. +pub struct RvdEqResConstraint { + /// 0 = low word, 1 = high word. + word_idx: usize, constraint_idx: usize, - ext_bit_col: usize, } -impl ExtBitZeroConstraint { - pub fn new(constraint_idx: usize, ext_bit_col: usize) -> Self { +impl RvdEqResConstraint { + pub fn new(word_idx: usize, constraint_idx: usize) -> Self { Self { + word_idx, constraint_idx, - ext_bit_col, } } - - fn compute(&self, step: &TableView) -> FieldElement - where - F: IsSubFieldOf, - E: IsField, - { - let ext_bit = step - .get_main_evaluation_element(0, self.ext_bit_col) - .clone(); - let word_instr = step - .get_main_evaluation_element(0, cols::WORD_INSTR) - .clone(); - - let one = FieldElement::::one(); - - // (1 - word_instr) * ext_bit = 0 - (one - word_instr) * ext_bit - } } -impl TransitionConstraint for ExtBitZeroConstraint { +impl TransitionConstraint for RvdEqResConstraint { fn degree(&self) -> usize { + // (1 - MEMORY - BRANCH) [deg 1] · (rvd - cast(res, WL)) [deg 1] = 2. 2 } @@ -505,28 +343,39 @@ impl TransitionConstraint for ExtBitZeroCo F: IsSubFieldOf, E: IsField, { - self.compute(step) + let high = self.word_idx == 1; + let rvd_col = if high { cols::RVD_1 } else { cols::RVD_0 }; + let one = FieldElement::::one(); + let memory = step.get_main_evaluation_element(0, cols::MEMORY).clone(); + let branch = step.get_main_evaluation_element(0, cols::BRANCH).clone(); + let rvd = step.get_main_evaluation_element(0, rvd_col).clone(); + let res_w = res_word(step, high); + (&one - &memory - &branch) * (rvd - res_w) } } // ========================================================================= -// Next PC (Non-Branching) Constraint +// branch group: BRANCH ⇒ rvd = pc + instruction_length // ========================================================================= -/// Constraint: when branch_cond = 0, next_pc = pc + instr_size -/// -/// where instr_size = 4 - 2 * c_type_instruction -/// (4 bytes for normal instructions, 2 bytes for compressed) +/// `BRANCH · carry · (1 − carry) = 0` for the 64-bit addition +/// `rvd = pc + instruction_length` (the JAL/JALR return address), in two +/// instances (`carry_0` / `carry_1`). Mirrors [`NextPcAddConstraint`] so the +/// low→high carry is propagated: the spec computes `rvd` with the same +/// carry-correct `ADD` template as `next_pc` (`cpu.toml` branch group), so the +/// high word must include the carry out of `pc[0] + instruction_length`. /// -/// Uses the same carry-based approach as AddConstraint but with -/// condition `(1 - branch_cond)` instead of a column value. -pub struct NextPcAddConstraint { - /// Which carry constraint this is (0 or 1) +/// On every BRANCH row `rvd` holds the return address `pc + instruction_length` +/// (written to `rd` only by JAL/JALR; conditional branches compute it but never +/// write it). See [`RvdEqResConstraint`] for the complementary +/// `¬MEMORY ∧ ¬BRANCH ⇒ rvd = res` case. +pub struct BranchRvdConstraint { + /// 0 = low-word carry, 1 = high-word carry. carry_idx: usize, constraint_idx: usize, } -impl NextPcAddConstraint { +impl BranchRvdConstraint { pub fn new(carry_idx: usize, constraint_idx: usize) -> Self { assert!(carry_idx <= 1); Self { @@ -535,7 +384,6 @@ impl NextPcAddConstraint { } } - /// Creates constraints for both carries. pub fn new_pair(constraint_idx_start: usize) -> (Self, Self) { ( Self::new(0, constraint_idx_start), @@ -543,69 +391,37 @@ impl NextPcAddConstraint { ) } - /// Compute carry_0 = (pc_lo + instr_size - next_pc_lo) / 2^32 fn compute_carry_0(&self, step: &TableView) -> FieldElement where F: IsSubFieldOf, E: IsField, { let pc_lo = step.get_main_evaluation_element(0, cols::PC_0).clone(); - let next_pc_lo = step.get_main_evaluation_element(0, cols::NEXT_PC_0).clone(); - let c_type = step - .get_main_evaluation_element(0, cols::C_TYPE_INSTRUCTION) + let rvd_lo = step.get_main_evaluation_element(0, cols::RVD_0).clone(); + let half_len = step + .get_main_evaluation_element(0, cols::HALF_INSTRUCTION_LENGTH) .clone(); - - // instr_size = 4 - 2 * c_type_instruction - let four: FieldElement = FieldElement::from(4u64); - let two: FieldElement = FieldElement::from(2u64); - let instr_size = four - two * c_type; - - // carry_0 = (pc_lo + instr_size - next_pc_lo) * 2^(-32) + let instr_len = &half_len + &half_len; // real byte length = 2 * half let inv_2_32 = FieldElement::::from(super::templates::INV_SHIFT_32); - (pc_lo + instr_size - next_pc_lo) * inv_2_32 + (pc_lo + instr_len - rvd_lo) * inv_2_32 } - /// Compute carry_1 = (pc_hi + carry_0 - next_pc_hi) / 2^32 fn compute_carry_1(&self, step: &TableView) -> FieldElement where F: IsSubFieldOf, E: IsField, { let pc_hi = step.get_main_evaluation_element(0, cols::PC_1).clone(); - let next_pc_hi = step.get_main_evaluation_element(0, cols::NEXT_PC_1).clone(); + let rvd_hi = step.get_main_evaluation_element(0, cols::RVD_1).clone(); let carry_0 = self.compute_carry_0(step); - - // rhs_hi = 0 (instruction size fits in low word) - // carry_1 = (pc_hi + 0 + carry_0 - next_pc_hi) * 2^(-32) let inv_2_32 = FieldElement::::from(super::templates::INV_SHIFT_32); - (pc_hi + carry_0 - next_pc_hi) * inv_2_32 - } - - fn compute(&self, step: &TableView) -> FieldElement - where - F: IsSubFieldOf, - E: IsField, - { - let branch_cond = step - .get_main_evaluation_element(0, cols::BRANCH_COND) - .clone(); - let one = FieldElement::::one(); - let not_branch = &one - branch_cond; - - let carry = match self.carry_idx { - 0 => self.compute_carry_0(step), - 1 => self.compute_carry_1(step), - _ => panic!("Invalid carry index"), - }; - - // (1 - branch_cond) * carry * (1 - carry) - not_branch * &carry * (one - carry) + (pc_hi + carry_0 - rvd_hi) * inv_2_32 } } -impl TransitionConstraint for NextPcAddConstraint { +impl TransitionConstraint for BranchRvdConstraint { fn degree(&self) -> usize { - // (1 - branch_cond) * carry * (1 - carry) has degree 3 + // BRANCH (deg 1) · carry · (1 − carry) = 3. 3 } @@ -618,68 +434,36 @@ impl TransitionConstraint for NextPcAddCon F: IsSubFieldOf, E: IsField, { - self.compute(step) + let one = FieldElement::::one(); + let branch = step.get_main_evaluation_element(0, cols::BRANCH).clone(); + let carry = match self.carry_idx { + 0 => self.compute_carry_0(step), + 1 => self.compute_carry_1(step), + _ => unreachable!("carry_idx validated <= 1 at construction"), + }; + branch * &carry * (&one - &carry) } } // ========================================================================= -// Arg2 Constraints +// branch group: branch_cond // ========================================================================= -/// Constraint: arg2[:4] = (1-LOAD)*rv2[:2] + (1-BEQ-BLT-STORE)*imm[0] -/// -/// arg2 lower 32 bits comes from either rv2 or imm depending on instruction type. -pub struct Arg2LowerConstraint { +/// `branch_cond = BRANCH·JALR + BRANCH·(1−JALR)·res[0]` (`cpu.toml` CPU-B1). +/// `JALR = mem_flags` (bit, under BRANCH); `res[0]` is the low half of `res`. +pub struct BranchCondConstraint { constraint_idx: usize, } -impl Arg2LowerConstraint { +impl BranchCondConstraint { pub fn new(constraint_idx: usize) -> Self { Self { constraint_idx } } - - fn compute(&self, step: &TableView) -> FieldElement - where - F: IsSubFieldOf, - E: IsField, - { - let arg2_lo = pack_bytes_to_word( - step, - cols::ARG2[0], - cols::ARG2[1], - cols::ARG2[2], - cols::ARG2[3], - ); - - // rv2 is DWordWHH: rv2[:2] = rv2[0] + rv2[1] * 2^16 - let rv2_0 = step.get_main_evaluation_element(0, cols::RV2_0); - let rv2_1 = step.get_main_evaluation_element(0, cols::RV2_1); - let shift_16: FieldElement = FieldElement::from(1u64 << 16); - let rv2_lower = rv2_0 + rv2_1 * shift_16; - - // imm[0] is lower word of immediate - let imm_0 = step.get_main_evaluation_element(0, cols::IMM_0); - - // Selectors - let store = step.get_main_evaluation_element(0, cols::STORE); - let load = step.get_main_evaluation_element(0, cols::LOAD); - let beq = step.get_main_evaluation_element(0, cols::BEQ); - let blt = step.get_main_evaluation_element(0, cols::BLT); - - let one = FieldElement::::one(); - - // (1-LOAD) * rv2_lower + (1-BEQ-BLT-STORE) * imm[0] - // STORE now gets rv2 (via rv2_lower), not imm - let expected = (&one - load) * rv2_lower + (&one - beq - blt - store) * imm_0; - - // Constraint: arg2_lo - expected = 0 - arg2_lo - expected - } } -impl TransitionConstraint for Arg2LowerConstraint { +impl TransitionConstraint for BranchCondConstraint { fn degree(&self) -> usize { - 2 + 3 } fn constraint_idx(&self) -> usize { @@ -691,182 +475,76 @@ impl TransitionConstraint for Arg2LowerCon F: IsSubFieldOf, E: IsField, { - self.compute(step) - } -} - -/// Constraint: arg2[4:] = (1-LOAD)*((1-word_instr)*rv2[2] + signed*rv2_ext_bit*(2^32-1)) + (1-BEQ-BLT-STORE)*imm[1] -/// -/// arg2 upper 32 bits with sign extension logic. -pub struct Arg2UpperConstraint { - constraint_idx: usize, -} - -impl Arg2UpperConstraint { - pub fn new(constraint_idx: usize) -> Self { - Self { constraint_idx } - } - - fn compute(&self, step: &TableView) -> FieldElement - where - F: IsSubFieldOf, - E: IsField, - { - let arg2_hi = pack_bytes_to_word( - step, - cols::ARG2[4], - cols::ARG2[5], - cols::ARG2[6], - cols::ARG2[7], - ); - - // rv2 is DWordWHH: rv2[2] IS the upper 32 bits directly (Word) - let rv2_upper = step.get_main_evaluation_element(0, cols::RV2_2); - - // imm[1] is upper word of immediate - let imm_1 = step.get_main_evaluation_element(0, cols::IMM_1); - - // Flags - let store = step.get_main_evaluation_element(0, cols::STORE); - let load = step.get_main_evaluation_element(0, cols::LOAD); - let beq = step.get_main_evaluation_element(0, cols::BEQ); - let blt = step.get_main_evaluation_element(0, cols::BLT); - let word_instr = step.get_main_evaluation_element(0, cols::WORD_INSTR); - let signed = step.get_main_evaluation_element(0, cols::SIGNED); - let rv2_ext_bit = step.get_main_evaluation_element(0, cols::RV2_EXT_BIT); - let one = FieldElement::::one(); - let mask_32: FieldElement = FieldElement::from((1u64 << 32) - 1); - - // rv2_term = (1 - word_instr) * rv2[2] + signed * rv2_ext_bit * (2^32 - 1) - let rv2_term = (&one - word_instr) * rv2_upper + signed * rv2_ext_bit * &mask_32; - - // expected = (1-LOAD) * rv2_term + (1-BEQ-BLT-STORE) * imm[1] - // STORE now gets rv2_term (with sign extension), not imm - let expected = (&one - load) * rv2_term + (&one - beq - blt - store) * imm_1; - - // Constraint: arg2_hi - expected = 0 - arg2_hi - expected - } -} - -impl TransitionConstraint for Arg2UpperConstraint { - fn degree(&self) -> usize { - // (1-LOAD) * signed * rv2_ext_bit has degree 3 - 3 - } - - fn constraint_idx(&self) -> usize { - self.constraint_idx - } + let branch = step.get_main_evaluation_element(0, cols::BRANCH).clone(); + let jalr = step.get_main_evaluation_element(0, cols::MEM_FLAGS).clone(); + let res0 = step.get_main_evaluation_element(0, cols::RES_0).clone(); + let branch_cond = step + .get_main_evaluation_element(0, cols::BRANCH_COND) + .clone(); - fn evaluate(&self, step: &TableView) -> FieldElement - where - F: IsSubFieldOf, - E: IsField, - { - self.compute(step) + let expected = &branch * &jalr + &branch * (&one - &jalr) * res0; + branch_cond - expected } } // ========================================================================= -// RVD Constraints +// branch group: next_pc = pc + instruction_length (when not branching) // ========================================================================= -/// Constraint: (1-LOAD) * (rvd[0] - res[:4]) = 0 -/// -/// When not LOAD, rvd lower 32 bits equals res lower 32 bits. -/// For LOAD: rvd is the loaded value, not res (which is the address). -/// For non-LOAD ops (including STORE): rvd must equal res in the trace. -pub struct RvdLowerConstraint { +/// `(1 − branch_cond) · carry · (1 − carry) = 0` for the 64-bit addition +/// `next_pc = pc + instruction_length`. Two instances (carry_0/carry_1). +pub struct NextPcAddConstraint { + carry_idx: usize, constraint_idx: usize, } -impl RvdLowerConstraint { - pub fn new(constraint_idx: usize) -> Self { - Self { constraint_idx } - } - - fn compute(&self, step: &TableView) -> FieldElement - where - F: IsSubFieldOf, - E: IsField, - { - // rvd[0] is lower word - let rvd_0 = step.get_main_evaluation_element(0, cols::RVD_0); - - let res_lo = - pack_bytes_to_word(step, cols::RES[0], cols::RES[1], cols::RES[2], cols::RES[3]); - - let load = step.get_main_evaluation_element(0, cols::LOAD); - let one = FieldElement::::one(); - - // (1 - LOAD) * (rvd[0] - res_lo) = 0 - (one - load) * (rvd_0 - res_lo) - } -} - -impl TransitionConstraint for RvdLowerConstraint { - fn degree(&self) -> usize { - 2 +impl NextPcAddConstraint { + pub fn new(carry_idx: usize, constraint_idx: usize) -> Self { + assert!(carry_idx <= 1); + Self { + carry_idx, + constraint_idx, + } } - fn constraint_idx(&self) -> usize { - self.constraint_idx + pub fn new_pair(constraint_idx_start: usize) -> (Self, Self) { + ( + Self::new(0, constraint_idx_start), + Self::new(1, constraint_idx_start + 1), + ) } - fn evaluate(&self, step: &TableView) -> FieldElement + fn compute_carry_0(&self, step: &TableView) -> FieldElement where F: IsSubFieldOf, E: IsField, { - self.compute(step) - } -} - -/// Constraint: (1-LOAD) * (rvd[1] - ((1-word_instr)*res[4:] + res_ext_bit*(2^32-1))) = 0 -/// -/// When not LOAD, rvd upper 32 bits equals res upper with sign extension. -/// For LOAD: rvd is the loaded value, not res (which is the address). -/// For non-LOAD ops (including STORE): rvd must equal res in the trace. -pub struct RvdUpperConstraint { - constraint_idx: usize, -} - -impl RvdUpperConstraint { - pub fn new(constraint_idx: usize) -> Self { - Self { constraint_idx } + let pc_lo = step.get_main_evaluation_element(0, cols::PC_0).clone(); + let next_pc_lo = step.get_main_evaluation_element(0, cols::NEXT_PC_0).clone(); + let half_len = step + .get_main_evaluation_element(0, cols::HALF_INSTRUCTION_LENGTH) + .clone(); + let instr_len = &half_len + &half_len; // real byte length = 2 * half + let inv_2_32 = FieldElement::::from(super::templates::INV_SHIFT_32); + (pc_lo + instr_len - next_pc_lo) * inv_2_32 } - fn compute(&self, step: &TableView) -> FieldElement + fn compute_carry_1(&self, step: &TableView) -> FieldElement where F: IsSubFieldOf, E: IsField, { - // rvd[1] is upper word - let rvd_1 = step.get_main_evaluation_element(0, cols::RVD_1); - - let res_hi = - pack_bytes_to_word(step, cols::RES[4], cols::RES[5], cols::RES[6], cols::RES[7]); - - let load = step.get_main_evaluation_element(0, cols::LOAD); - let word_instr = step.get_main_evaluation_element(0, cols::WORD_INSTR); - let res_ext_bit = step.get_main_evaluation_element(0, cols::RES_EXT_BIT); - - let one = FieldElement::::one(); - let mask_32: FieldElement = FieldElement::from((1u64 << 32) - 1); - - // expected = (1 - word_instr) * res_hi + res_ext_bit * (2^32 - 1) - let expected = (&one - word_instr) * res_hi + res_ext_bit * mask_32; - - // (1 - LOAD) * (rvd[1] - expected) = 0 - (one - load) * (rvd_1 - expected) + let pc_hi = step.get_main_evaluation_element(0, cols::PC_1).clone(); + let next_pc_hi = step.get_main_evaluation_element(0, cols::NEXT_PC_1).clone(); + let carry_0 = self.compute_carry_0(step); + let inv_2_32 = FieldElement::::from(super::templates::INV_SHIFT_32); + (pc_hi + carry_0 - next_pc_hi) * inv_2_32 } } -impl TransitionConstraint for RvdUpperConstraint { +impl TransitionConstraint for NextPcAddConstraint { fn degree(&self) -> usize { - // (1-LOAD) * (1-word_instr) * res_hi has degree 3 3 } @@ -879,189 +557,64 @@ impl TransitionConstraint for RvdUpperCons F: IsSubFieldOf, E: IsField, { - self.compute(step) - } -} - -// ========================================================================= -// read_register - register Constraints (CM48, CM50) -// ========================================================================= - -/// Constraint: `(1 - flag_col) * value_col = 0` -/// -/// Forces `value_col` to zero whenever `flag_col` is 0. -/// -/// Used for: -/// - CPU-CM48.i: `(1 - read_register1) * rv1[i] = 0` for i ∈ [0, 2] -/// When read_register1 = 0 (rs1 is x0), rv1 is not loaded from memory, -/// so it must be forced to zero by a polynomial constraint. -/// - CPU-CM50.i: `(1 - read_register2) * rv2[i] = 0` for i ∈ [0, 2] -/// Same logic for rv2 when read_register2 = 0 (I-type instructions). -pub struct RegNotReadIsZeroConstraint { - flag_col: usize, - value_col: usize, - constraint_idx: usize, -} - -impl RegNotReadIsZeroConstraint { - pub fn new(flag_col: usize, value_col: usize, constraint_idx: usize) -> Self { - Self { - flag_col, - value_col, - constraint_idx, - } - } - - fn compute(&self, step: &TableView) -> FieldElement - where - F: IsSubFieldOf, - E: IsField, - { - let flag = step.get_main_evaluation_element(0, self.flag_col).clone(); - let value = step.get_main_evaluation_element(0, self.value_col).clone(); + let branch_cond = step + .get_main_evaluation_element(0, cols::BRANCH_COND) + .clone(); let one = FieldElement::::one(); - // (1 - flag) * value = 0 - (one - flag) * value - } -} - -impl TransitionConstraint for RegNotReadIsZeroConstraint { - fn degree(&self) -> usize { - 2 - } - - fn constraint_idx(&self) -> usize { - self.constraint_idx - } - - fn evaluate(&self, step: &TableView) -> FieldElement - where - F: IsSubFieldOf, - E: IsField, - { - self.compute(step) + let not_branch = &one - branch_cond; + let carry = match self.carry_idx { + 0 => self.compute_carry_0(step), + 1 => self.compute_carry_1(step), + _ => panic!("Invalid carry index"), + }; + not_branch * &carry * (one - carry) } } // ========================================================================= -// SUB Constraints +// alu group: ADD / SUB fast-path templates // ========================================================================= -/// Creates SUB constraints for the CPU table. -/// -/// SUB template is used when: SUB + BEQ > 0 -/// - SUB: res = arg1 - arg2 -/// - BEQ: computes arg1 - arg2 to check equality (res = 0 means equal) -/// -/// Verifies: arg2 + res = arg1 (subtraction expressed as addition) -/// -/// Returns the constraints and the next available constraint index. -pub fn create_sub_constraints(constraint_idx_start: usize) -> (Vec, usize) { - // SUB is verified as: arg2 + res = arg1 - // This is the ADD template with swapped roles: - // - lhs = arg2 - // - rhs = res - // - sum = arg1 - - let lhs = AddOperand::from_dword_bl(cols::ARG2_0); // First addend - let rhs = AddOperand::from_dword_bl(cols::RES_0); // Second addend (the difference) - let sum = AddOperand::from_dword_bl(cols::ARG1_0); // Result of addition (original minuend) - - // Condition: SUB + BEQ (active when either flag is set) - let cond_cols = vec![cols::SUB, cols::BEQ]; - - let (sub_c0, sub_c1) = AddConstraint::new_pair(cond_cols, lhs, rhs, sum, constraint_idx_start); - - (vec![sub_c0, sub_c1], constraint_idx_start + 2) +/// ADD fast-path: `cond = ADD`, `rv1 + arg2 = cast(res, WL)`. Covers ADD, LOAD, +/// STORE and JAL(R) (all set `ADD`). +pub fn create_add_constraints(constraint_idx_start: usize) -> (Vec, usize) { + let lhs = AddOperand::dword(cols::RV1_0); + let rhs = AddOperand::dword(cols::ARG2_0); + let sum = AddOperand::from_dword_hl(cols::RES_0); + let (c0, c1) = AddConstraint::new_pair(vec![cols::ADD], lhs, rhs, sum, constraint_idx_start); + (vec![c0, c1], constraint_idx_start + 2) } -// ========================================================================= -// JALR Result Constraint -// ========================================================================= - -/// Creates JALR result constraints using the ADD template. -/// -/// JALR: res = pc + instr_size (return address) -/// where instr_size = 4 - 2 * c_type_instruction -/// -/// This uses proper 64-bit addition with carry handling. -pub fn create_jalr_constraints(constraint_idx_start: usize) -> (Vec, usize) { - // pc is stored as DWordWL (2 consecutive columns) - let pc = AddOperand::dword(cols::PC_0); - - // instr_size = 4 - 2 * c_type_instruction - // This is a linear expression with only a low word (hi = 0) - let instr_size = AddOperand::linear( - vec![ - AddLinearTerm::Constant(4), - AddLinearTerm::Column { - coefficient: -2, - column: cols::C_TYPE_INSTRUCTION, - }, - ], - vec![], // hi = 0 - ); - - // res is stored as DWordBL (8 bytes) - let res = AddOperand::from_dword_bl(cols::RES_0); - - // Condition: JALR - let cond_cols = vec![cols::JALR]; - - let (jalr_c0, jalr_c1) = - AddConstraint::new_pair(cond_cols, pc, instr_size, res, constraint_idx_start); - - (vec![jalr_c0, jalr_c1], constraint_idx_start + 2) +/// SUB fast-path: `cond = SUB`, `res = rv1 − arg2`, verified as `arg2 + res = rv1`. +pub fn create_sub_constraints(constraint_idx_start: usize) -> (Vec, usize) { + let lhs = AddOperand::dword(cols::ARG2_0); + let rhs = AddOperand::from_dword_hl(cols::RES_0); + let sum = AddOperand::dword(cols::RV1_0); + let (c0, c1) = AddConstraint::new_pair(vec![cols::SUB], lhs, rhs, sum, constraint_idx_start); + (vec![c0, c1], constraint_idx_start + 2) } // ========================================================================= -// Inline PC Constraints -// ========================================================================= -// -// Per spec/cpu.typ: "Constraints on `pc_double_read` corresponding to an `AUIPC` -// instruction are not necessary, as regardless of its value, the old timestamp is -// guaranteed smaller than the new timestamp, and the integrity of the memory -// argument therefore ensures the correctness of this bit." -// -// The IS_BIT constraints on PC_DOUBLE_READ and PREV_PC_TIMESTAMP_BORROW are -// sufficient; no extra algebraic constraints linking them to rs1/read_register1 -// or to each other are required. - -// ========================================================================= -// Constraint Summary +// Assembly // ========================================================================= -/// Total number of CPU constraints. -/// -/// - IS_BIT: 34 (all bit flags, including read_register1/2 and inline-PC columns) -/// - ADD carry: 2 (for ADD + LOAD) -/// - STORE ADD carry: 2 (for STORE: res = arg1 + imm) -/// - SUB carry: 2 (for SUB + BEQ) -/// - JALR carry: 2 (res = pc + instr_size) -/// - Branch cond: 1 -/// - EBREAK: 1 -/// - Arg1 lower: 1 -/// - Arg1 upper: 1 -/// - Arg2 lower: 1 -/// - Arg2 upper: 1 -/// - Rvd lower: 1 -/// - Rvd upper: 1 -/// - SLT res zero: 7 (bytes 1-7) -/// - Ext bit zero (SIGN template): 3 (rv1_ext_bit, rv2_ext_bit, res_ext_bit) -/// - rv1 zero-forcing (CM48): 3 (rv1[0..2] when read_register1 = 0) -/// - rv2 zero-forcing (CM50): 3 (rv2[0..2] when read_register2 = 0) -/// - Next PC (non-branching): 2 +/// Total number of CPU transition constraints (excludes bus lookups): +/// - IS_BIT: 12 +/// - decode mutex: 6 (`word_instr · {MEMORY, BRANCH, ECALL, WRITE_REGISTER, +/// READ_REGISTER1, READ_REGISTER2}`) +/// - ADD pair: 2, SUB pair: 2 +/// - arg2 multiplex: 2 +/// - register zero-forcing: 4 (`rv1[0..1]`, `rv2[0..1]`) +/// - rvd = res: 2 +/// - branch rvd (`pc + len`): 2 +/// - branch_cond: 1 +/// - next_pc: 2 +/// - assumptions: 4 (MEMORY·BRANCH mutex 1 + arg2 exclusivity 2 + mem_flags IS_BIT 1) +pub const NUM_CPU_CONSTRAINTS: usize = 12 + 6 + 2 + 2 + 2 + 4 + 2 + 2 + 1 + 2 + 4; + +/// Creates all CPU transition constraints. /// -/// Total: 68 constraints (34 IS_BIT + 8 ADD + 26 other) -/// (The inline PC columns PC_DOUBLE_READ and PREV_PC_TIMESTAMP_BORROW are -/// IS_BIT-constrained; per spec/cpu.typ no additional algebraic constraints -/// are required.) -pub const NUM_CPU_CONSTRAINTS: usize = - 34 + 2 + 2 + 2 + 2 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 7 + 3 + 3 + 3 + 2; - -/// Creates all CPU constraints. -/// -/// Returns a tuple of (is_bit_constraints, add_constraints, other_constraints, next_idx) +/// Returns `(is_bit_constraints, add_constraints, other_constraints, next_idx)`. #[allow(clippy::type_complexity)] pub fn create_all_cpu_constraints() -> ( Vec, @@ -1071,91 +624,88 @@ pub fn create_all_cpu_constraints() -> ( ) { let mut next_idx = 0; - // IS_BIT constraints + // range: IS_BIT let (is_bit, next) = create_is_bit_constraints(next_idx); next_idx = next; - // ADD constraints (for ADD + LOAD + STORE) + // alu: ADD + SUB fast-paths let (mut add_constraints, next) = create_add_constraints(next_idx); next_idx = next; - - // SUB constraints (for SUB + BEQ) let (sub, next) = create_sub_constraints(next_idx); next_idx = next; add_constraints.extend(sub); - // JALR constraints (res = pc + instr_size) - let (jalr, next) = create_jalr_constraints(next_idx); - next_idx = next; - add_constraints.extend(jalr); - - // Other constraints let mut other: Vec< Box>, > = Vec::new(); - // Branch condition - other.push(BranchCondConstraint::new(next_idx).boxed()); - next_idx += 1; + // decode: word_instr mutex with MEMORY / BRANCH / ECALL, plus word_instr ⇒ + // {write,read1,read2}_register = 0 (word instructions are delegated to CPU32 + // and must not touch the main register file — leaving these free is unsound). + // The register-read gates are spec-mandated ("out of caution"). + for &col in &[ + cols::MEMORY, + cols::BRANCH, + cols::ECALL, + cols::WRITE_REGISTER, + cols::READ_REGISTER1, + cols::READ_REGISTER2, + ] { + other.push(ProductZeroConstraint::new(cols::WORD_INSTR, col, next_idx).boxed()); + next_idx += 1; + } - // EBREAK - other.push(EbreakConstraint::new(next_idx).boxed()); + // alu: arg2 multiplex (low, high words) + other.push(Arg2Constraint::new(0, next_idx).boxed()); + next_idx += 1; + other.push(Arg2Constraint::new(1, next_idx).boxed()); next_idx += 1; - // rv1 zero-forcing (CM48): (1 - read_register1) * rv1[i] = 0 for i ∈ [0, 2] - for &value_col in &[cols::RV1_0, cols::RV1_1, cols::RV1_2] { + // mem: register zero-forcing (rv1/rv2 are DWordWL → 2 words each) + for &value_col in &[cols::RV1_0, cols::RV1_1] { other.push( RegNotReadIsZeroConstraint::new(cols::READ_REGISTER1, value_col, next_idx).boxed(), ); next_idx += 1; } - - // rv2 zero-forcing (CM50): (1 - read_register2) * rv2[i] = 0 for i ∈ [0, 2] - for &value_col in &[cols::RV2_0, cols::RV2_1, cols::RV2_2] { + for &value_col in &[cols::RV2_0, cols::RV2_1] { other.push( RegNotReadIsZeroConstraint::new(cols::READ_REGISTER2, value_col, next_idx).boxed(), ); next_idx += 1; } - // Arg1 constraints - other.push(Arg1LowerConstraint::new(next_idx).boxed()); - next_idx += 1; - other.push(Arg1UpperConstraint::new(next_idx).boxed()); - next_idx += 1; - - // Arg2 constraints - other.push(Arg2LowerConstraint::new(next_idx).boxed()); - next_idx += 1; - other.push(Arg2UpperConstraint::new(next_idx).boxed()); - next_idx += 1; - - // Rvd constraints - other.push(RvdLowerConstraint::new(next_idx).boxed()); + // mem: ¬MEMORY ∧ ¬BRANCH ⇒ rvd = cast(res, WL) + other.push(RvdEqResConstraint::new(0, next_idx).boxed()); next_idx += 1; - other.push(RvdUpperConstraint::new(next_idx).boxed()); + other.push(RvdEqResConstraint::new(1, next_idx).boxed()); next_idx += 1; - // SLT res zero constraints - let (slt_zero, next) = create_slt_res_zero_constraints(next_idx); - next_idx = next; - for c in slt_zero { - other.push(c.boxed()); - } + // branch: BRANCH ⇒ rvd = pc + instruction_length (JAL/JALR return), carry-aware + let (branch_rvd_0, branch_rvd_1) = BranchRvdConstraint::new_pair(next_idx); + other.push(branch_rvd_0.boxed()); + other.push(branch_rvd_1.boxed()); + next_idx += 2; - // Extension bit zero constraints (SIGN template: !word_instr => ext_bit = 0) - other.push(ExtBitZeroConstraint::new(next_idx, cols::RV1_EXT_BIT).boxed()); - next_idx += 1; - other.push(ExtBitZeroConstraint::new(next_idx, cols::RV2_EXT_BIT).boxed()); - next_idx += 1; - other.push(ExtBitZeroConstraint::new(next_idx, cols::RES_EXT_BIT).boxed()); + // branch: branch_cond + next_pc + other.push(BranchCondConstraint::new(next_idx).boxed()); next_idx += 1; - - // Next PC (non-branching) constraints let (next_pc_0, next_pc_1) = NextPcAddConstraint::new_pair(next_idx); other.push(next_pc_0.boxed()); other.push(next_pc_1.boxed()); next_idx += 2; + // assumptions (spec defense-in-depth, redundant with the DECODE lookup): + // MEMORY/BRANCH mutex, arg2 multiplex exclusivity, and IS_BIT on + // non-memory rows. + other.push(ProductZeroConstraint::new(cols::MEMORY, cols::BRANCH, next_idx).boxed()); + next_idx += 1; + for &imm_col in &[cols::IMM_0, cols::IMM_1] { + other.push(Arg2ExclusiveConstraint::new(imm_col, next_idx).boxed()); + next_idx += 1; + } + other.push(MemFlagsBitConstraint::new(next_idx).boxed()); + next_idx += 1; + (is_bit, add_constraints, other, next_idx) } diff --git a/prover/src/lib.rs b/prover/src/lib.rs index aaefc60ed..8c8cbed17 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -48,11 +48,12 @@ use crate::tables::trace_builder::Traces; use crate::tables::trace_builder::count_table_lengths; use crate::tables::types::BusId; use crate::test_utils::{ - E, F, VmAir, create_bitwise_air, create_branch_air, create_commit_air, create_cpu_air, - create_decode_air, create_dvrm_air, create_halt_air, create_keccak_air, create_keccak_rc_air, - create_keccak_rnd_air, create_load_air, create_lt_air, create_memw_air, - create_memw_aligned_air, create_memw_register_air, create_mul_air, create_page_air, - create_register_air, create_shift_air, + E, F, VmAir, create_bitwise_air, create_branch_air, create_bytewise_air, create_commit_air, + create_cpu_air, create_cpu32_air, create_decode_air, create_dvrm_air, create_eq_air, + create_halt_air, create_keccak_air, create_keccak_rc_air, create_keccak_rnd_air, + create_load_air, create_lt_air, create_memw_air, create_memw_aligned_air, + create_memw_register_air, create_mul_air, create_page_air, create_register_air, + create_shift_air, create_store_air, }; use stark::proof::options::{GoldilocksCubicProofOptions, ProofOptions}; @@ -84,6 +85,11 @@ pub struct TableCounts { pub shift: usize, pub branch: usize, pub memw_register: usize, + // Auxiliary ALU / memory / CPU32 dispatch chips + pub eq: usize, + pub bytewise: usize, + pub store: usize, + pub cpu32: usize, } impl TableCounts { @@ -99,6 +105,10 @@ impl TableCounts { + self.shift + self.branch + self.memw_register + + self.eq + + self.bytewise + + self.store + + self.cpu32 } /// Validate that all required tables have at least one chunk. @@ -117,6 +127,10 @@ impl TableCounts { ("shift", self.shift), ("branch", self.branch), ("memw_register", self.memw_register), + ("eq", self.eq), + ("bytewise", self.bytewise), + ("store", self.store), + ("cpu32", self.cpu32), ]; for (name, count) in checks { if count == 0 { @@ -212,6 +226,11 @@ pub(crate) struct VmAirs { pub register: VmAir, pub pages: Vec, pub memw_registers: Vec, + // Auxiliary ALU / memory / CPU32 dispatch chips + pub eqs: Vec, + pub bytewises: Vec, + pub stores: Vec, + pub cpu32s: Vec, } impl VmAirs { @@ -269,6 +288,18 @@ impl VmAirs { { pairs.push((air, trace, &())); } + for (air, trace) in self.eqs.iter().zip(traces.eqs.iter_mut()) { + pairs.push((air, trace, &())); + } + for (air, trace) in self.bytewises.iter().zip(traces.bytewises.iter_mut()) { + pairs.push((air, trace, &())); + } + for (air, trace) in self.stores.iter().zip(traces.stores.iter_mut()) { + pairs.push((air, trace, &())); + } + for (air, trace) in self.cpu32s.iter().zip(traces.cpu32s.iter_mut()) { + pairs.push((air, trace, &())); + } pairs } @@ -319,6 +350,18 @@ impl VmAirs { for air in &self.memw_registers { refs.push(air); } + for air in &self.eqs { + refs.push(air); + } + for air in &self.bytewises { + refs.push(air); + } + for air in &self.stores { + refs.push(air); + } + for air in &self.cpu32s { + refs.push(air); + } refs } @@ -424,6 +467,18 @@ impl VmAirs { let memw_registers: Vec<_> = (0..table_counts.memw_register) .map(|i| create_memw_register_air(proof_options).with_name(&format!("MEMW_R[{}]", i))) .collect(); + let eqs: Vec<_> = (0..table_counts.eq) + .map(|i| create_eq_air(proof_options).with_name(&format!("EQ[{}]", i))) + .collect(); + let bytewises: Vec<_> = (0..table_counts.bytewise) + .map(|i| create_bytewise_air(proof_options).with_name(&format!("BYTEWISE[{}]", i))) + .collect(); + let stores: Vec<_> = (0..table_counts.store) + .map(|i| create_store_air(proof_options).with_name(&format!("STORE[{}]", i))) + .collect(); + let cpu32s: Vec<_> = (0..table_counts.cpu32) + .map(|i| create_cpu32_air(proof_options).with_name(&format!("CPU32[{}]", i))) + .collect(); #[cfg(feature = "debug-checks")] debug_report::print_bus_legend(); @@ -448,6 +503,10 @@ impl VmAirs { register, pages, memw_registers, + eqs, + bytewises, + stores, + cpu32s, } } } diff --git a/prover/src/statement.rs b/prover/src/statement.rs index 82c41861c..7935abe66 100644 --- a/prover/src/statement.rs +++ b/prover/src/statement.rs @@ -16,7 +16,7 @@ use crate::test_utils::E; use crate::{RuntimePageRange, TableCounts}; /// Domain-separation tag. Bump the suffix (`_V2`, ...) on any encoding change. -const DOMAIN_TAG: &[u8] = b"LAMBDAVM_STARK_STATEMENT_V1"; +const DOMAIN_TAG: &[u8] = b"LAMBDAVM_STARK_STATEMENT_V2"; fn elf_digest(elf: &[u8]) -> [u8; 32] { let mut h = Keccak256::new(); @@ -55,6 +55,10 @@ pub(crate) fn absorb_statement( shift, branch, memw_register, + eq, + bytewise, + store, + cpu32, } = table_counts; for count in [ cpu, @@ -67,6 +71,10 @@ pub(crate) fn absorb_statement( shift, branch, memw_register, + eq, + bytewise, + store, + cpu32, ] { t.append_bytes(&(count as u64).to_le_bytes()); } diff --git a/prover/src/tables/bitwise.rs b/prover/src/tables/bitwise.rs index bdf7cfc99..db0236987 100644 --- a/prover/src/tables/bitwise.rs +++ b/prover/src/tables/bitwise.rs @@ -38,7 +38,7 @@ use stark::trace::{TraceTable, columns2rows}; #[cfg(feature = "parallel")] use rayon::prelude::*; -use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; +use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField, alu_op}; // ========================================================================= // Column indices for BITWISE table @@ -92,8 +92,14 @@ pub mod cols { pub const MU_IS_B20: usize = 19; /// Multiplicity for HWSL lookups pub const MU_HWSL: usize = 20; + /// Multiplicity for `BYTE_ALU[opsel=AND]` lookups + pub const MU_BYTE_ALU_AND: usize = 21; + /// Multiplicity for `BYTE_ALU[opsel=OR]` lookups + pub const MU_BYTE_ALU_OR: usize = 22; + /// Multiplicity for `BYTE_ALU[opsel=XOR]` lookups + pub const MU_BYTE_ALU_XOR: usize = 23; /// Total number of columns - pub const NUM_COLUMNS: usize = 21; + pub const NUM_COLUMNS: usize = 24; } /// Number of rows in the BITWISE table: 256 * 256 * 16 = 2^20 @@ -442,6 +448,9 @@ pub fn update_multiplicities( BitwiseOperationType::IsHalf => cols::MU_IS_HALF, BitwiseOperationType::IsB20 => cols::MU_IS_B20, BitwiseOperationType::Hwsl => cols::MU_HWSL, + BitwiseOperationType::ByteAluAnd => cols::MU_BYTE_ALU_AND, + BitwiseOperationType::ByteAluOr => cols::MU_BYTE_ALU_OR, + BitwiseOperationType::ByteAluXor => cols::MU_BYTE_ALU_XOR, }; // Increment multiplicity @@ -477,8 +486,10 @@ pub(crate) fn trim_zero_rows( let kept_rows: Vec = (0..num_rows) .filter(|&row| { let row_data = trace.main_table.get_row(row); - // Check all multiplicity columns (indices 11-20) - (cols::MU_AND..=cols::MU_HWSL).any(|col| row_data[col] != FE::zero()) + // Check all multiplicity columns (MU_AND..=MU_BYTE_ALU_XOR), including + // the BYTE_ALU columns (rows used only by a BYTE_ALU lookup + // must not be trimmed). + (cols::MU_AND..=cols::MU_BYTE_ALU_XOR).any(|col| row_data[col] != FE::zero()) }) .collect(); @@ -519,6 +530,10 @@ pub enum BitwiseOperationType { IsHalf, IsB20, Hwsl, + // Unified `BYTE_ALU` lookups, keyed by opsel. + ByteAluAnd, + ByteAluOr, + ByteAluXor, } /// A lookup request to the BITWISE precomputed table. @@ -807,5 +822,67 @@ pub fn bus_interactions() -> Vec { }, ], ), + // BYTE_ALU[opsel, X, Y] -> out. + // Unifies AND/OR/XOR into one bus keyed by the `alu_op` descriptor. + // Implemented as one receiver per opsel, reusing the precomputed + // AND/OR/XOR result columns (the "single 2^20 column" in bitwise.typ is + // an optimization note, not a requirement). + BusInteraction::receiver( + BusId::ByteAlu, + Multiplicity::Column(cols::MU_BYTE_ALU_AND), + vec![ + BusValue::constant(alu_op::AND as u64), + BusValue::Packed { + start_column: cols::X, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::Y, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::AND, + packing: Packing::Direct, + }, + ], + ), + BusInteraction::receiver( + BusId::ByteAlu, + Multiplicity::Column(cols::MU_BYTE_ALU_OR), + vec![ + BusValue::constant(alu_op::OR as u64), + BusValue::Packed { + start_column: cols::X, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::Y, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::OR, + packing: Packing::Direct, + }, + ], + ), + BusInteraction::receiver( + BusId::ByteAlu, + Multiplicity::Column(cols::MU_BYTE_ALU_XOR), + vec![ + BusValue::constant(alu_op::XOR as u64), + BusValue::Packed { + start_column: cols::X, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::Y, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::XOR, + packing: Packing::Direct, + }, + ], + ), ] } diff --git a/prover/src/tables/branch.rs b/prover/src/tables/branch.rs index 1a4cff20c..b2e53e91b 100644 --- a/prover/src/tables/branch.rs +++ b/prover/src/tables/branch.rs @@ -395,6 +395,8 @@ pub enum BranchConstraintKind { /// `(1 - JALR) * carry_1_pc * (1 - carry_1_pc) = 0` /// where carry_1_pc = (pc[1] + offset[1] + carry_0_pc - next_pc_unmasked[1]) / 2^32 PcCarry1IsBit, + /// `IS_BIT`: `JALR * (1 - JALR) = 0` (spec defense-in-depth assumption) + JalrIsBit, /// `JALR * carry_0_reg * (1 - carry_0_reg) = 0` /// where carry_0_reg = (register[0] + offset[0] - next_pc_unmasked[0]) / 2^32 RegCarry0IsBit, @@ -494,6 +496,7 @@ impl BranchConstraint { let one = FieldElement::::one(); match self.kind { + BranchConstraintKind::JalrIsBit => &jalr * (&one - &jalr), BranchConstraintKind::PcCarry0IsBit => { let cond = &one - &jalr; let c = Self::compute_carry_0_for(cols::PC_0, step); @@ -520,8 +523,12 @@ impl BranchConstraint { impl TransitionConstraint for BranchConstraint { fn degree(&self) -> usize { - // cond (degree 1) * carry (degree 1) * (1 - carry) (degree 1) = degree 3 - 3 + match self.kind { + // JALR * (1 - JALR) = degree 2 + BranchConstraintKind::JalrIsBit => 2, + // cond (degree 1) * carry (degree 1) * (1 - carry) (degree 1) = degree 3 + _ => 3, + } } fn constraint_idx(&self) -> usize { @@ -539,11 +546,13 @@ impl TransitionConstraint for BranchConstr /// Creates all constraints for the BRANCH table. /// -/// Returns 4 constraints (two conditional ADD templates × 2 carries each): +/// Returns 5 constraints (two conditional ADD templates × 2 carries each, plus +/// the `IS_BIT` defense-in-depth assumption): /// - PcCarry0IsBit: `(1 - JALR) * carry_0 * (1 - carry_0) = 0` (pc path) /// - PcCarry1IsBit: `(1 - JALR) * carry_1 * (1 - carry_1) = 0` (pc path) /// - RegCarry0IsBit: `JALR * carry_0 * (1 - carry_0) = 0` (register path) /// - RegCarry1IsBit: `JALR * carry_1 * (1 - carry_1) = 0` (register path) +/// - JalrIsBit: `JALR * (1 - JALR) = 0` pub fn branch_constraints(constraint_idx_start: usize) -> (Vec, usize) { let mut idx = constraint_idx_start; let mut next = || { @@ -556,6 +565,7 @@ pub fn branch_constraints(constraint_idx_start: usize) -> (Vec BranchConstraint::new(BranchConstraintKind::PcCarry1IsBit, next()), BranchConstraint::new(BranchConstraintKind::RegCarry0IsBit, next()), BranchConstraint::new(BranchConstraintKind::RegCarry1IsBit, next()), + BranchConstraint::new(BranchConstraintKind::JalrIsBit, next()), ]; (constraints, idx) } diff --git a/prover/src/tables/bytewise.rs b/prover/src/tables/bytewise.rs new file mode 100644 index 000000000..16c811cfb --- /dev/null +++ b/prover/src/tables/bytewise.rs @@ -0,0 +1,188 @@ +//! BYTEWISE ALU table. +//! +//! Computes a full-word bitwise `AND`/`OR`/`XOR` of two 64-bit values by +//! decomposing them into bytes and delegating each byte to the `BYTE_ALU` +//! lookup. The CPU dispatches here on the unified `ALU` bus for `alu_op` +//! `AND`(0)/`OR`(1)/`XOR`(2); `alu_flags` for these ops equals just the opcode. +//! +//! Spec: `spec/src/bytewise.toml`. The chip has no polynomial constraints — +//! correctness is entirely enforced by the lookups (the `BYTE_ALU` lookup also +//! range-checks each input byte). +//! +//! ## Columns +//! - `a`: DWordBL (8 bytes) — first input +//! - `b`: DWordBL (8 bytes) — second input +//! - `op`: Byte — the `alu_op` opcode (AND/OR/XOR) +//! - `res`: DWordBL (8 bytes) — output +//! - `μ`: multiplicity + +use stark::lookup::{BusInteraction, BusValue, Multiplicity, Packing}; +use stark::trace::TraceTable; + +use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField, alu_op}; + +// ========================================================================= +// Column indices for BYTEWISE table +// ========================================================================= + +/// Column definitions for the BYTEWISE table. +pub mod cols { + /// a as 8 bytes (DWordBL), little-endian. + pub const A: [usize; 8] = [0, 1, 2, 3, 4, 5, 6, 7]; + /// b as 8 bytes (DWordBL), little-endian. + pub const B: [usize; 8] = [8, 9, 10, 11, 12, 13, 14, 15]; + /// op: Byte (alu_op opcode: AND/OR/XOR) + pub const OP: usize = 16; + /// res as 8 bytes (DWordBL), little-endian. + pub const RES: [usize; 8] = [17, 18, 19, 20, 21, 22, 23, 24]; + /// μ: multiplicity + pub const MU: usize = 25; + + /// Total number of columns + pub const NUM_COLUMNS: usize = 26; +} + +// ========================================================================= +// Trace generation +// ========================================================================= + +/// A single BYTEWISE operation. `op` is an [`alu_op`] opcode in {AND, OR, XOR}. +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct BytewiseOperation { + pub a: u64, + pub b: u64, + pub op: u8, +} + +impl BytewiseOperation { + /// Create a new BYTEWISE operation. + pub fn new(a: u64, b: u64, op: u8) -> Self { + Self { a, b, op } + } + + /// The result of applying `op` to `a` and `b` (byte-wise == full-word). + pub fn compute_res(&self) -> u64 { + match self.op { + alu_op::AND => self.a & self.b, + alu_op::OR => self.a | self.b, + alu_op::XOR => self.a ^ self.b, + other => panic!("BYTEWISE only handles AND/OR/XOR, got opcode {other}"), + } + } + + /// The 8 `BYTE_ALU` lookups this op sends, for the BITWISE table's + /// multiplicity bookkeeping (one per byte, keyed by opsel). + pub fn collect_bitwise_ops(&self) -> Vec { + use super::bitwise::{BitwiseOperation, BitwiseOperationType}; + let kind = match self.op { + alu_op::AND => BitwiseOperationType::ByteAluAnd, + alu_op::OR => BitwiseOperationType::ByteAluOr, + alu_op::XOR => BitwiseOperationType::ByteAluXor, + other => panic!("BYTEWISE only handles AND/OR/XOR, got opcode {other}"), + }; + (0..8) + .map(|i| { + let a = ((self.a >> (i * 8)) & 0xFF) as u8; + let b = ((self.b >> (i * 8)) & 0xFF) as u8; + BitwiseOperation::byte_op(kind, a, b) + }) + .collect() + } +} + +/// Generates the BYTEWISE trace from a list of operations. +/// +/// Duplicate operations are merged with summed multiplicities, then padded to +/// the next power of two (minimum 4). +pub fn generate_bytewise_trace( + operations: &[BytewiseOperation], +) -> TraceTable { + use std::collections::HashMap; + + let mut op_map: HashMap = HashMap::new(); + for op in operations { + *op_map.entry(op.clone()).or_insert(0) += 1; + } + + let unique_ops: Vec<_> = op_map.into_iter().collect(); + let num_rows = unique_ops.len().next_power_of_two().max(4); + let mut data = vec![FE::zero(); num_rows * cols::NUM_COLUMNS]; + + for (row_idx, (op, multiplicity)) in unique_ops.iter().enumerate() { + let base = row_idx * cols::NUM_COLUMNS; + let res = op.compute_res(); + + for i in 0..8 { + data[base + cols::A[i]] = FE::from((op.a >> (8 * i)) & 0xFF); + data[base + cols::B[i]] = FE::from((op.b >> (8 * i)) & 0xFF); + data[base + cols::RES[i]] = FE::from((res >> (8 * i)) & 0xFF); + } + data[base + cols::OP] = FE::from(op.op as u64); + data[base + cols::MU] = FE::from(*multiplicity); + } + + TraceTable::new_main(data, cols::NUM_COLUMNS, 1) +} + +// ========================================================================= +// Bus interactions +// ========================================================================= + +/// All bus interactions for the BYTEWISE table: +/// - **Sends** `BYTE_ALU[op, a[i], b[i]] -> res[i]` for each of the 8 bytes. +/// - **Receives** `ALU[a, b, op] -> res` (operands packed DWordBL -> 2 words). +pub fn bus_interactions() -> Vec { + let mut interactions = Vec::with_capacity(9); + + for i in 0..8 { + interactions.push(BusInteraction::sender( + BusId::ByteAlu, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::OP, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::A[i], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::B[i], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::RES[i], + packing: Packing::Direct, + }, + ], + )); + } + + // ALU[a, b, op] -> res (receiver). a/b/res are DWordBL (8 bytes) packed + // into 2 words each, matching the CPU's DWordWL operands. + interactions.push(BusInteraction::receiver( + BusId::Alu, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::A[0], + packing: Packing::DWordBL, + }, + BusValue::Packed { + start_column: cols::B[0], + packing: Packing::DWordBL, + }, + BusValue::Packed { + start_column: cols::OP, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::RES[0], + packing: Packing::DWordBL, + }, + ], + )); + + interactions +} diff --git a/prover/src/tables/cpu.rs b/prover/src/tables/cpu.rs index 5f1a759b1..ea5fc94dc 100644 --- a/prover/src/tables/cpu.rs +++ b/prover/src/tables/cpu.rs @@ -1,59 +1,30 @@ //! CPU table for the 64-bit VM. //! -//! The CPU table is the central execution table that: -//! - Fetches instructions via DECODE interaction -//! - Dispatches ALU operations to specialized tables (ADD, SUB, LT, BITWISE, SHIFT, MUL, DIVREM) -//! - Handles memory operations (LOAD, STORE, register read/write) -//! - Computes branch conditions and next_pc +//! The CPU table is the central execution table. Following `spec/src/cpu.toml` +//! it is narrow (~39 columns): there are no per-opcode one-hot ALU selectors and +//! no `*_ext_bit`/`arg1` columns. Instead each row carries: +//! - top-level flags `ALU/ADD/SUB/MEMORY/BRANCH/ECALL` (+ `word_instr`), +//! - the packed `alu_flags`/`mem_flags` bytes (the chips unpack them), and +//! - register indices + read/write flags. //! -//! ## Column Layout +//! Dispatch happens over a small set of buses: +//! - `DECODE[pc, imm, packed_decode]` (mult `1 - word_instr`): instruction fetch. +//! - `ALU[rv1, arg2, alu_flags] -> res` (mult `ALU`): unified ALU lookup; the +//! lt/mul/dvrm/shift/eq/bytewise chips receive on it, keyed by `alu_flags`. +//! - `MEMORY[timestamp, address, rv2, mem_flags] -> rvd` (mult `MEMORY`): high +//! level LOAD/STORE dispatch (the LOAD/STORE chips receive on it). +//! - `CPU32[timestamp, pc, half_instruction_length]` (mult `word_instr`): every word +//! (`*W`) instruction is delegated to the CPU32 table, which does its own +//! register I/O and sign-extension. On a `word_instr` row the main CPU is a +//! pure delegate: all operational flags are 0 and only the PC advances. +//! - `MEMW` register read/write (×3), `BRANCH`, `ECALL`, inline-PC `memory` +//! tokens, and `ARE_BYTES`/`IS_HALF` range checks. //! -//! ### Input (from DECODE) -//! - `timestamp`: Timestamp (1 col) -//! - `pc`: DWordWL (2 cols) - program counter -//! - `rs1`, `rs2`, `rd`: Byte (3 cols) - register indices -//! - Flags: `write_register`, `memory_2bytes`, `memory_4bytes`, `memory_8bytes`, -//! `c_type_instruction`, `signed`, `mp_selector`, `muldiv_selector`, `word_instr` -//! - `imm`: DWordWL (2 cols) - fully extended immediate -//! - ALU selectors: `ADD`, `SUB`, `SLT`, `AND`, `OR`, `XOR`, `SHIFT`, `JALR`, -//! `BEQ`, `BLT`, `LOAD`, `STORE`, `MUL`, `DIVREM`, `ECALL`, `EBREAK` -//! -//! ### Output -//! - `next_pc`: DWordWL (2 cols) -//! - `rvd`: DWordWL (2 cols) - value to write to destination register -//! -//! ### Auxiliary -//! - `rv1`: DWordWHH (3 cols) - value of register rs1 -//! - `rv2`: DWordWHH (3 cols) - value of register rs2 -//! - `rv1_ext_bit`, `rv2_ext_bit`, `res_ext_bit`: Bit (for word instruction extension) -//! - `arg1`: DWordBL (8 cols) - extended rv1 -//! - `arg2`: DWordBL (8 cols) - multiplexed rv2/imm -//! - `res`: DWordBL (8 cols) - ALU result -//! - `is_equal`: Bit - whether arg1 == arg2 -//! - `branch_cond`: Bit - whether branch is taken -//! -//! ## Bus Interactions -//! -//! ### Senders (CPU sends to other tables) -//! - DECODE: instruction fetch -//! - ARE_BYTES: range checks for rs1, rs2, rd, and arg1/arg2/res byte pairs -//! - IS_BIT: range checks for flags (via templates) -//! - ADD: for ADD, LOAD, JALR operations -//! - STORE ADD: for STORE (res = arg1 + imm, separate from main ADD) -//! - SUB: for SUB, BEQ operations -//! - LT: for SLT, BLT operations -//! - AND_BYTE, OR_BYTE, XOR_BYTE: for bitwise operations (×8 each) -//! - SHIFT: for shift operations -//! - MUL: for multiplication -//! - DIVREM: for division/remainder -//! - MEMW: for register and memory access -//! - MSB16: for sign/extension bit extraction (rv1, rv2, res) -//! - ZERO: for equality check -//! - BRANCH: for branch target calculation -//! - ECALL: for system calls - -use super::dvrm::DvrmOperation; -use super::types::{BusId, DecodeEntry, FE, GoldilocksExtension, GoldilocksField}; +//! `JALR` is virtual: under `BRANCH` the `mem_flags` byte only ever holds the +//! JALR bit (the memory-width bits are 0), so `mem_flags ∈ {0,1} = JALR` and the +//! `mem_flags` column is used directly as `JALR` wherever it is gated by `BRANCH`. + +use super::types::{BusId, DecodeEntry, FE, GoldilocksExtension, GoldilocksField, alu_op}; use crate::Error; use executor::vm::{ instruction::{decoding::Instruction, execution::SyscallNumbers}, @@ -63,13 +34,13 @@ use executor::vm::{ use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::trace::TraceTable; -/// PC value used for CPU padding rows. Per spec, this is an odd address (unreachable -/// during normal execution) with all flags=0. The DECODE table must contain a -/// corresponding entry at this PC. +/// PC value used for CPU padding rows. Per spec this is an odd address +/// (unreachable during normal execution); the DECODE table contains a matching +/// padding entry at this PC (all flags 0, `half_instruction_length = 0`). pub const CPU_PADDING_PC: u64 = 1; // ========================================================================= -// Column indices for CPU table +// Column indices for the CPU table // ========================================================================= /// Column definitions for the CPU table. @@ -78,188 +49,99 @@ pub mod cols { // Input columns (from DECODE) // ------------------------------------------------------------------------- - /// timestamp: Timestamp for memory argument coordination + /// timestamp: Timestamp for memory argument coordination. pub const TIMESTAMP: usize = 0; - /// pc[0]: Program counter (low word) + /// pc: program counter (DWordWL, 2 words). pub const PC_0: usize = 1; - /// pc[1]: Program counter (high word) pub const PC_1: usize = 2; - /// rs1: Source register 1 index (Byte) + /// rs1/rs2/rd: register indices (Byte). pub const RS1: usize = 3; - /// rs2: Source register 2 index (Byte) pub const RS2: usize = 4; - /// rd: Destination register index (Byte) pub const RD: usize = 5; - /// read_register1: Whether to read from rs1 (Bit) + /// read_register1/2, write_register (Bit). pub const READ_REGISTER1: usize = 6; - /// read_register2: Whether to read from rs2 (Bit) pub const READ_REGISTER2: usize = 7; - /// write_register: Whether to write back to rd (Bit) pub const WRITE_REGISTER: usize = 8; - /// memory_2bytes: Memory access is 2 bytes (Bit) - pub const MEMORY_2BYTES: usize = 9; - /// memory_4bytes: Memory access is 4 bytes (Bit) - pub const MEMORY_4BYTES: usize = 10; - /// memory_8bytes: Memory access is 8 bytes (Bit) - pub const MEMORY_8BYTES: usize = 11; - /// c_type_instruction: Instruction is 2 bytes (compressed) instead of 4 (Bit) - pub const C_TYPE_INSTRUCTION: usize = 12; - - /// imm[0]: Immediate value (low word) - pub const IMM_0: usize = 13; - /// imm[1]: Immediate value (high word) - pub const IMM_1: usize = 14; - - /// signed: Signed operation flag (Bit) - pub const SIGNED: usize = 15; - /// mp_selector: Multi-purpose selector (branch invert, shift direction, MUL variant) - pub const MP_SELECTOR: usize = 16; - /// muldiv_selector: Select MUL/DIV output variant - pub const MULDIV_SELECTOR: usize = 17; - /// word_instr: 32-bit word instruction (requires sign extension) - pub const WORD_INSTR: usize = 18; - - // ALU selector flags (one-hot encoded) - /// ADD operation - pub const ADD: usize = 19; - /// SUB operation - pub const SUB: usize = 20; - /// SLT (Set Less Than) operation - pub const SLT: usize = 21; - /// AND operation - pub const AND: usize = 22; - /// OR operation - pub const OR: usize = 23; - /// XOR operation - pub const XOR: usize = 24; - /// SHIFT operation - pub const SHIFT: usize = 25; - /// JALR (Jump And Link Register) - pub const JALR: usize = 26; - /// BEQ (Branch if Equal) - pub const BEQ: usize = 27; - /// BLT (Branch if Less Than) - pub const BLT: usize = 28; - /// LOAD operation - pub const LOAD: usize = 29; - /// STORE operation - pub const STORE: usize = 30; - /// MUL operation - pub const MUL: usize = 31; - /// DIVREM (Division/Remainder) operation - pub const DIVREM: usize = 32; - /// ECALL (Environment Call) - pub const ECALL: usize = 33; - /// EBREAK (Environment Break) - pub const EBREAK: usize = 34; + + /// imm: fully extended immediate (DWordWL, 2 words). + pub const IMM_0: usize = 9; + pub const IMM_1: usize = 10; + + /// half_instruction_length: half the bytes consumed (Byte; 1 or 2). The real + /// length is `2 * half_instruction_length`. + pub const HALF_INSTRUCTION_LENGTH: usize = 11; + /// word_instr: `*W` instruction (delegated to CPU32) (Bit). + pub const WORD_INSTR: usize = 12; + + /// ALU: use the unified ALU for this instruction (Bit). + pub const ALU: usize = 13; + /// alu_flags: packed ALU op + flags byte (Byte). + pub const ALU_FLAGS: usize = 14; + /// ADD/SUB: arithmetic fast-paths bypassing the ALU (Bit). + pub const ADD: usize = 15; + pub const SUB: usize = 16; + /// MEMORY: touches memory (LOAD/STORE) (Bit). + pub const MEMORY: usize = 17; + /// mem_flags: packed memory op + width + signed byte (Byte). Under BRANCH + /// this column doubles as the virtual `JALR` bit. + pub const MEM_FLAGS: usize = 18; + /// BRANCH: conditional branch or jump (Bit). + pub const BRANCH: usize = 19; + /// ECALL: environment call (Bit). + pub const ECALL: usize = 20; // ------------------------------------------------------------------------- // Output columns // ------------------------------------------------------------------------- - /// next_pc[0]: Next program counter (low word) - pub const NEXT_PC_0: usize = 35; - /// next_pc[1]: Next program counter (high word) - pub const NEXT_PC_1: usize = 36; + /// next_pc: program counter for the next instruction (DWordWL, 2 words). + pub const NEXT_PC_0: usize = 21; + pub const NEXT_PC_1: usize = 22; - /// rvd[0]: Value to write to destination register (low word) - pub const RVD_0: usize = 37; - /// rvd[1]: Value to write to destination register (high word) - pub const RVD_1: usize = 38; + /// rvd: value to (maybe) write back to rd (DWordWL, 2 words). + pub const RVD_0: usize = 23; + pub const RVD_1: usize = 24; // ------------------------------------------------------------------------- // Auxiliary columns // ------------------------------------------------------------------------- - /// rv1[0]: Register rs1 value (Half - bits 0-15) [DWordWHH] - pub const RV1_0: usize = 39; - /// rv1[1]: Register rs1 value (Half - bits 16-31) [DWordWHH] - pub const RV1_1: usize = 40; - /// rv1[2]: Register rs1 value (Word - bits 32-63) [DWordWHH] - pub const RV1_2: usize = 41; - - /// rv2[0]: Register rs2 value (Half - bits 0-15) [DWordWHH] - pub const RV2_0: usize = 42; - /// rv2[1]: Register rs2 value (Half - bits 16-31) [DWordWHH] - pub const RV2_1: usize = 43; - /// rv2[2]: Register rs2 value (Word - bits 32-63) [DWordWHH] - pub const RV2_2: usize = 44; - - /// rv1_ext_bit: Sign bit of rv1 as 32-bit word (for word_instr sign extension) - pub const RV1_EXT_BIT: usize = 45; - - /// arg1[0..8]: Extended rv1 as DWordBL (8 bytes) - pub const ARG1_0: usize = 46; - pub const ARG1_1: usize = 47; - pub const ARG1_2: usize = 48; - pub const ARG1_3: usize = 49; - pub const ARG1_4: usize = 50; - pub const ARG1_5: usize = 51; - pub const ARG1_6: usize = 52; - pub const ARG1_7: usize = 53; - - /// rv2_ext_bit: Sign bit of rv2 as 32-bit word (bit 31 of rv2; used for arg2 sign extension) - pub const RV2_EXT_BIT: usize = 54; - - /// arg2[0..8]: Extended rv2/imm as DWordBL (8 bytes) - pub const ARG2_0: usize = 55; - pub const ARG2_1: usize = 56; - pub const ARG2_2: usize = 57; - pub const ARG2_3: usize = 58; - pub const ARG2_4: usize = 59; - pub const ARG2_5: usize = 60; - pub const ARG2_6: usize = 61; - pub const ARG2_7: usize = 62; - - /// res_ext_bit: Sign bit of res as 32-bit word (for rvd sign extension) - pub const RES_EXT_BIT: usize = 63; - - /// res[0..8]: ALU result as DWordBL (8 bytes) - pub const RES_0: usize = 64; - pub const RES_1: usize = 65; - pub const RES_2: usize = 66; - pub const RES_3: usize = 67; - pub const RES_4: usize = 68; - pub const RES_5: usize = 69; - pub const RES_6: usize = 70; - pub const RES_7: usize = 71; - - /// is_equal: Whether rv1 == arg2 (for BEQ) - pub const IS_EQUAL: usize = 72; - - /// branch_cond: Whether branch is taken - pub const BRANCH_COND: usize = 73; - - /// prev_pc_timestamp_borrow: Borrow bit for the 32-bit subtraction timestamp_lo - 3 - /// in the inline PC prev_ts formula. Fires only when timestamp_lo < 3 and - /// pc_double_read = 0 (i.e. after timestamp wraps past 2^32 into values 0..2). - pub const PREV_PC_TIMESTAMP_BORROW: usize = 74; - - /// pc_double_read: Whether PC is read as rs1 this cycle (AUIPC/JAL) - pub const PC_DOUBLE_READ: usize = 75; - - /// Total number of columns - pub const NUM_COLUMNS: usize = 76; + /// prev_pc_timestamp_borrow: borrow bit for the inline-PC `timestamp - 3` + /// subtraction (fires when `timestamp_lo < 3` and `pc_double_read = 0`). + pub const PREV_PC_TIMESTAMP_BORROW: usize = 25; + /// pc_double_read: PC is read as a general register (`rs1 = 255`) this cycle + /// (AUIPC/JAL) (Bit). + pub const PC_DOUBLE_READ: usize = 26; - // ------------------------------------------------------------------------- - // Helper ranges for iteration - // ------------------------------------------------------------------------- + /// rv1: value of register rs1 (DWordWL, 2 words). + pub const RV1_0: usize = 27; + pub const RV1_1: usize = 28; - /// ARG1 byte columns as array - pub const ARG1: [usize; 8] = [ - ARG1_0, ARG1_1, ARG1_2, ARG1_3, ARG1_4, ARG1_5, ARG1_6, ARG1_7, - ]; + /// rv2: value of register rs2 (DWordWL, 2 words). + pub const RV2_0: usize = 29; + pub const RV2_1: usize = 30; - /// ARG2 byte columns as array - pub const ARG2: [usize; 8] = [ - ARG2_0, ARG2_1, ARG2_2, ARG2_3, ARG2_4, ARG2_5, ARG2_6, ARG2_7, - ]; + /// arg2: multiplexed second ALU argument (DWordWL, 2 words). + pub const ARG2_0: usize = 31; + pub const ARG2_1: usize = 32; - /// RES byte columns as array - pub const RES: [usize; 8] = [RES_0, RES_1, RES_2, RES_3, RES_4, RES_5, RES_6, RES_7]; + /// res: ALU result (DWordHL, 4 halves → 2 words via `cast`). + pub const RES_0: usize = 33; + pub const RES_1: usize = 34; + pub const RES_2: usize = 35; + pub const RES_3: usize = 36; + + /// branch_cond: whether the branch/jump is taken (Bit). + pub const BRANCH_COND: usize = 37; + + /// Total number of columns. + pub const NUM_COLUMNS: usize = 38; + + /// res half columns as an array (DWordHL). + pub const RES: [usize; 4] = [RES_0, RES_1, RES_2, RES_3]; } // ========================================================================= @@ -268,50 +150,40 @@ pub mod cols { /// A single CPU cycle to be added to the trace. /// -/// Contains static decode information (from DecodeEntry) plus runtime values -/// from execution (register values, computed results, etc.). +/// Holds the decoded instruction (`DecodeEntry`) plus the runtime values needed +/// to fill a row: register values, the multiplexed `arg2`, the ALU result, and +/// the branch decision. For `word_instr` rows all operational values are 0 (the +/// row is a pure CPU32 delegate). #[derive(Debug, Clone, Default)] pub struct CpuOperation { - /// Static decode information (shared with DECODE table) + /// Static decode information (shared with the DECODE table). pub decode: DecodeEntry, - - /// Timestamp for memory argument coordination + /// Timestamp for memory argument coordination. pub timestamp: u64, - - /// Next program counter (from execution) + /// Next program counter. pub next_pc: u64, - - /// Value to write to destination register (from execution) + /// Value to write back to rd. pub rvd: u64, - - /// Value of register rs1 (from execution) + /// Value of register rs1. pub rv1: u64, - - /// Value of register rs2 (from execution) + /// Value of register rs2. pub rv2: u64, - - /// ALU result or memory address (computed) + /// Multiplexed second ALU argument. + pub arg2: u64, + /// ALU result (or memory address for LOAD/STORE). pub res: u64, - - /// Whether rv1 == rv2 (for BEQ) - pub is_equal: bool, - - /// Whether branch is taken + /// Whether the branch/jump is taken. pub branch_cond: bool, - /// Whether this ECALL is a Commit syscall + /// Whether this ECALL is a Commit syscall. pub ecall_commit: bool, - - /// For Commit ECALLs: buffer address from x11 + /// For Commit ECALLs: buffer address from x11. pub commit_buf_addr: u64, - - /// For Commit ECALLs: byte count from x12 + /// For Commit ECALLs: byte count from x12. pub commit_count: u64, - - /// Whether this ECALL is a KeccakPermute syscall + /// Whether this ECALL is a KeccakPermute syscall. pub ecall_keccak: bool, - - /// For KeccakPermute ECALLs: state address from x10 + /// For KeccakPermute ECALLs: state address from x10. pub keccak_state_addr: u64, } @@ -321,448 +193,229 @@ impl CpuOperation { Self::default() } - // ========================================================================= - // Convenience accessors for decode fields (reduces verbosity) - // ========================================================================= - + // ------- convenience accessors ------- #[inline] pub fn pc(&self) -> u64 { self.decode.pc } #[inline] - pub fn rs1(&self) -> u8 { - self.decode.rs1 - } - #[inline] - pub fn rs2(&self) -> u8 { - self.decode.rs2 - } - #[inline] - pub fn rd(&self) -> u8 { - self.decode.rd - } - #[inline] pub fn imm(&self) -> u64 { self.decode.imm } #[inline] pub fn word_instr(&self) -> bool { - self.decode.word_instr + self.decode.fields.word_instr } + /// Virtual `JALR` bit: bit 0 of `mem_flags` (only meaningful under BRANCH). #[inline] - pub fn signed(&self) -> bool { - self.decode.signed + pub fn jalr(&self) -> bool { + self.decode.fields.mem_flags & 1 == 1 } - // ========================================================================= - // Computation methods - // ========================================================================= - - /// Compute arg1 from rv1 based on word_instr and signed flags. - /// - /// Per spec constraint: arg1[4:] = rv1[2] * (1 - word_instr) + (2^32 - 1) * rv1_ext_bit * signed - /// - /// For 64-bit instructions: pass through full rv1 - /// For unsigned word instructions: zero-extend from 32 bits - /// For signed word instructions: sign-extend from 32 bits - pub fn compute_arg1(&self) -> u64 { - if self.decode.word_instr { - let lower_32 = self.rv1 & 0xFFFF_FFFF; - if self.decode.signed && Self::sign_bit_32(self.rv1) { - // Sign extend: set upper 32 bits to all 1s - lower_32 | (0xFFFF_FFFF_u64 << 32) - } else { - // Zero extend: upper 32 bits are 0 - lower_32 - } - } else { - self.rv1 - } - } + /// Creates a CpuOperation from an executor Log and a DecodeEntry. + pub fn from_log(log: &Log, timestamp: u64, decode: DecodeEntry) -> Self { + let f = decode.fields; + // Real byte length: the column stores half. + let instruction_length = 2 * f.half_instruction_length as u64; - /// Compute arg2 following the spec formula exactly (CPU-CE62/CE63). - /// - /// arg2[:4] = (1-LOAD)*rv2[:2] + (1-BEQ-BLT-STORE)*imm[0] - /// arg2[4:] = (1-LOAD)*((1-word_instr)*rv2[2] + signed*rv2_ext_bit*(2^32-1)) - /// + (1-BEQ-BLT-STORE)*imm[1] - /// - /// Per CPU-A2, the decode guarantees that at most one of rv2/imm is non-zero - /// when STORE+LOAD+BEQ+BLT=0, so the addition acts as a selection. - pub fn compute_arg2(&self) -> u64 { - let d = &self.decode; - - // rv2 contribution: zeroed when LOAD (spec: (1-LOAD) factor) - let rv2_extended = if d.op_load { - 0 - } else if d.word_instr { - // Word-instruction sign/zero extension on upper 32 bits - let lower_32 = self.rv2 & 0xFFFF_FFFF; - if d.signed && Self::sign_bit_32(self.rv2) { - lower_32 | (0xFFFF_FFFF_u64 << 32) - } else { - lower_32 - } + // ECALL syscall classification (rv1 = a7 = syscall number). + let ecall_commit = f.ecall && log.src1_val == SyscallNumbers::Commit as u64; + let (commit_buf_addr, commit_count) = if ecall_commit { + (log.src2_val, log.dst_val) } else { - self.rv2 + (0, 0) }; + let ecall_keccak = + f.ecall && log.src1_val == executor::vm::instruction::execution::KECCAK_SYSCALL_NUMBER; + let keccak_state_addr = if ecall_keccak { log.src2_val } else { 0 }; + + // Word instructions are fully handled by CPU32; the main CPU row is a + // delegate that only advances the PC and sends the CPU32 lookup. We still + // carry the real register values (rv1/rv2/rvd) so the CPU32 op-generation + // and its register MEMW accesses can use them — `generate_cpu_trace` + // zeroes the operational columns on the delegate row. + if f.word_instr { + return Self { + next_pc: decode.pc.wrapping_add(instruction_length), + rv1: log.src1_val, + rv2: if f.read_register2 { log.src2_val } else { 0 }, + rvd: log.dst_val, + ecall_commit, + commit_buf_addr, + commit_count, + ecall_keccak, + keccak_state_addr, + decode, + timestamp, + ..Default::default() + }; + } - // imm contribution: zeroed when BEQ, BLT, or STORE (spec: (1-BEQ-BLT-STORE) factor) - let imm_contrib = if d.op_beq || d.op_blt || d.op_store { + // Register values. x255 is the PC register (read by AUIPC/JAL via rs1). + let rv1 = if f.rs1 == 255 { + log.current_pc + } else if f.read_register1 { + log.src1_val + } else { 0 + }; + let rv2 = if f.read_register2 { log.src2_val } else { 0 }; + + let jalr = f.mem_flags & 1 == 1; + + // arg2 multiplex (CPU-A1), matching `cpu.toml`: + // MEMORY -> imm + // BRANCH -> rv2 (JAL/JALR read no rs2, so rv2 = 0) + // else -> rv2 + imm (≤1 nonzero by decode A2) + let arg2 = if f.memory { + decode.imm + } else if f.branch { + rv2 } else { - d.imm + rv2.wrapping_add(decode.imm) }; - rv2_extended.wrapping_add(imm_contrib) - } - - /// Extract sign bit of a 32-bit word (bit 31). - pub fn sign_bit_32(val: u64) -> bool { - (val >> 31) & 1 == 1 - } - - /// Compute rvd (destination register value) based on res and word_instr. - /// - /// According to spec constraints: - /// - rvd[0] = res[:4] (lower 32 bits of res) - /// - rvd[1] = (1 - word_instr) * res[4:] + res_ext_bit * (2^32 - 1) - /// - /// For LOAD: rvd comes from the executor (loaded value), not this method. - /// For all other operations: rvd is computed from res with sign extension. - pub fn compute_rvd(&self) -> u64 { - let res = self.compute_res(); - let res_lo = res & 0xFFFF_FFFF; - - if self.decode.word_instr { - // Sign extend from 32 bits - let res_ext_bit = Self::sign_bit_32(res); - if res_ext_bit { - // Upper 32 bits = 0xFFFF_FFFF (sign extension) - res_lo | (0xFFFF_FFFF_u64 << 32) + // Branch decision. JAL/JALR always jump; conditional branches evaluate + // the EQ/LT comparison (with invert) encoded in `alu_flags`. + let branch_cond = if f.branch { + if jalr { + true } else { - // Upper 32 bits = 0 (zero extension) - res_lo + Self::branch_taken(&f, rv1, rv2) } } else { - // rvd = res (full 64-bit value) - res - } - } + false + }; - /// Compute the result based on operation type. - /// - /// For ADD: res = arg1 + arg2 (64-bit wrapping) - /// For SUB: res = arg1 - arg2 (64-bit wrapping) - /// For SHIFT: res = raw 64-bit shift of arg1 by arg2 (no word sign extension; - /// rvd handles sign extension for word instructions) - /// For SLT: res = 0 or 1 (comparison result from executor) - /// For other operations: uses the executor's result (self.res) - /// - /// This ensures the ADD/SUB constraints are satisfied. - /// The rvd column holds the actual sign-extended result for word instructions. - pub fn compute_res(&self) -> u64 { - let arg1 = self.compute_arg1(); - let arg2 = self.compute_arg2(); - - if self.decode.op_add || self.decode.op_load { - // ADD constraint: arg1 + arg2 = res - // For ADD: computes arithmetic result - // For LOAD: computes memory address (rv1 + imm) - arg1.wrapping_add(arg2) - } else if self.decode.op_store { - // STORE: res = arg1 + imm (address), not arg1 + arg2 (which is now rv2) - arg1.wrapping_add(self.decode.imm) - } else if self.decode.op_sub { - // SUB constraint checks: res + arg2 = arg1, so res = arg1 - arg2 - arg1.wrapping_sub(arg2) - } else if self.decode.op_shift { - // SHIFT: raw 64-bit shift matching the SHIFT chip's computation. - // The SHIFT chip shifts the full 64-bit arg1 by (shift mod 32*(2-word_instr)). - // Sign extension for word instructions is handled by rvd, not res. - let shift = (arg2 & 0xFF) as u32; - let modulus = if self.decode.word_instr { 32 } else { 64 }; - let effective = shift % modulus; - if !self.decode.mp_selector { - // Left shift - arg1.wrapping_shl(effective) - } else if !self.decode.signed { - // Logical right shift - arg1.wrapping_shr(effective) + // res = ALU result / address. ADD covers add/load/store/JAL(R); SUB the + // subtraction fast-path; ALU the comparison (branch) or the chip result. + let res = if f.add { + rv1.wrapping_add(arg2) + } else if f.sub { + rv1.wrapping_sub(arg2) + } else if f.alu { + if f.branch { + branch_cond as u64 } else { - // Arithmetic right shift - (arg1 as i64).wrapping_shr(effective) as u64 - } - } else if self.decode.op_mul && self.decode.word_instr { - // MULW: low 64 bits of arg1 * arg2 (signedness doesn't affect the low bits). - arg1.wrapping_mul(arg2) - } else if self.decode.op_divrem && self.decode.word_instr { - // DIVUW/DIVW/REMUW/REMW. Reuse the DVRM spec implementation so the CPU - // and DVRM tables stay in lockstep on division semantics. - let dvrm = DvrmOperation::new(arg1, arg2, self.decode.signed); - if self.decode.muldiv_selector { - dvrm.compute_remainder() - } else { - dvrm.compute_quotient() + log.dst_val } } else { - // For SLT and other operations, use the executor's result - // SLT res is 0 or 1, verified by SltResZeroConstraint - self.res - } - } - - /// Collects CPU range-check lookups for register indices and byte pairs. - /// - /// The CPU sends: - /// - 1 ARE_BYTES lookup for (RS1, RS2) batched as a pair - /// - 1 ARE_BYTES lookup for RD encoded as (RD, 0) - /// - 12 ARE_BYTES lookups for adjacent byte pairs in ARG1, ARG2, and RES - pub fn collect_byte_check_ops(&self) -> Vec { - use super::bitwise::{BitwiseOperation, BitwiseOperationType}; - - let arg1 = self.compute_arg1(); - let arg2 = self.compute_arg2(); - let res = self.compute_res(); - - let mut ops = Vec::with_capacity(14); - - // Batch RS1+RS2 as a pair; RD stays single with Y=0. - ops.push(BitwiseOperation::byte_op( - BitwiseOperationType::AreBytes, - self.decode.rs1, - self.decode.rs2, - )); - ops.push(BitwiseOperation::single_byte( - BitwiseOperationType::AreBytes, - self.decode.rd, - )); - - // 12 ARE_BYTES lookups for ARG1/ARG2/RES byte pairs - // Each pair sends [lo, hi] as two separate bus values, so the LogUp - // fingerprint forces each byte to match individually against BITWISE X, Y. - for value in [arg1, arg2, res] { - for i in 0..4 { - let lo = ((value >> (i * 16)) & 0xFF) as u8; - let hi = ((value >> (i * 16 + 8)) & 0xFF) as u8; - ops.push(BitwiseOperation::byte_op( - BitwiseOperationType::AreBytes, - lo, - hi, - )); - } - } - - ops - } - - /// Collects Bitwise table lookups generated by this CPU operation. - pub fn collect_bitwise_ops(&self) -> Vec { - use super::bitwise::{BitwiseOperation, BitwiseOperationType}; - let mut lookups = Vec::new(); - - // Range checks: 14 ARE_BYTES ops (RS1+RS2 paired, RD single with Y=0, - // plus 12 ARG1/ARG2/RES byte pairs). - lookups.extend(self.collect_byte_check_ops()); - - // MSB16 lookups for sign bit extraction (when word_instr=1) - if self.decode.word_instr { - // rv1[1] is bits 16-31, extract as halfword for MSB16 lookup - let rv1_half = ((self.rv1 >> 16) & 0xFFFF) as u16; - let lo = (rv1_half & 0xFF) as u8; - let hi = ((rv1_half >> 8) & 0xFF) as u8; - lookups.push(BitwiseOperation::halfword( - BitwiseOperationType::Msb16, - lo, - hi, - )); - - // rv2[1] for rv2_ext_bit - let rv2_half = ((self.rv2 >> 16) & 0xFFFF) as u16; - let lo = (rv2_half & 0xFF) as u8; - let hi = ((rv2_half >> 8) & 0xFF) as u8; - lookups.push(BitwiseOperation::halfword( - BitwiseOperationType::Msb16, - lo, - hi, - )); - - // res::DWordHL[1] for res_ext_bit (MSB16 on half at bits 16-31) - let res_half = ((self.res >> 16) & 0xFFFF) as u16; - lookups.push(BitwiseOperation::halfword( - BitwiseOperationType::Msb16, - (res_half & 0xFF) as u8, - (res_half >> 8) as u8, - )); - } - - // ZERO lookup for is_equal (when BEQ=1) - if self.decode.op_beq { - // Sum of all result bytes - let mut sum: u64 = 0; - for i in 0..8 { - sum += (self.res >> (i * 8)) & 0xFF; - } - // Sum fits in 11 bits (max 8 * 255 = 2040), well within ZERO's 20-bit range - lookups.push(BitwiseOperation::zero(sum as u32)); - } - - // AND/OR/XOR lookups (×8 each for each byte) - let arg1 = self.compute_arg1(); - let arg2 = self.compute_arg2(); - - if self.decode.op_and { - for i in 0..8 { - let a = ((arg1 >> (i * 8)) & 0xFF) as u8; - let b = ((arg2 >> (i * 8)) & 0xFF) as u8; - lookups.push(BitwiseOperation::byte_op( - BitwiseOperationType::AndByte, - a, - b, - )); - } - } - - if self.decode.op_or { - for i in 0..8 { - let a = ((arg1 >> (i * 8)) & 0xFF) as u8; - let b = ((arg2 >> (i * 8)) & 0xFF) as u8; - lookups.push(BitwiseOperation::byte_op( - BitwiseOperationType::OrByte, - a, - b, - )); - } - } - - if self.decode.op_xor { - for i in 0..8 { - let a = ((arg1 >> (i * 8)) & 0xFF) as u8; - let b = ((arg2 >> (i * 8)) & 0xFF) as u8; - lookups.push(BitwiseOperation::byte_op( - BitwiseOperationType::XorByte, - a, - b, - )); - } - } - - lookups - } + 0 + }; - /// Creates a CpuOperation from an executor Log and DecodeEntry. - /// - /// The DecodeEntry contains static instruction information. This method - /// adds runtime values from the Log (register values, branch decisions, etc.). - pub fn from_log(log: &Log, timestamp: u64, decode: DecodeEntry) -> Self { - let ecall_commit = decode.op_ecall && log.src1_val == SyscallNumbers::Commit as u64; - let (commit_buf_addr, commit_count) = if ecall_commit { - (log.src2_val, log.dst_val) + // rvd: loaded value for LOAD; 0 for STORE (output unused); the return + // address `pc + instruction_length` on every BRANCH row (written to `rd` + // only by JAL/JALR — `cpu.toml` branch group); `res` + // otherwise. The spec computes this `pc + len` via the ADD chip gated on + // `BRANCH`; we pin it with [`BranchRvdConstraint`] (carry-omitting, like + // `next_pc`). For conditional branches `rvd` is computed but never + // written (`write_register = 0`). + let store = f.memory && jalr; // under MEMORY, mem_flags bit 0 = memory_op (1 = store) + let rvd = if f.memory { + if store { 0 } else { log.dst_val } + } else if f.branch { + decode.pc.wrapping_add(instruction_length) } else { - (0, 0) + res }; - let ecall_keccak = decode.op_ecall - && log.src1_val == executor::vm::instruction::execution::KECCAK_SYSCALL_NUMBER; - let keccak_state_addr = if ecall_keccak { log.src2_val } else { 0 }; - // CM50: (1 - read_register2) * rv2[i] = 0. When read_register2=0, rv2 must be 0. - // For example, ECALL has read_register2=0 (rs2 defaults to 0). The commit buf_addr is - // carried separately in commit_buf_addr and does not go through rv2. - let rv2 = if !decode.read_register2 { - 0 + + // next_pc: branch target for taken branches/jumps; otherwise pc + len. + // ECALL keeps next_pc = pc + len (CO69) even though the executor sets 0 + // to signal halt; the HALT table proves termination separately. + let next_pc = if f.ecall { + decode.pc.wrapping_add(instruction_length) + } else if branch_cond { + log.next_pc } else { - log.src2_val + decode.pc.wrapping_add(instruction_length) }; - let mut op = Self { + Self { decode, timestamp, - next_pc: log.next_pc, - rv1: log.src1_val, + next_pc, + rvd, + rv1, rv2, - rvd: log.dst_val, - res: log.dst_val, // Default: result is destination value - is_equal: false, - branch_cond: false, + arg2, + res, + branch_cond, ecall_commit, commit_buf_addr, commit_count, ecall_keccak, keccak_state_addr, - }; + } + } - // Compute runtime-specific values based on instruction type - op.compute_runtime_values(log); - op + /// Evaluate a conditional-branch comparison `(rv1 ? rv2)` from `alu_flags`. + /// `alu_flags = alu_op + 32·signed + 64·invert` for branches. + fn branch_taken(f: &super::types::ShrunkDecode, rv1: u64, rv2: u64) -> bool { + let op = f.alu_flags & 0x1F; + let signed = (f.alu_flags >> 5) & 1 == 1; + let invert = (f.alu_flags >> 6) & 1 == 1; + let cmp = match op { + x if x == alu_op::EQ => rv1 == rv2, + x if x == alu_op::LT => { + if signed { + (rv1 as i64) < (rv2 as i64) + } else { + rv1 < rv2 + } + } + _ => false, + }; + cmp ^ invert } - /// Creates a CpuOperation from Log and Instruction (convenience method). - /// - /// This creates the DecodeEntry internally. Use `from_log` with a pre-built - /// DecodeEntry when possible to avoid redundant decoding. + /// Creates a CpuOperation from Log and Instruction (convenience). pub fn from_log_and_instruction(log: &Log, timestamp: u64, instruction: Instruction) -> Self { - let decode = DecodeEntry::from_instruction(log.current_pc, instruction); + let decode = DecodeEntry::from_instruction(log.current_pc, instruction, 4); Self::from_log(log, timestamp, decode) } - /// Computes runtime-specific values based on the instruction type. - /// - /// This handles: - /// - Memory address computation for LOAD/STORE - /// - Branch condition and result computation for BEQ/BLT - /// - AUIPC special case (rv1 = current_pc) - /// - JALR branch_cond = true - fn compute_runtime_values(&mut self, log: &Log) { - // JALR: always jumps - if self.decode.op_jalr { - self.branch_cond = true; - } - - // LOAD/STORE: res = memory address = rv1 + imm - if self.decode.op_load || self.decode.op_store { - self.res = (log.src1_val as i64 + self.decode.imm as i64) as u64; - } + /// Collects the BITWISE-table range-check lookups generated by this row, so + /// the BITWISE table can account for the matching multiplicities: + /// 3 `ARE_BYTES` (rs1/rs2, rd/half_instruction_length, alu_flags/mem_flags) and + /// 4 `IS_HALF` (the four halves of `res`). + pub fn collect_bitwise_ops(&self) -> Vec { + use super::bitwise::{BitwiseOperation, BitwiseOperationType}; + let f = self.decode.fields; + let mut ops = Vec::with_capacity(7); - // BEQ: res = rv1 - rv2, branch if equal (or not equal for BNE) - if self.decode.op_beq { - self.is_equal = log.src1_val == log.src2_val; - self.res = log.src1_val.wrapping_sub(log.src2_val); - // mp_selector inverts the condition (BNE vs BEQ) - self.branch_cond = if self.decode.mp_selector { - log.src1_val != log.src2_val - } else { - log.src1_val == log.src2_val - }; - } + // Must mirror the trace columns exactly. On word delegate rows the CPU + // zeroes rs1/rs2/rd/alu_flags/mem_flags and res (half_instruction_length stays); + // CPU32 emits its own range checks for the real decoded values. + let word = f.word_instr; + let z = |v: u8| if word { 0 } else { v }; + let res = if word { 0 } else { self.res }; - // BLT: res = comparison result (0 or 1) - if self.decode.op_blt { - self.is_equal = log.src1_val == log.src2_val; - let lt_result = if self.decode.signed { - (log.src1_val as i64) < (log.src2_val as i64) - } else { - log.src1_val < log.src2_val - }; - self.res = lt_result as u64; - // mp_selector inverts the condition (BGE/BGEU vs BLT/BLTU) - self.branch_cond = if self.decode.mp_selector { - !lt_result - } else { - lt_result - }; - } + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::AreBytes, + z(f.rs1), + z(f.rs2), + )); + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::AreBytes, + z(f.rd), + f.half_instruction_length, + )); + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::AreBytes, + z(f.alu_flags), + z(f.mem_flags), + )); - // AUIPC/JAL: rv1 should be current_pc (special case) - // Per spec, these instructions use rs1=255 (virtual PC register) - if self.decode.rs1 == 255 { - self.rv1 = log.current_pc; + for i in 0..4 { + let half = ((res >> (i * 16)) & 0xFFFF) as u16; + ops.push(BitwiseOperation::halfword( + BitwiseOperationType::IsHalf, + (half & 0xFF) as u8, + (half >> 8) as u8, + )); } - // ECALL: Per spec constraint CO69, next_pc = pc + instr_size for all instructions, - // including ECALL. The CPU transition constraint enforces next_pc = pc + 4 on every - // row, so the trace must satisfy this even though the executor sets next_pc=0 to - // signal halt. The HALT table separately proves program termination via the ECALL bus. - if self.decode.op_ecall { - self.next_pc = self.decode.pc + 4; - } + ops } } @@ -772,150 +425,122 @@ impl CpuOperation { /// Generates the CPU trace table from a list of operations. /// -/// Each operation becomes one row in the table. The table is then -/// padded to the next power of 2. +/// Each operation becomes one row; the table is padded to the next power of 2. pub fn generate_cpu_trace( operations: &[CpuOperation], ) -> TraceTable { let n = operations.len(); - let num_rows = n.next_power_of_two().max(4); let mut data = vec![FE::zero(); num_rows * cols::NUM_COLUMNS]; for (row_idx, op) in operations.iter().enumerate() { let base = row_idx * cols::NUM_COLUMNS; - let d = &op.decode; // Shorthand for decode fields + let f = &op.decode.fields; + let word = f.word_instr; + + // For a word_instr delegate row the operational flags/register I/O are + // suppressed (CPU32 owns them); only the PC-advancing columns are set. + let effective = |flag: bool| (!word && flag) as u64; - // Input columns (from decode) data[base + cols::TIMESTAMP] = FE::from(op.timestamp); - data[base + cols::PC_0] = FE::from(d.pc & 0xFFFF_FFFF); - data[base + cols::PC_1] = FE::from(d.pc >> 32); - data[base + cols::RS1] = FE::from(d.rs1 as u64); - data[base + cols::RS2] = FE::from(d.rs2 as u64); - data[base + cols::RD] = FE::from(d.rd as u64); - // Skip x0 (hardwired zero). x255 is the register where the pc is stored - // (per spec decode.md). read_register1=1 for rs1=255 ensures the CM47 MEMW - // interaction is sent and rv1 is not forced to zero by CM48. - data[base + cols::READ_REGISTER1] = FE::from((d.read_register1 && d.rs1 != 0) as u64); - data[base + cols::READ_REGISTER2] = FE::from((d.read_register2 && d.rs2 != 0) as u64); - data[base + cols::WRITE_REGISTER] = FE::from((d.write_register && d.rd != 0) as u64); - data[base + cols::MEMORY_2BYTES] = FE::from(d.memory_2bytes as u64); - data[base + cols::MEMORY_4BYTES] = FE::from(d.memory_4bytes as u64); - data[base + cols::MEMORY_8BYTES] = FE::from(d.memory_8bytes as u64); - data[base + cols::C_TYPE_INSTRUCTION] = FE::from(d.c_type as u64); - data[base + cols::IMM_0] = FE::from(d.imm & 0xFFFF_FFFF); - data[base + cols::IMM_1] = FE::from(d.imm >> 32); - data[base + cols::SIGNED] = FE::from(d.signed as u64); - data[base + cols::MP_SELECTOR] = FE::from(d.mp_selector as u64); - data[base + cols::MULDIV_SELECTOR] = FE::from(d.muldiv_selector as u64); - data[base + cols::WORD_INSTR] = FE::from(d.word_instr as u64); - - // ALU selector flags - data[base + cols::ADD] = FE::from(d.op_add as u64); - data[base + cols::SUB] = FE::from(d.op_sub as u64); - data[base + cols::SLT] = FE::from(d.op_slt as u64); - data[base + cols::AND] = FE::from(d.op_and as u64); - data[base + cols::OR] = FE::from(d.op_or as u64); - data[base + cols::XOR] = FE::from(d.op_xor as u64); - data[base + cols::SHIFT] = FE::from(d.op_shift as u64); - data[base + cols::JALR] = FE::from(d.op_jalr as u64); - data[base + cols::BEQ] = FE::from(d.op_beq as u64); - data[base + cols::BLT] = FE::from(d.op_blt as u64); - data[base + cols::LOAD] = FE::from(d.op_load as u64); - data[base + cols::STORE] = FE::from(d.op_store as u64); - data[base + cols::MUL] = FE::from(d.op_mul as u64); - data[base + cols::DIVREM] = FE::from(d.op_divrem as u64); - data[base + cols::ECALL] = FE::from(d.op_ecall as u64); - data[base + cols::EBREAK] = FE::from(d.op_ebreak as u64); - - // Output columns - data[base + cols::NEXT_PC_0] = FE::from(op.next_pc & 0xFFFF_FFFF); - data[base + cols::NEXT_PC_1] = FE::from(op.next_pc >> 32); + data[base + cols::PC_0] = FE::from(op.decode.pc & 0xFFFF_FFFF); + data[base + cols::PC_1] = FE::from(op.decode.pc >> 32); - // rvd: For LOAD, use the executor's loaded value (op.rvd). - // For all other operations (including STORE), compute from res with sign extension. - // This satisfies spec constraint: (1-LOAD) * (rvd - res_extended) = 0 - let rvd = if d.op_load { - op.rvd // Loaded value from executor + // rs1/rs2/rd and read/write flags are only present on non-word rows. + let (rs1, rs2, rd) = if word { + (0, 0, 0) } else { - op.compute_rvd() // res with sign extension for word instructions + (f.rs1, f.rs2, f.rd) }; + data[base + cols::RS1] = FE::from(rs1 as u64); + data[base + cols::RS2] = FE::from(rs2 as u64); + data[base + cols::RD] = FE::from(rd as u64); + + // x0 is hardwired zero (never read/written); x255 is the PC register and + // must be read (read_register1=1) so its MEMW interaction fires. + data[base + cols::READ_REGISTER1] = FE::from(effective(f.read_register1 && f.rs1 != 0)); + data[base + cols::READ_REGISTER2] = FE::from(effective(f.read_register2 && f.rs2 != 0)); + data[base + cols::WRITE_REGISTER] = FE::from(effective(f.write_register && f.rd != 0)); + + // On word delegate rows, all operational data columns are 0 (CPU32 owns + // the real values); the register-zero / arg2 / rvd=res constraints all + // hold with read flags = 0. `op` still carries the real rv1/rv2/rvd for + // the CPU32 op-generation, so we mask the columns here. + let (imm, rvd, rv1, rv2, arg2, res) = if word { + (0, 0, 0, 0, 0, 0) + } else { + (op.decode.imm, op.rvd, op.rv1, op.rv2, op.arg2, op.res) + }; + + data[base + cols::IMM_0] = FE::from(imm & 0xFFFF_FFFF); + data[base + cols::IMM_1] = FE::from(imm >> 32); + + data[base + cols::HALF_INSTRUCTION_LENGTH] = FE::from(f.half_instruction_length as u64); + data[base + cols::WORD_INSTR] = FE::from(word as u64); + + data[base + cols::ALU] = FE::from(effective(f.alu)); + data[base + cols::ALU_FLAGS] = FE::from(if word { 0 } else { f.alu_flags as u64 }); + data[base + cols::ADD] = FE::from(effective(f.add)); + data[base + cols::SUB] = FE::from(effective(f.sub)); + data[base + cols::MEMORY] = FE::from(effective(f.memory)); + data[base + cols::MEM_FLAGS] = FE::from(if word { 0 } else { f.mem_flags as u64 }); + data[base + cols::BRANCH] = FE::from(effective(f.branch)); + data[base + cols::ECALL] = FE::from(effective(f.ecall)); + + data[base + cols::NEXT_PC_0] = FE::from(op.next_pc & 0xFFFF_FFFF); + data[base + cols::NEXT_PC_1] = FE::from(op.next_pc >> 32); + data[base + cols::RVD_0] = FE::from(rvd & 0xFFFF_FFFF); data[base + cols::RVD_1] = FE::from(rvd >> 32); - // Auxiliary: rv1 as DWordWHH [Half, Half, Word] - Word is MSB (bits 32-63) - data[base + cols::RV1_0] = FE::from(op.rv1 & 0xFFFF); // bits 0-15 (Half) - data[base + cols::RV1_1] = FE::from((op.rv1 >> 16) & 0xFFFF); // bits 16-31 (Half) - data[base + cols::RV1_2] = FE::from(op.rv1 >> 32); // bits 32-63 (Word) - - // Auxiliary: rv2 as DWordWHH [Half, Half, Word] - Word is MSB (bits 32-63) - data[base + cols::RV2_0] = FE::from(op.rv2 & 0xFFFF); // bits 0-15 (Half) - data[base + cols::RV2_1] = FE::from((op.rv2 >> 16) & 0xFFFF); // bits 16-31 (Half) - data[base + cols::RV2_2] = FE::from(op.rv2 >> 32); // bits 32-63 (Word) - - // Extension bits - only set when word_instr=1, per SIGN template - // The constraint enforces: (1 - word_instr) * ext_bit = 0 for each ext bit - let rv1_ext_bit = d.word_instr && CpuOperation::sign_bit_32(op.rv1); - data[base + cols::RV1_EXT_BIT] = FE::from(rv1_ext_bit as u64); - - // Compute and store arg1 as DWordBL (8 bytes) - let arg1 = op.compute_arg1(); - for i in 0..8 { - data[base + cols::ARG1[i]] = FE::from((arg1 >> (i * 8)) & 0xFF); - } - - // Compute and store arg2 - let arg2 = op.compute_arg2(); - let rv2_ext_bit = d.word_instr && CpuOperation::sign_bit_32(op.rv2); - data[base + cols::RV2_EXT_BIT] = FE::from(rv2_ext_bit as u64); - for i in 0..8 { - data[base + cols::ARG2[i]] = FE::from((arg2 >> (i * 8)) & 0xFF); - } + // rv1/rv2/arg2 as DWordWL (2 × 32-bit words). + data[base + cols::RV1_0] = FE::from(rv1 & 0xFFFF_FFFF); + data[base + cols::RV1_1] = FE::from(rv1 >> 32); + data[base + cols::RV2_0] = FE::from(rv2 & 0xFFFF_FFFF); + data[base + cols::RV2_1] = FE::from(rv2 >> 32); + data[base + cols::ARG2_0] = FE::from(arg2 & 0xFFFF_FFFF); + data[base + cols::ARG2_1] = FE::from(arg2 >> 32); - // Result - computed from arg1/arg2 for ADD/SUB to satisfy constraints - let res = op.compute_res(); - let res_ext_bit = d.word_instr && CpuOperation::sign_bit_32(res); - data[base + cols::RES_EXT_BIT] = FE::from(res_ext_bit as u64); - for i in 0..8 { - data[base + cols::RES[i]] = FE::from((res >> (i * 8)) & 0xFF); + // res as DWordHL (4 × 16-bit halves). + for i in 0..4 { + data[base + cols::RES[i]] = FE::from((res >> (i * 16)) & 0xFFFF); } - // Branch columns - data[base + cols::IS_EQUAL] = FE::from(op.is_equal as u64); data[base + cols::BRANCH_COND] = FE::from(op.branch_cond as u64); - // Inline PC columns - let pc_double_read = (d.read_register1 && d.rs1 == 255) as u64; + // Inline-PC coordination columns. + let pc_double_read = (!word && f.read_register1 && f.rs1 == 255) as u64; let ts_lo = op.timestamp & 0xFFFF_FFFF; let prev_pc_ts_borrow = if pc_double_read == 0 && ts_lo < 3 { - 1u64 + 1 } else { - 0u64 + 0 }; data[base + cols::PC_DOUBLE_READ] = FE::from(pc_double_read); data[base + cols::PREV_PC_TIMESTAMP_BORROW] = FE::from(prev_pc_ts_borrow); } - // Padding rows: per spec, padding uses pc=1 (odd address, unreachable during - // normal execution) with all flags=0, so pad=1 and no bus interactions fire. - // next_pc=5 satisfies the NextPcAdd constraint: carry=(1+4-5)/2^32=0. - // The DECODE table must contain a corresponding entry at pc=1. + // Padding rows: pc = next_pc = 1 (odd, unreachable), half_instruction_length = 0 so + // next_pc = pc + 0 = pc, all flags 0. The DECODE table has the matching padding + // entry at pc = 1. Per spec, padding rows participate in the inline-PC `memory` + // chain: each reads pc=1 at `timestamp - 3` and writes pc=1 at `timestamp + 1`, + // so their timestamps must continue the +4 cadence from the last real row (the + // halting ECALL). pc_double_read and prev_pc_timestamp_borrow stay 0, giving + // prev_ts = timestamp - 3. The first padding read (timestamp = last_ts + 4) then + // lands on last_ts + 1, where the HALT chip's emit_pc deposited pc = 1. + let last_ts = operations.last().map(|op| op.timestamp).unwrap_or(0); for row_idx in n..num_rows { let base = row_idx * cols::NUM_COLUMNS; + let j = (row_idx - n + 1) as u64; + data[base + cols::TIMESTAMP] = FE::from(last_ts + 4 * j); data[base + cols::PC_0] = FE::from(CPU_PADDING_PC); - data[base + cols::NEXT_PC_0] = FE::from(CPU_PADDING_PC + 4); + data[base + cols::NEXT_PC_0] = FE::from(CPU_PADDING_PC); } TraceTable::new_main(data, cols::NUM_COLUMNS, 1) } /// Generates the CPU trace table directly from executor logs. -/// -/// This is a convenience function that converts logs to CpuOperations -/// and then generates the trace. -/// -/// Returns an error if an instruction is not found for a PC. -/// Panics if logs.len() is not a power of 2 >= 4. pub fn generate_cpu_trace_from_logs( logs: &[Log], instructions: &U64HashMap, @@ -934,7 +559,7 @@ pub fn generate_cpu_trace_from_logs( Ok(generate_cpu_trace(&operations)) } -/// Collects all Bitwise lookups from a list of CPU operations. +/// Collects all BITWISE lookups generated by these CPU operations. pub fn collect_bitwise_ops(operations: &[CpuOperation]) -> Vec { operations .iter() @@ -942,9 +567,7 @@ pub fn collect_bitwise_ops(operations: &[CpuOperation]) -> Vec, @@ -967,654 +590,156 @@ pub fn collect_bitwise_ops_from_logs( // Bus interactions // ========================================================================= -/// Helper to create a LinearTerm with coefficient 2^bit for a column. -fn linear_term(bit: u32, column: usize) -> LinearTerm { +/// LinearTerm with coefficient 2^bit for a column (packed_decode reconstruction). +fn pow2_term(bit: u32, column: usize) -> LinearTerm { LinearTerm::Column { - coefficient: 1 << bit, + coefficient: 1i64 << bit, column, } } +/// `BusValue` for the low 32-bit word and high 32-bit word of `res` (DWordHL), +/// i.e. `cast(res, DWordWL)` as 2 bus elements. +fn res_cast_wl() -> BusValue { + BusValue::Packed { + start_column: cols::RES_0, + packing: Packing::DWordHL, + } +} + /// Returns the bus interactions for the CPU table. -/// -/// The CPU table sends to: -/// - DECODE: instruction fetch (every row) -/// - AND_BYTE, OR_BYTE, XOR_BYTE: for bitwise operations (×8 each) -/// -/// Note: LT interaction is TODO - needs proper DWordHHW packing to match LT table receiver. pub fn bus_interactions() -> Vec { - use super::types::packed_decode as bits; + use super::types::packed_decode_shrunk as pd; - let mut interactions = Vec::new(); + let mut interactions = Vec::with_capacity(24); // ------------------------------------------------------------------------- - // DECODE interaction (instruction fetch) + // DECODE: instruction fetch (mult = 1 - word_instr; word rows go to CPU32). // ------------------------------------------------------------------------- - // Every CPU row looks up the DECODE table once to verify instruction decoding. - // Format: DECODE[pc::DWordWL, imm::DWordWL, packed_decode] - // - // packed_decode is computed as a linear combination of all decode columns. - // Bit positions are defined in types::packed_decode (single source of truth). interactions.push(BusInteraction::sender( BusId::Decode, - Multiplicity::One, // Every row sends exactly once + Multiplicity::Negated(cols::WORD_INSTR), vec![ - // pc as DWordWL (2 bus elements) BusValue::Packed { start_column: cols::PC_0, packing: Packing::DWordWL, }, - // imm as DWordWL (2 bus elements) BusValue::Packed { start_column: cols::IMM_0, packing: Packing::DWordWL, }, - // packed_decode as linear combination of decode columns BusValue::linear(vec![ - // Control flags (bits 0-10) - linear_term(bits::READ_REG1, cols::READ_REGISTER1), - linear_term(bits::READ_REG2, cols::READ_REGISTER2), - linear_term(bits::WRITE_REG, cols::WRITE_REGISTER), - linear_term(bits::MEMORY_2BYTES, cols::MEMORY_2BYTES), - linear_term(bits::MEMORY_4BYTES, cols::MEMORY_4BYTES), - linear_term(bits::MEMORY_8BYTES, cols::MEMORY_8BYTES), - linear_term(bits::C_TYPE, cols::C_TYPE_INSTRUCTION), - linear_term(bits::SIGNED, cols::SIGNED), - linear_term(bits::MP_SELECTOR, cols::MP_SELECTOR), - linear_term(bits::MULDIV_SELECTOR, cols::MULDIV_SELECTOR), - linear_term(bits::WORD_INSTR, cols::WORD_INSTR), - // ALU selector flags (bits 11-26) - linear_term(bits::OP_ADD, cols::ADD), - linear_term(bits::OP_SUB, cols::SUB), - linear_term(bits::OP_SLT, cols::SLT), - linear_term(bits::OP_AND, cols::AND), - linear_term(bits::OP_OR, cols::OR), - linear_term(bits::OP_XOR, cols::XOR), - linear_term(bits::OP_SHIFT, cols::SHIFT), - linear_term(bits::OP_JALR, cols::JALR), - linear_term(bits::OP_BEQ, cols::BEQ), - linear_term(bits::OP_BLT, cols::BLT), - linear_term(bits::OP_LOAD, cols::LOAD), - linear_term(bits::OP_STORE, cols::STORE), - linear_term(bits::OP_MUL, cols::MUL), - linear_term(bits::OP_DIVREM, cols::DIVREM), - linear_term(bits::OP_ECALL, cols::ECALL), - linear_term(bits::OP_EBREAK, cols::EBREAK), - // Register indices (bits 27-50) - linear_term(bits::RS1, cols::RS1), - linear_term(bits::RS2, cols::RS2), - linear_term(bits::RD, cols::RD), + pow2_term(pd::READ_REG1, cols::READ_REGISTER1), + pow2_term(pd::READ_REG2, cols::READ_REGISTER2), + pow2_term(pd::WRITE_REG, cols::WRITE_REGISTER), + pow2_term(pd::WORD_INSTR, cols::WORD_INSTR), + pow2_term(pd::ALU, cols::ALU), + pow2_term(pd::ADD, cols::ADD), + pow2_term(pd::SUB, cols::SUB), + pow2_term(pd::MEMORY, cols::MEMORY), + pow2_term(pd::BRANCH, cols::BRANCH), + pow2_term(pd::ECALL, cols::ECALL), + pow2_term(pd::RS1, cols::RS1), + pow2_term(pd::RS2, cols::RS2), + pow2_term(pd::RD, cols::RD), + pow2_term(pd::HALF_INSTRUCTION_LENGTH, cols::HALF_INSTRUCTION_LENGTH), + pow2_term(pd::ALU_FLAGS, cols::ALU_FLAGS), + pow2_term(pd::MEM_FLAGS, cols::MEM_FLAGS), ]), ], )); // ------------------------------------------------------------------------- - // LT interaction (for SLT, BLT) - TODO: Re-add when properly implemented - // ------------------------------------------------------------------------- - // The LT table receiver expects: lhs (DWordHHW: 3 cols), rhs (DWordHHW: 3 cols), signed, lt - // The CPU has arg1/arg2 as DWordBL (8 bytes), needs Linear bus values to repack to HHW format - // For now, commented out until we implement the proper packing. - // - // interactions.push(BusInteraction::sender( - // BusId::Lt, - // Multiplicity::Column(cols::SLT), - // vec![...], // Need Linear to repack DWordBL -> DWordHHW - // )); - - // ------------------------------------------------------------------------- - // AND_BYTE interactions (×8 for each byte) - // ------------------------------------------------------------------------- - for i in 0..8 { - interactions.push(BusInteraction::sender( - BusId::AndByte, - Multiplicity::Column(cols::AND), - vec![ - BusValue::Packed { - start_column: cols::ARG1[i], - packing: Packing::Direct, - }, - BusValue::Packed { - start_column: cols::ARG2[i], - packing: Packing::Direct, - }, - BusValue::Packed { - start_column: cols::RES[i], - packing: Packing::Direct, - }, - ], - )); - } - - // ------------------------------------------------------------------------- - // OR_BYTE interactions (×8) - // ------------------------------------------------------------------------- - for i in 0..8 { - interactions.push(BusInteraction::sender( - BusId::OrByte, - Multiplicity::Column(cols::OR), - vec![ - BusValue::Packed { - start_column: cols::ARG1[i], - packing: Packing::Direct, - }, - BusValue::Packed { - start_column: cols::ARG2[i], - packing: Packing::Direct, - }, - BusValue::Packed { - start_column: cols::RES[i], - packing: Packing::Direct, - }, - ], - )); - } - - // ------------------------------------------------------------------------- - // XOR_BYTE interactions (×8) - // ------------------------------------------------------------------------- - for i in 0..8 { - interactions.push(BusInteraction::sender( - BusId::XorByte, - Multiplicity::Column(cols::XOR), - vec![ - BusValue::Packed { - start_column: cols::ARG1[i], - packing: Packing::Direct, - }, - BusValue::Packed { - start_column: cols::ARG2[i], - packing: Packing::Direct, - }, - BusValue::Packed { - start_column: cols::RES[i], - packing: Packing::Direct, - }, - ], - )); - } - - // ------------------------------------------------------------------------- - // SIGN template: MSB16 interactions for extension bit extraction + // ALU: unified dispatch ALU[rv1, arg2, alu_flags] -> cast(res, WL). // ------------------------------------------------------------------------- - // SIGN(rv1[1], word_instr) -> rv1_ext_bit - // rv1[1] is a Half (bits 16-31), MSB16 extracts bit 31 interactions.push(BusInteraction::sender( - BusId::Msb16, - Multiplicity::Column(cols::WORD_INSTR), + BusId::Alu, + Multiplicity::Column(cols::ALU), vec![ BusValue::Packed { - start_column: cols::RV1_1, - packing: Packing::Direct, + start_column: cols::RV1_0, + packing: Packing::DWordWL, }, BusValue::Packed { - start_column: cols::RV1_EXT_BIT, + start_column: cols::ARG2_0, + packing: Packing::DWordWL, + }, + BusValue::Packed { + start_column: cols::ALU_FLAGS, packing: Packing::Direct, }, + res_cast_wl(), ], )); - // SIGN(rv2[1], word_instr) -> rv2_ext_bit + // ------------------------------------------------------------------------- + // CPU32: delegate word (`*W`) instructions (mult = word_instr). + // CPU32[timestamp::DWordWL, pc::DWordWL, half_instruction_length]. + // ------------------------------------------------------------------------- interactions.push(BusInteraction::sender( - BusId::Msb16, + BusId::Cpu32, Multiplicity::Column(cols::WORD_INSTR), vec![ BusValue::Packed { - start_column: cols::RV2_1, + start_column: cols::TIMESTAMP, packing: Packing::Direct, }, + BusValue::constant(0), // timestamp_hi (CPU timestamps fit in 32 bits) + BusValue::Packed { + start_column: cols::PC_0, + packing: Packing::DWordWL, + }, BusValue::Packed { - start_column: cols::RV2_EXT_BIT, + start_column: cols::HALF_INSTRUCTION_LENGTH, packing: Packing::Direct, }, ], )); // ------------------------------------------------------------------------- - // MSB16 interaction for res extension bit extraction + // Register reads/writes via MEMW (24-element read, 16-element write). + // rv1/rv2/rvd are DWordWL, so the value words are emitted directly. // ------------------------------------------------------------------------- - // MSB16[res::DWordHL[1]] -> res_ext_bit, multiplicity = word_instr - // res::DWordHL[1] is the half at bits 16-31 = res[2] + 256*res[3] + interactions.push(memw_register_read( + cols::READ_REGISTER1, + cols::RS1, + cols::RV1_0, + cols::RV1_1, + 0, + )); + interactions.push(memw_register_read( + cols::READ_REGISTER2, + cols::RS2, + cols::RV2_0, + cols::RV2_1, + 1, + )); + // Register write of rvd at timestamp+2 (16 elements, no `old`). interactions.push(BusInteraction::sender( - BusId::Msb16, - Multiplicity::Column(cols::WORD_INSTR), + BusId::Memw, + Multiplicity::Column(cols::WRITE_REGISTER), vec![ - BusValue::linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::RES[2], - }, - LinearTerm::Column { - coefficient: 256, - column: cols::RES[3], - }, - ]), + BusValue::constant(1), // is_register + BusValue::linear(vec![LinearTerm::Column { + coefficient: 2, + column: cols::RD, + }]), // base_address[0] = 2*rd + BusValue::constant(0), // base_address[1] BusValue::Packed { - start_column: cols::RES_EXT_BIT, + start_column: cols::RVD_0, packing: Packing::Direct, }, - ], - )); - - // ------------------------------------------------------------------------- - // ZERO interaction for is_equal (BEQ) - // ------------------------------------------------------------------------- - // ZERO[sum(res[0..7])] -> is_equal, multiplicity = BEQ - // If all 8 bytes of res are zero, sum = 0, is_equal = 1 - interactions.push(BusInteraction::sender( - BusId::Zero, - Multiplicity::Column(cols::BEQ), - vec![ - // Sum of all 8 result bytes as linear combination - BusValue::linear(vec![ - stark::lookup::LinearTerm::Column { - coefficient: 1, - column: cols::RES[0], - }, - stark::lookup::LinearTerm::Column { - coefficient: 1, - column: cols::RES[1], - }, - stark::lookup::LinearTerm::Column { - coefficient: 1, - column: cols::RES[2], - }, - stark::lookup::LinearTerm::Column { - coefficient: 1, - column: cols::RES[3], - }, - stark::lookup::LinearTerm::Column { - coefficient: 1, - column: cols::RES[4], - }, - stark::lookup::LinearTerm::Column { - coefficient: 1, - column: cols::RES[5], - }, - stark::lookup::LinearTerm::Column { - coefficient: 1, - column: cols::RES[6], - }, - stark::lookup::LinearTerm::Column { - coefficient: 1, - column: cols::RES[7], - }, - ]), - BusValue::Packed { - start_column: cols::IS_EQUAL, - packing: Packing::Direct, - }, - ], - )); - - // ------------------------------------------------------------------------- - // LT interaction (for SLT, BLT) - // ------------------------------------------------------------------------- - // LT[arg1, arg2, signed] -> res[0] - // multiplicity = SLT + BLT - // - // LT bus uses 2 elements per 64-bit operand: [lo32, hi32] - // arg1/arg2 are DWordBL (8 bytes) - use Packing::DWordBL to produce 2 elements - interactions.push(BusInteraction::sender( - BusId::Lt, - // SLT + BLT using Multiplicity::Sum - Multiplicity::Sum(cols::SLT, cols::BLT), - vec![ - // arg1 as DWordBL (8 bytes → 2 elements: [lo32, hi32]) - BusValue::Packed { - start_column: cols::ARG1[0], - packing: Packing::DWordBL, - }, - // arg2 as DWordBL (8 bytes → 2 elements: [lo32, hi32]) - BusValue::Packed { - start_column: cols::ARG2[0], - packing: Packing::DWordBL, - }, - // signed flag - BusValue::Packed { - start_column: cols::SIGNED, - packing: Packing::Direct, - }, - // lt result (res[0]) - BusValue::Packed { - start_column: cols::RES[0], - packing: Packing::Direct, - }, - ], - )); - - // ------------------------------------------------------------------------- - // MUL interaction (for MUL, MULH, MULHSU, MULHU) - // ------------------------------------------------------------------------- - // MUL[arg1, signed, arg2, mp_selector, rvd, muldiv_selector] per spec CPU-CA44 - // multiplicity = MUL - // - // The MUL table expects DWordHL (4 halfwords), but CPU has DWordBL (8 bytes). - // Both pack to 2 words (lo32, hi32), so the signatures match for the same values. - // - // rhs_signed = mp_selector per spec: - // - MUL/MULH: mp_selector=1 (both operands signed) - // - MULHU/MULHSU: mp_selector=0 (rhs unsigned) - // - // muldiv_selector distinguishes lo (0) from hi (1) result - interactions.push(BusInteraction::sender( - BusId::Mul, - Multiplicity::Column(cols::MUL), - vec![ - // arg1 (lhs) as DWordBL (8 bytes → 2 elements) - BusValue::Packed { - start_column: cols::ARG1[0], - packing: Packing::DWordBL, - }, - // lhs_signed = signed - BusValue::Packed { - start_column: cols::SIGNED, - packing: Packing::Direct, - }, - // arg2 (rhs) as DWordBL (8 bytes → 2 elements) - BusValue::Packed { - start_column: cols::ARG2[0], - packing: Packing::DWordBL, - }, - // rhs_signed = mp_selector - BusValue::Packed { - start_column: cols::MP_SELECTOR, - packing: Packing::Direct, - }, - // result (res) as DWordBL (8 bytes → 2 elements) per spec CPU-CA44. - // Must send res (raw MUL output), not rvd. For MULW, rvd = sign_extend(res[31:0]), - // which can differ from res when bits [63:32] ≠ sign_extend(bit31) of res. - BusValue::Packed { - start_column: cols::RES[0], - packing: Packing::DWordBL, - }, - // muldiv_selector: 0=lo (MUL), 1=hi (MULH/MULHSU/MULHU) - BusValue::Packed { - start_column: cols::MULDIV_SELECTOR, - packing: Packing::Direct, - }, - ], - )); - - // ------------------------------------------------------------------------- - // DVRM interaction (for DIV, DIVU, REM, REMU) — CPU-CA45 - // ------------------------------------------------------------------------- - // DVRM[rvd; arg1, arg2, signed, muldiv_selector] - // multiplicity = DIVREM - interactions.push(BusInteraction::sender( - BusId::Dvrm, - Multiplicity::Column(cols::DIVREM), - vec![ - // arg1 (numerator n) as DWordBL (8 bytes → 2 elements) - BusValue::Packed { - start_column: cols::ARG1[0], - packing: Packing::DWordBL, - }, - // arg2 (denominator d) as DWordBL (8 bytes → 2 elements) - BusValue::Packed { - start_column: cols::ARG2[0], - packing: Packing::DWordBL, - }, - // signed - BusValue::Packed { - start_column: cols::SIGNED, - packing: Packing::Direct, - }, - // result (res) as DWordBL (8 bytes → 2 elements) per spec CPU-CA45. - // Must send res (raw DVRM output), not rvd. For DIVW/REMW, rvd = sign_extend(res[31:0]), - // which can differ from res when bits [63:32] ≠ sign_extend(bit31) of res. - BusValue::Packed { - start_column: cols::RES[0], - packing: Packing::DWordBL, - }, - // muldiv_selector: 0=quotient (DIV), 1=remainder (REM) - BusValue::Packed { - start_column: cols::MULDIV_SELECTOR, - packing: Packing::Direct, - }, - ], - )); - - // ------------------------------------------------------------------------- - // SHIFT interaction (for SLL, SRL, SRA) — CPU-CA43 - // ------------------------------------------------------------------------- - // SHIFT[res::DWordWL; arg1::DWordHL, arg2[0], mp_selector, signed, word_instr] - // multiplicity = SHIFT - interactions.push(BusInteraction::sender( - BusId::Shift, - Multiplicity::Column(cols::SHIFT), - vec![ - // res (result) as DWordBL (8 bytes → 2 elements, same as DWordWL) - BusValue::Packed { - start_column: cols::RES[0], - packing: Packing::DWordBL, - }, - // arg1 (input) as DWordBL (8 bytes → 2 elements) - BusValue::Packed { - start_column: cols::ARG1[0], - packing: Packing::DWordBL, - }, - // arg2[0] (shift amount byte) - BusValue::Packed { - start_column: cols::ARG2[0], - packing: Packing::Direct, - }, - // mp_selector (direction: 0=left, 1=right) - BusValue::Packed { - start_column: cols::MP_SELECTOR, - packing: Packing::Direct, - }, - // signed - BusValue::Packed { - start_column: cols::SIGNED, - packing: Packing::Direct, - }, - // word_instr - BusValue::Packed { - start_column: cols::WORD_INSTR, - packing: Packing::Direct, - }, - ], - )); - - // ========================================================================= - // MEMW and LOAD bus interactions (M1, M3, M5, M6, M7) - // ========================================================================= - // M1 and M3: Register read interactions (CPU → MEMW μ_read) - // ------------------------------------------------------------------------- - // M1: MEMW[rv1; 1, 2*rs1, rv1, timestamp+0, 1, 0, 0] | read_register1 - // ------------------------------------------------------------------------- - // Read from rs1 register via MEMW. Format: 24 elements - // [old[8], is_register, base_addr[2], value[8], timestamp[2], write2, write4, write8] - // - // Registers are stored as WL (2 words), remaining 6 values are unconstrained (zeros). - // rv1 is DWordWHH (3 cols: Half, Half, Word) -> pack as WL: lo32 = rv1[0] + 2^16*rv1[1], hi32 = rv1[2] - interactions.push(BusInteraction::sender( - BusId::Memw, - Multiplicity::Column(cols::READ_REGISTER1), - vec![ - // old[0] = lo32 = RV1_0 + 2^16 * RV1_1 - BusValue::linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::RV1_0, - }, - LinearTerm::Column { - coefficient: 65536, - column: cols::RV1_1, - }, - ]), - // old[1] = hi32 = RV1_2 - BusValue::Packed { - start_column: cols::RV1_2, - packing: Packing::Direct, - }, - // old[2..7] = 0 (unconstrained for registers) - BusValue::constant(0), - BusValue::constant(0), - BusValue::constant(0), - BusValue::constant(0), - BusValue::constant(0), - BusValue::constant(0), - // is_register = 1 - BusValue::constant(1), - // base_address[0] = 2 * rs1 - BusValue::linear(vec![LinearTerm::Column { - coefficient: 2, - column: cols::RS1, - }]), - // base_address[1] = 0 - BusValue::constant(0), - // value[0..7] = same as old (rv1 as WL + 6 zeros) - BusValue::linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::RV1_0, - }, - LinearTerm::Column { - coefficient: 65536, - column: cols::RV1_1, - }, - ]), - BusValue::Packed { - start_column: cols::RV1_2, - packing: Packing::Direct, - }, - BusValue::constant(0), - BusValue::constant(0), - BusValue::constant(0), - BusValue::constant(0), - BusValue::constant(0), - BusValue::constant(0), - // timestamp[0] = timestamp, timestamp[1] = 0 - BusValue::Packed { - start_column: cols::TIMESTAMP, - packing: Packing::Direct, - }, - BusValue::constant(0), - // write2=1, write4=0, write8=0 (register access = 2 Words / 64 bits) - BusValue::constant(1), - BusValue::constant(0), - BusValue::constant(0), - ], - )); - - // ------------------------------------------------------------------------- - // M3: MEMW[rv2; 1, 2*rs2, rv2, timestamp+1, 0, 0, 1] | read_register2 - // ------------------------------------------------------------------------- - // Same pattern as M1 but with RV2 and timestamp+1 - interactions.push(BusInteraction::sender( - BusId::Memw, - Multiplicity::Column(cols::READ_REGISTER2), - vec![ - // old[0] = lo32 = RV2_0 + 2^16 * RV2_1 - BusValue::linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::RV2_0, - }, - LinearTerm::Column { - coefficient: 65536, - column: cols::RV2_1, - }, - ]), - // old[1] = hi32 = RV2_2 - BusValue::Packed { - start_column: cols::RV2_2, - packing: Packing::Direct, - }, - // old[2..7] = 0 - BusValue::constant(0), - BusValue::constant(0), - BusValue::constant(0), - BusValue::constant(0), - BusValue::constant(0), - BusValue::constant(0), - // is_register = 1 - BusValue::constant(1), - // base_address[0] = 2 * rs2 - BusValue::linear(vec![LinearTerm::Column { - coefficient: 2, - column: cols::RS2, - }]), - // base_address[1] = 0 - BusValue::constant(0), - // value[0..7] = rv2 as WL + 6 zeros - BusValue::linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::RV2_0, - }, - LinearTerm::Column { - coefficient: 65536, - column: cols::RV2_1, - }, - ]), - BusValue::Packed { - start_column: cols::RV2_2, - packing: Packing::Direct, - }, - BusValue::constant(0), - BusValue::constant(0), - BusValue::constant(0), - BusValue::constant(0), - BusValue::constant(0), - BusValue::constant(0), - // timestamp[0] = timestamp + 1, timestamp[1] = 0 - BusValue::linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::TIMESTAMP, - }, - LinearTerm::Constant(1), - ]), - BusValue::constant(0), - // write2=1, write4=0, write8=0 (register access = 2 Words / 64 bits) - BusValue::constant(1), - BusValue::constant(0), - BusValue::constant(0), - ], - )); - - // ------------------------------------------------------------------------- - // M5: MEMW[1, 2*rd, rvd, timestamp+2, 0, 0, 1] | write_register - // ------------------------------------------------------------------------- - // Write to rd register via MEMW. Format: 16 elements (write, no old) - // [is_register, base_addr[2], value[8], timestamp[2], write2, write4, write8] - // - // rvd is DWordWL (2 cols: Word, Word) - // MEMW uses EXCLUSIVE encoding for write flags: (0, 0, 1) for 8-byte access - // ("exactly N bytes" semantics, not "at least N bytes") - interactions.push(BusInteraction::sender( - BusId::Memw, - Multiplicity::Column(cols::WRITE_REGISTER), - vec![ - // is_register = 1 - BusValue::constant(1), - // base_address[0] = 2 * rd - BusValue::linear(vec![LinearTerm::Column { - coefficient: 2, - column: cols::RD, - }]), - // base_address[1] = 0 - BusValue::constant(0), - // value[0] = rvd_lo = RVD_0 - BusValue::Packed { - start_column: cols::RVD_0, - packing: Packing::Direct, - }, - // value[1] = rvd_hi = RVD_1 BusValue::Packed { start_column: cols::RVD_1, packing: Packing::Direct, }, - // value[2..7] = 0 BusValue::constant(0), BusValue::constant(0), BusValue::constant(0), BusValue::constant(0), BusValue::constant(0), BusValue::constant(0), - // timestamp[0] = timestamp + 2, timestamp[1] = 0 + // timestamp+2 BusValue::linear(vec![ LinearTerm::Column { coefficient: 1, @@ -1623,219 +748,50 @@ pub fn bus_interactions() -> Vec { LinearTerm::Constant(2), ]), BusValue::constant(0), - // write2=1, write4=0, write8=0 (EXCLUSIVE encoding for 2-Word register access) - BusValue::constant(1), + BusValue::constant(1), // write2 (register access = 2 words) BusValue::constant(0), BusValue::constant(0), ], )); // ------------------------------------------------------------------------- - // M6: LOAD[rvd; base_address, timestamp, read2, read4, read8, signed] | LOAD + // MEMORY: high-level LOAD/STORE dispatch (mult = MEMORY). + // MEMORY[timestamp, cast(res, WL) = address, rv2, mem_flags] -> rvd. // ------------------------------------------------------------------------- - // LOAD receiver expects: [res::DWordBL(2), base_address::DWordWL(2), timestamp::DWordWL(2), flags(3), signed(1)] = 10 elements - // - // For CPU LOAD: - // - rvd (the loaded result) corresponds to res - // - res (computed address = rv1 + imm) corresponds to base_address - // - memory_Xbytes flags use EXCLUSIVE encoding per spec ("exactly N bytes") interactions.push(BusInteraction::sender( - BusId::Load, - Multiplicity::Column(cols::LOAD), + BusId::MemoryOp, + Multiplicity::Column(cols::MEMORY), vec![ - // rvd as DWordWL (2 words) - this is the loaded value - // CPU RVD is already WL format - BusValue::Packed { - start_column: cols::RVD_0, - packing: Packing::DWordWL, - }, - // base_address = res (computed address) as DWordBL (8 bytes → 2 elements) - BusValue::Packed { - start_column: cols::RES[0], - packing: Packing::DWordBL, - }, - // timestamp as DWordWL: [timestamp, 0] BusValue::Packed { start_column: cols::TIMESTAMP, packing: Packing::Direct, }, - BusValue::constant(0), - // read flags: exclusive encoding (pass through directly) - BusValue::Packed { - start_column: cols::MEMORY_2BYTES, - packing: Packing::Direct, - }, + BusValue::constant(0), // timestamp_hi + res_cast_wl(), // address (2 words) BusValue::Packed { - start_column: cols::MEMORY_4BYTES, - packing: Packing::Direct, - }, + start_column: cols::RV2_0, + packing: Packing::DWordWL, + }, // value to store (2 words) BusValue::Packed { - start_column: cols::MEMORY_8BYTES, + start_column: cols::MEM_FLAGS, packing: Packing::Direct, }, - // signed flag BusValue::Packed { - start_column: cols::SIGNED, - packing: Packing::Direct, - }, + start_column: cols::RVD_0, + packing: Packing::DWordWL, + }, // loaded value (output) ], )); // ------------------------------------------------------------------------- - // M7: MEMW[0, res, rv2, timestamp+1, memory_2bytes, memory_4bytes, memory_8bytes] | STORE + // Inline PC memory tokens (mult = 1, per spec): read PC at the coordinated + // previous timestamp, write next_pc at timestamp+1. x255 lives at addresses + // 510/511. Padding rows participate too (they carry PC=1 and chain their + // timestamps); the HALT chip's consume_pc/emit_pc bridges the last real write + // to the padding chain. See `docs/cpu-rework-deviations.md` (D-PAD). // ------------------------------------------------------------------------- - // Write to memory via MEMW. Format: 16 elements - // [is_register, base_addr[2], value[8], timestamp[2], write2, write4, write8] - // - // For STORE: - // - is_register = 0 (memory access) - // - base_address = res (computed address = rv1 + imm) - // - value = rv2 (the value being stored) - interactions.push(BusInteraction::sender( - BusId::Memw, - Multiplicity::Column(cols::STORE), - vec![ - // is_register = 0 (memory access) - BusValue::constant(0), - // base_address = res as DWordBL → 2 elements [lo32, hi32] - BusValue::Packed { - start_column: cols::RES[0], - packing: Packing::DWordBL, - }, - // value[0..7] = arg2 bytes (8 individual Direct elements) - BusValue::Packed { - start_column: cols::ARG2[0], - packing: Packing::Direct, - }, - BusValue::Packed { - start_column: cols::ARG2[1], - packing: Packing::Direct, - }, - BusValue::Packed { - start_column: cols::ARG2[2], - packing: Packing::Direct, - }, - BusValue::Packed { - start_column: cols::ARG2[3], - packing: Packing::Direct, - }, - BusValue::Packed { - start_column: cols::ARG2[4], - packing: Packing::Direct, - }, - BusValue::Packed { - start_column: cols::ARG2[5], - packing: Packing::Direct, - }, - BusValue::Packed { - start_column: cols::ARG2[6], - packing: Packing::Direct, - }, - BusValue::Packed { - start_column: cols::ARG2[7], - packing: Packing::Direct, - }, - // timestamp[0] = timestamp + 1, timestamp[1] = 0 - BusValue::linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::TIMESTAMP, - }, - LinearTerm::Constant(1), - ]), - BusValue::constant(0), - // write flags: exclusive encoding (pass through directly) - BusValue::Packed { - start_column: cols::MEMORY_2BYTES, - packing: Packing::Direct, - }, - BusValue::Packed { - start_column: cols::MEMORY_4BYTES, - packing: Packing::Direct, - }, - BusValue::Packed { - start_column: cols::MEMORY_8BYTES, - packing: Packing::Direct, - }, - ], - )); - - // ========================================================================= - // Inline PC memory interactions (replaces CM54 MEMW interaction) - // ========================================================================= - // CPU directly talks to the low-level memory bus for PC register (x255, - // addresses 510 and 511), bypassing MEMW_R. - - // Non-padding multiplicity: sum of all ALU selector flags - let non_pad_mult = Multiplicity::Linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::ADD, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::SUB, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::SLT, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::AND, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::OR, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::XOR, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::SHIFT, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::JALR, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::BEQ, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::BLT, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::LOAD, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::STORE, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::MUL, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::DIVREM, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::ECALL, - }, - LinearTerm::Column { - coefficient: 1, - column: cols::EBREAK, - }, - ]); - + let pc_mult = Multiplicity::One; // prev_ts_lo = timestamp - 3*(1 - pc_double_read) + 2^32 * borrow - // = timestamp - 3 + 3*pc_double_read + 2^32 * borrow let prev_ts_lo = BusValue::linear(vec![ LinearTerm::Column { coefficient: 1, @@ -1851,21 +807,21 @@ pub fn bus_interactions() -> Vec { column: cols::PREV_PC_TIMESTAMP_BORROW, }, ]); - - // prev_ts_hi = 0 - borrow - // The -1 cancels the +2^32 added to prev_ts_lo when borrow fires, keeping the - // 64-bit timestamp correct: (prev_ts_hi * 2^32 + prev_ts_lo) = timestamp - 3. let prev_ts_hi = BusValue::linear(vec![LinearTerm::Column { coefficient: -1, column: cols::PREV_PC_TIMESTAMP_BORROW, }]); - for i in 0..2u64 { - // PC read (sender, +1): consume old token - // memory[1, 510+i, 0, prev_ts_lo, prev_ts_hi, pc[i]] + let pc_col = if i == 0 { cols::PC_0 } else { cols::PC_1 }; + let next_pc_col = if i == 0 { + cols::NEXT_PC_0 + } else { + cols::NEXT_PC_1 + }; + // PC read (sender): consume the existing token. interactions.push(BusInteraction::sender( BusId::Memory, - non_pad_mult.clone(), + pc_mult.clone(), vec![ BusValue::constant(1), BusValue::constant(510 + i), @@ -1873,17 +829,15 @@ pub fn bus_interactions() -> Vec { prev_ts_lo.clone(), prev_ts_hi.clone(), BusValue::Packed { - start_column: if i == 0 { cols::PC_0 } else { cols::PC_1 }, + start_column: pc_col, packing: Packing::Direct, }, ], )); - - // PC write (receiver, -1): emit new token - // memory[1, 510+i, 0, timestamp+1, 0, next_pc[i]] + // PC write (receiver): emit the next token at timestamp+1. interactions.push(BusInteraction::receiver( BusId::Memory, - non_pad_mult.clone(), + pc_mult.clone(), vec![ BusValue::constant(1), BusValue::constant(510 + i), @@ -1897,11 +851,7 @@ pub fn bus_interactions() -> Vec { ]), BusValue::constant(0), BusValue::Packed { - start_column: if i == 0 { - cols::NEXT_PC_0 - } else { - cols::NEXT_PC_1 - }, + start_column: next_pc_col, packing: Packing::Direct, }, ], @@ -1909,159 +859,92 @@ pub fn bus_interactions() -> Vec { } // ------------------------------------------------------------------------- - // BRANCH interaction (for branch/jump target calculation) + // BRANCH: target computation (mult = branch_cond). + // BRANCH[pc, imm, rv1, JALR] -> next_pc. JALR ≡ mem_flags under BRANCH. + // Order matches the BRANCH table receiver: [next_pc, pc, imm, register, JALR]. // ------------------------------------------------------------------------- - // CPU-CO68: BRANCH[next_pc; pc, imm, arg1::DWordWL, JALR] | branch_cond - // - // Sends to BRANCH table when branch_cond is true. - // Bus signature: [next_pc[0], next_pc[1], pc[0], pc[1], offset[0], offset[1], register[0], register[1], JALR] - // - next_pc: DWordWL (2 words) from NEXT_PC_0, NEXT_PC_1 - // - pc: DWordWL (2 words) from PC_0, PC_1 - // - offset: DWordWL (2 words) from IMM_0, IMM_1 (already sign-extended) - // - register: DWordWL (2 words) - arg1 (DWordBL: 8 bytes) repacked as 2 words - // - JALR: Bit flag interactions.push(BusInteraction::sender( BusId::Branch, Multiplicity::Column(cols::BRANCH_COND), vec![ - // next_pc[0] (Word) - low 32 bits BusValue::Packed { start_column: cols::NEXT_PC_0, packing: Packing::Direct, }, - // next_pc[1] (Word) - high 32 bits BusValue::Packed { start_column: cols::NEXT_PC_1, packing: Packing::Direct, }, - // pc[0] (Word) BusValue::Packed { start_column: cols::PC_0, packing: Packing::Direct, }, - // pc[1] (Word) BusValue::Packed { start_column: cols::PC_1, packing: Packing::Direct, }, - // offset[0] = imm[0] (Word) - low 32 bits of immediate BusValue::Packed { start_column: cols::IMM_0, packing: Packing::Direct, }, - // offset[1] = imm[1] (Word) - high 32 bits of immediate (sign-extended) BusValue::Packed { start_column: cols::IMM_1, packing: Packing::Direct, }, - // register[0] = arg1[0..4] repacked as Word - // arg1_word0 = arg1[0] + 2^8*arg1[1] + 2^16*arg1[2] + 2^24*arg1[3] - BusValue::linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::ARG1[0], - }, - LinearTerm::Column { - coefficient: 256, - column: cols::ARG1[1], - }, - LinearTerm::Column { - coefficient: 65536, - column: cols::ARG1[2], - }, - LinearTerm::Column { - coefficient: 16777216, - column: cols::ARG1[3], - }, - ]), - // register[1] = arg1[4..8] repacked as Word - // arg1_word1 = arg1[4] + 2^8*arg1[5] + 2^16*arg1[6] + 2^24*arg1[7] - BusValue::linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::ARG1[4], - }, - LinearTerm::Column { - coefficient: 256, - column: cols::ARG1[5], - }, - LinearTerm::Column { - coefficient: 65536, - column: cols::ARG1[6], - }, - LinearTerm::Column { - coefficient: 16777216, - column: cols::ARG1[7], - }, - ]), - // JALR flag BusValue::Packed { - start_column: cols::JALR, + start_column: cols::RV1_0, packing: Packing::Direct, }, - ], - )); - - // ------------------------------------------------------------------------- - // Range checks (14 total): - // CPU-CR29: ARE_BYTES[rs1, rs2], CPU-CR30: ARE_BYTES[rd, 0] - // CPU-CR31.i: ARE_BYTES[arg1[2i], arg1[2i+1]] (i=0..3) - // CPU-CR32.i: ARE_BYTES[arg2[2i], arg2[2i+1]] (i=0..3) - // CPU-CR33.i: ARE_BYTES[res[2i], res[2i+1]] (i=0..3) - // ------------------------------------------------------------------------- - // RS1 and RS2 share one ARE_BYTES check; RD uses 0 as the second argument. - // ARG1/ARG2/RES are 8-byte little-endian values — adjacent byte pairs are - // batched into ARE_BYTES checks. Each pair sends two separate bus values - // [lo, hi], so the LogUp fingerprint forces each byte to match individually - // against the BITWISE table's X in [0,255] and Y in [0,255]. - // Every CPU row (including padding) sends with Multiplicity::One. - interactions.push(BusInteraction::sender( - BusId::AreBytes, - Multiplicity::One, - vec![ BusValue::Packed { - start_column: cols::RS1, + start_column: cols::RV1_1, packing: Packing::Direct, }, BusValue::Packed { - start_column: cols::RS2, + start_column: cols::MEM_FLAGS, packing: Packing::Direct, - }, + }, // JALR ], )); - interactions.push(BusInteraction::sender( - BusId::AreBytes, - Multiplicity::One, - vec![ - BusValue::Packed { - start_column: cols::RD, + + // ------------------------------------------------------------------------- + // Range checks: ARE_BYTES (rs1/rs2, rd/half_instruction_length, alu_flags/mem_flags) + // and IS_HALF on each `res` half. Every row sends (incl. padding: all 0). + // ------------------------------------------------------------------------- + for (a, b) in [ + (cols::RS1, cols::RS2), + (cols::RD, cols::HALF_INSTRUCTION_LENGTH), + (cols::ALU_FLAGS, cols::MEM_FLAGS), + ] { + interactions.push(BusInteraction::sender( + BusId::AreBytes, + Multiplicity::One, + vec![ + BusValue::Packed { + start_column: a, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: b, + packing: Packing::Direct, + }, + ], + )); + } + for &res_col in &cols::RES { + interactions.push(BusInteraction::sender( + BusId::IsHalfword, + Multiplicity::One, + vec![BusValue::Packed { + start_column: res_col, packing: Packing::Direct, - }, - BusValue::constant(0), - ], - )); - for arr in [&cols::ARG1, &cols::ARG2, &cols::RES] { - for i in 0..4 { - interactions.push(BusInteraction::sender( - BusId::AreBytes, - Multiplicity::One, - vec![ - BusValue::Packed { - start_column: arr[2 * i], - packing: Packing::Direct, - }, - BusValue::Packed { - start_column: arr[2 * i + 1], - packing: Packing::Direct, - }, - ], - )); - } + }], + )); } - // ECALL interaction (shared bus for HALT, COMMIT, and KECCAK) // ------------------------------------------------------------------------- - // multiplicity = ECALL (all ECALLs, each receiver matches on syscall number) + // ECALL: system-call bus (HALT/COMMIT/KECCAK receive). mult = ECALL. + // ECALL[timestamp, rv1]. + // ------------------------------------------------------------------------- interactions.push(BusInteraction::sender( BusId::Ecall, Multiplicity::Column(cols::ECALL), @@ -2070,22 +953,10 @@ pub fn bus_interactions() -> Vec { start_column: cols::TIMESTAMP, packing: Packing::Direct, }, - BusValue::constant(0), // timestamp_hi = 0 (CPU timestamps fit in u32) - // cast(rv1, DWordWL)[0] = rv1_lo32 = RV1_0 + 2^16 * RV1_1 - BusValue::linear(vec![ - LinearTerm::Column { - coefficient: 1, - column: cols::RV1_0, - }, - LinearTerm::Column { - coefficient: 65536, - column: cols::RV1_1, - }, - ]), - // cast(rv1, DWordWL)[1] = rv1_hi32 = RV1_2 + BusValue::constant(0), BusValue::Packed { - start_column: cols::RV1_2, - packing: Packing::Direct, + start_column: cols::RV1_0, + packing: Packing::DWordWL, }, ], )); @@ -2093,18 +964,75 @@ pub fn bus_interactions() -> Vec { interactions } -// ========================================================================= -// Constraints (placeholder - will be implemented in constraints/) -// ========================================================================= - -// The CPU constraints include: -// 1. Range checks (IS_BIT) for all bit flags - via templates -// 2. ALU dispatch constraints (conditional on selector flags) -// 3. Extension constraints (arg1, arg2, rvd from rv1, rv2, res) -// 4. Branch condition computation -// 5. next_pc computation (increment or branch target) -// -// These will be implemented using: -// - IsBitConstraint template for flags -// - AddConstraint template for ADD, SUB, next_pc -// - Custom constraints for extension logic +/// MEMW register-read interaction (24 elements: `old(8), is_register, base(2), +/// value(8), timestamp(2), w2, w4, w8`). Register values are DWordWL (the two +/// value words are read directly; the remaining 6 byte slots are 0). +fn memw_register_read( + read_flag_col: usize, + rs_col: usize, + rv_lo_col: usize, + rv_hi_col: usize, + ts_offset: i64, +) -> BusInteraction { + let value_lo = || BusValue::Packed { + start_column: rv_lo_col, + packing: Packing::Direct, + }; + let value_hi = || BusValue::Packed { + start_column: rv_hi_col, + packing: Packing::Direct, + }; + let ts = if ts_offset == 0 { + BusValue::Packed { + start_column: cols::TIMESTAMP, + packing: Packing::Direct, + } + } else { + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::TIMESTAMP, + }, + LinearTerm::Constant(ts_offset), + ]) + }; + BusInteraction::sender( + BusId::Memw, + Multiplicity::Column(read_flag_col), + vec![ + // old[0..8] = rv (2 words) + 6 zeros + value_lo(), + value_hi(), + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + // is_register = 1 + BusValue::constant(1), + // base_address[0] = 2*rs, base_address[1] = 0 + BusValue::linear(vec![LinearTerm::Column { + coefficient: 2, + column: rs_col, + }]), + BusValue::constant(0), + // value[0..8] = rv (2 words) + 6 zeros + value_lo(), + value_hi(), + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + BusValue::constant(0), + // timestamp[0..2] + ts, + BusValue::constant(0), + // write2 = 1, write4 = 0, write8 = 0 (register = 2 words) + BusValue::constant(1), + BusValue::constant(0), + BusValue::constant(0), + ], + ) +} diff --git a/prover/src/tables/cpu32.rs b/prover/src/tables/cpu32.rs new file mode 100644 index 000000000..802d94851 --- /dev/null +++ b/prover/src/tables/cpu32.rs @@ -0,0 +1,808 @@ +//! CPU32 table. +//! +//! Handles all 32-bit word (`*W`) instructions delegated by the main CPU via +//! the `CPU32[timestamp, pc, half_instruction_length]` interaction. All `*W` +//! instructions are ALU-only, so there is no BRANCH/MEMORY/ECALL path. The chip +//! does its own DECODE lookup, reads the registers, sign-extends the inputs to +//! 64 bits, runs the ALU (or the ADD/SUB fast-path) and sign-extends the 32-bit +//! result back to 64 bits before writing `rd`. +//! +//! Spec: `spec/src/cpu32.toml`. +//! +//! ## Sign extension +//! `*W` instructions operate on the low 32 bits of the registers and produce a +//! sign-extended 64-bit result. `signed` (extracted from `alu_flags` bit 5) +//! selects sign- vs zero-extension of the inputs; the output `rvd` is always +//! sign-extended (RV64 `*W` semantics). +//! +//! Register reads use the cast-to-`DWordWL` encoding. + +use math::field::element::FieldElement; +use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; +use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; +use stark::table::TableView; +use stark::trace::TraceTable; + +use super::types::{ + BusId, FE, GoldilocksExtension, GoldilocksField, SHIFT_16, alu_op, packed_decode_shrunk, +}; +use crate::constraints::templates::{AddConstraint, AddOperand, new_is_bit_constraints}; + +// ========================================================================= +// Column indices for CPU32 table +// ========================================================================= + +/// Column definitions for the CPU32 table. +pub mod cols { + // Inputs (from the CPU32 interaction) + pub const TIMESTAMP_0: usize = 0; + pub const TIMESTAMP_1: usize = 1; + pub const PC_0: usize = 2; + pub const PC_1: usize = 3; + + // rs1 read + pub const RS1: usize = 4; + pub const READ_REGISTER1: usize = 5; + // rv1: DWordWHH = [Half, Half, Word] (low word as 2 halves + high word) + pub const RV1_0: usize = 6; + pub const RV1_1: usize = 7; + pub const RV1_2: usize = 8; + pub const RV1_SIGN: usize = 9; + // arg1: DWordWL = sign/zero-extended low word of rv1 + pub const ARG1_0: usize = 10; + pub const ARG1_1: usize = 11; + + // rs2 read + pub const RS2: usize = 12; + pub const READ_REGISTER2: usize = 13; + pub const RV2_0: usize = 14; + pub const RV2_1: usize = 15; + pub const RV2_2: usize = 16; + pub const RV2_SIGN: usize = 17; + // imm: DWordWL (fully sign-extended immediate) + pub const IMM_0: usize = 18; + pub const IMM_1: usize = 19; + // arg2: DWordWL = ext(rv2) or imm + pub const ARG2_0: usize = 20; + pub const ARG2_1: usize = 21; + + // res: DWordHL = ALU result (4 halves) + pub const RES_0: usize = 22; + pub const RES_1: usize = 23; + pub const RES_2: usize = 24; + pub const RES_3: usize = 25; + pub const RES_SIGN: usize = 26; + + // rd write + pub const RD: usize = 27; + pub const WRITE_REGISTER: usize = 28; + // rvd: DWordWL = sign-extended low word of res + pub const RVD_0: usize = 29; + pub const RVD_1: usize = 30; + + // ALU control + pub const ALU: usize = 31; + pub const ALU_FLAGS: usize = 32; + pub const ADD: usize = 33; + pub const SUB: usize = 34; + /// half the byte length (1 or 2); real length = `2 * half`. + pub const HALF_INSTRUCTION_LENGTH: usize = 35; + /// signed: extracted from `alu_flags` bit 5 (via BYTE_ALU[AND, 32, alu_flags]). + pub const SIGNED: usize = 36; + + /// μ: multiplicity + pub const MU: usize = 37; + + /// Total number of columns + pub const NUM_COLUMNS: usize = 38; +} + +/// Mask selecting `signed` from the `alu_flags` byte (bit 5). +const SIGNED_MASK: u64 = 1 << packed_decode_shrunk::ALU_FLAGS_SIGNED; +/// `2^32 - 1`, the sign-extension fill for the high word. +const HI_FILL: u64 = 0xFFFF_FFFF; + +// ========================================================================= +// Trace generation +// ========================================================================= + +/// A single CPU32 operation (a delegated `*W` instruction). +/// +/// `res` is the raw 64-bit ALU result (computed by the executor); `rvd` is +/// derived from it by sign-extending the low 32 bits. +#[derive(Debug, Clone, Default, Hash, PartialEq, Eq)] +pub struct Cpu32Operation { + pub timestamp: u64, + pub pc: u64, + pub rs1: u8, + pub read_register1: bool, + pub rv1: u64, + pub rs2: u8, + pub read_register2: bool, + pub rv2: u64, + pub imm: u64, + /// Raw 64-bit ALU result. + pub res: u64, + pub rd: u8, + pub write_register: bool, + pub alu: bool, + pub alu_flags: u8, + pub add: bool, + pub sub: bool, + pub half_instruction_length: u8, +} + +/// Derived auxiliary values for a CPU32 row. +pub struct Cpu32Aux { + pub signed: bool, + pub rv1_sign: bool, + pub arg1: u64, + pub rv2_sign: bool, + pub arg2: u64, + pub res_sign: bool, + pub rvd: u64, +} + +impl Cpu32Operation { + /// Computes the derived auxiliary values (signs, extended args, rvd). + pub fn compute_aux(&self) -> Cpu32Aux { + let signed = (self.alu_flags as u64 & SIGNED_MASK) != 0; + + // Sign bits = MSB (bit 31) of the low word of each value. + let rv1_sign = (self.rv1 >> 31) & 1 == 1; + let rv2_sign = (self.rv2 >> 31) & 1 == 1; + let res_sign = (self.res >> 31) & 1 == 1; + + // arg1 = ext(rv1 low word): low word as-is, high word = (2^32-1) if + // (signed AND rv1_sign) else 0. + let arg1_hi = if signed && rv1_sign { HI_FILL } else { 0 }; + let arg1 = (self.rv1 & 0xFFFF_FFFF) | (arg1_hi << 32); + + // arg2 = ext(rv2 low word) + imm. By the decoding assumption exactly one + // of rv2 / imm is non-zero, so the per-word sums never overflow. + let arg2_lo = (self.rv2 & 0xFFFF_FFFF) + (self.imm & 0xFFFF_FFFF); + let arg2_hi = if signed && rv2_sign { HI_FILL } else { 0 } + (self.imm >> 32); + let arg2 = (arg2_lo & 0xFFFF_FFFF) | (arg2_hi << 32); + + // rvd = sign-extend(res low word) — always sign-extended for *W. + let rvd_hi = if res_sign { HI_FILL } else { 0 }; + let rvd = (self.res & 0xFFFF_FFFF) | (rvd_hi << 32); + + Cpu32Aux { + signed, + rv1_sign, + arg1, + rv2_sign, + arg2, + res_sign, + rvd, + } + } +} + +/// Generates the CPU32 trace from a list of operations. +/// +/// Each operation occupies its own row (μ = 1); the table is padded to the next +/// power of two (minimum 4). +pub fn generate_cpu32_trace( + operations: &[Cpu32Operation], +) -> TraceTable { + let num_rows = operations.len().next_power_of_two().max(4); + let mut data = vec![FE::zero(); num_rows * cols::NUM_COLUMNS]; + + for (row_idx, op) in operations.iter().enumerate() { + let base = row_idx * cols::NUM_COLUMNS; + let aux = op.compute_aux(); + + // Inputs + data[base + cols::TIMESTAMP_0] = FE::from(op.timestamp & 0xFFFF_FFFF); + data[base + cols::TIMESTAMP_1] = FE::from(op.timestamp >> 32); + data[base + cols::PC_0] = FE::from(op.pc & 0xFFFF_FFFF); + data[base + cols::PC_1] = FE::from(op.pc >> 32); + + // rv1 as DWordWHH: [Half, Half, Word] + data[base + cols::RS1] = FE::from(op.rs1 as u64); + data[base + cols::READ_REGISTER1] = FE::from(op.read_register1 as u64); + data[base + cols::RV1_0] = FE::from(op.rv1 & 0xFFFF); + data[base + cols::RV1_1] = FE::from((op.rv1 >> 16) & 0xFFFF); + data[base + cols::RV1_2] = FE::from(op.rv1 >> 32); + data[base + cols::RV1_SIGN] = FE::from(aux.rv1_sign as u64); + data[base + cols::ARG1_0] = FE::from(aux.arg1 & 0xFFFF_FFFF); + data[base + cols::ARG1_1] = FE::from(aux.arg1 >> 32); + + // rv2 as DWordWHH + data[base + cols::RS2] = FE::from(op.rs2 as u64); + data[base + cols::READ_REGISTER2] = FE::from(op.read_register2 as u64); + data[base + cols::RV2_0] = FE::from(op.rv2 & 0xFFFF); + data[base + cols::RV2_1] = FE::from((op.rv2 >> 16) & 0xFFFF); + data[base + cols::RV2_2] = FE::from(op.rv2 >> 32); + data[base + cols::RV2_SIGN] = FE::from(aux.rv2_sign as u64); + data[base + cols::IMM_0] = FE::from(op.imm & 0xFFFF_FFFF); + data[base + cols::IMM_1] = FE::from(op.imm >> 32); + data[base + cols::ARG2_0] = FE::from(aux.arg2 & 0xFFFF_FFFF); + data[base + cols::ARG2_1] = FE::from(aux.arg2 >> 32); + + // res as DWordHL: 4 halves + data[base + cols::RES_0] = FE::from(op.res & 0xFFFF); + data[base + cols::RES_1] = FE::from((op.res >> 16) & 0xFFFF); + data[base + cols::RES_2] = FE::from((op.res >> 32) & 0xFFFF); + data[base + cols::RES_3] = FE::from((op.res >> 48) & 0xFFFF); + data[base + cols::RES_SIGN] = FE::from(aux.res_sign as u64); + + // rd write + data[base + cols::RD] = FE::from(op.rd as u64); + data[base + cols::WRITE_REGISTER] = FE::from(op.write_register as u64); + data[base + cols::RVD_0] = FE::from(aux.rvd & 0xFFFF_FFFF); + data[base + cols::RVD_1] = FE::from(aux.rvd >> 32); + + // ALU control + data[base + cols::ALU] = FE::from(op.alu as u64); + data[base + cols::ALU_FLAGS] = FE::from(op.alu_flags as u64); + data[base + cols::ADD] = FE::from(op.add as u64); + data[base + cols::SUB] = FE::from(op.sub as u64); + data[base + cols::HALF_INSTRUCTION_LENGTH] = FE::from(op.half_instruction_length as u64); + data[base + cols::SIGNED] = FE::from(aux.signed as u64); + + data[base + cols::MU] = FE::one(); + } + + TraceTable::new_main(data, cols::NUM_COLUMNS, 1) +} + +// ========================================================================= +// Bus interactions +// ========================================================================= + +/// 2^16, to combine two halves into a word. +const HALF_SHIFT: i64 = 1 << 16; + +/// The 8-element MEMW value/old for a register read: `[lo_word, hi_word, 0×6]` +/// where `lo_word = lo0 + 2^16·lo1` (Q9: cast `DWordWHH` → `DWordWL`). +fn register_dword(lo0: usize, lo1: usize, hi: usize) -> Vec { + let mut v = vec![ + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: lo0, + }, + LinearTerm::Column { + coefficient: HALF_SHIFT, + column: lo1, + }, + ]), + BusValue::Packed { + start_column: hi, + packing: Packing::Direct, + }, + ]; + v.extend(std::iter::repeat_n(BusValue::constant(0), 6)); + v +} + +/// `timestamp + offset` as DWordWL: `[TIMESTAMP_0 + offset, TIMESTAMP_1]`. +fn timestamp_plus(offset: i64) -> Vec { + vec![ + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::TIMESTAMP_0, + }, + LinearTerm::Constant(offset), + ]), + BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }, + ] +} + +/// MEMW register **read** (24 elements: `old == value`, `is_register=1`, `write2=1`). +fn reg_read( + rs: usize, + lo0: usize, + lo1: usize, + hi: usize, + ts_offset: i64, + mult: usize, +) -> BusInteraction { + let mut values = register_dword(lo0, lo1, hi); // old + values.push(BusValue::constant(1)); // is_register + values.push(BusValue::linear(vec![LinearTerm::Column { + coefficient: 2, + column: rs, + }])); // base_address[0] = 2*rs + values.push(BusValue::constant(0)); // base_address[1] + values.extend(register_dword(lo0, lo1, hi)); // value + values.extend(timestamp_plus(ts_offset)); + values.push(BusValue::constant(1)); // write2 = 1 (register = 2 words) + values.push(BusValue::constant(0)); // write4 + values.push(BusValue::constant(0)); // write8 + BusInteraction::sender(BusId::Memw, Multiplicity::Column(mult), values) +} + +/// MEMW register **write** (16 elements: `value = [val_lo, val_hi, 0×6]`, `write2=1`). +fn reg_write( + rd: usize, + val_lo: usize, + val_hi: usize, + ts_offset: i64, + mult: usize, +) -> BusInteraction { + let mut values = vec![ + BusValue::constant(1), // is_register + BusValue::linear(vec![LinearTerm::Column { + coefficient: 2, + column: rd, + }]), // base_address[0] = 2*rd + BusValue::constant(0), // base_address[1] + BusValue::Packed { + start_column: val_lo, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: val_hi, + packing: Packing::Direct, + }, + ]; + values.extend(std::iter::repeat_n(BusValue::constant(0), 6)); // value[2..8] + values.extend(timestamp_plus(ts_offset)); + values.push(BusValue::constant(1)); // write2 = 1 + values.push(BusValue::constant(0)); // write4 + values.push(BusValue::constant(0)); // write8 + BusInteraction::sender(BusId::Memw, Multiplicity::Column(mult), values) +} + +/// All bus interactions for the CPU32 table. +pub fn bus_interactions() -> Vec { + use packed_decode_shrunk as pd; + let mut interactions = Vec::new(); + + // DECODE[pc, imm, packed_decode] (sender, mult μ); word_instr is constant 1, + // and there are no MEMORY/BRANCH/ECALL/mem_flags terms (CPU32 is ALU-only). + interactions.push(BusInteraction::sender( + BusId::Decode, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::PC_0, + packing: Packing::DWordWL, + }, + BusValue::Packed { + start_column: cols::IMM_0, + packing: Packing::DWordWL, + }, + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1 << pd::READ_REG1, + column: cols::READ_REGISTER1, + }, + LinearTerm::Column { + coefficient: 1 << pd::READ_REG2, + column: cols::READ_REGISTER2, + }, + LinearTerm::Column { + coefficient: 1 << pd::WRITE_REG, + column: cols::WRITE_REGISTER, + }, + LinearTerm::Constant(1 << pd::WORD_INSTR), // word_instr = 1 + LinearTerm::Column { + coefficient: 1 << pd::ALU, + column: cols::ALU, + }, + LinearTerm::Column { + coefficient: 1 << pd::ADD, + column: cols::ADD, + }, + LinearTerm::Column { + coefficient: 1 << pd::SUB, + column: cols::SUB, + }, + LinearTerm::Column { + coefficient: 1 << pd::RS1, + column: cols::RS1, + }, + LinearTerm::Column { + coefficient: 1 << pd::RS2, + column: cols::RS2, + }, + LinearTerm::Column { + coefficient: 1 << pd::RD, + column: cols::RD, + }, + LinearTerm::Column { + coefficient: 1 << pd::HALF_INSTRUCTION_LENGTH, + column: cols::HALF_INSTRUCTION_LENGTH, + }, + LinearTerm::Column { + coefficient: 1 << pd::ALU_FLAGS, + column: cols::ALU_FLAGS, + }, + ]), + ], + )); + + // Byte range checks: ARE_BYTES[x, 0]. + for col in [ + cols::HALF_INSTRUCTION_LENGTH, + cols::ALU_FLAGS, + cols::RS1, + cols::RS2, + cols::RD, + ] { + interactions.push(BusInteraction::sender( + BusId::AreBytes, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: col, + packing: Packing::Direct, + }, + BusValue::constant(0), + ], + )); + } + + // IS_HALF for the rv1/rv2 low-word halves and the res halves. + for col in [ + cols::RV1_0, + cols::RV1_1, + cols::RV2_0, + cols::RV2_1, + cols::RES_0, + cols::RES_1, + cols::RES_2, + cols::RES_3, + ] { + interactions.push(BusInteraction::sender( + BusId::IsHalfword, + Multiplicity::Column(cols::MU), + vec![BusValue::Packed { + start_column: col, + packing: Packing::Direct, + }], + )); + } + + // Register reads (rv1 @ ts+0, rv2 @ ts+1) and write (rvd @ ts+2). + interactions.push(reg_read( + cols::RS1, + cols::RV1_0, + cols::RV1_1, + cols::RV1_2, + 0, + cols::READ_REGISTER1, + )); + interactions.push(reg_read( + cols::RS2, + cols::RV2_0, + cols::RV2_1, + cols::RV2_2, + 1, + cols::READ_REGISTER2, + )); + interactions.push(reg_write( + cols::RD, + cols::RVD_0, + cols::RVD_1, + 2, + cols::WRITE_REGISTER, + )); + + // ALU[arg1, arg2, alu_flags] -> res (sender, mult ALU). res is DWordHL cast to DWordWL. + interactions.push(BusInteraction::sender( + BusId::Alu, + Multiplicity::Column(cols::ALU), + vec![ + BusValue::Packed { + start_column: cols::ARG1_0, + packing: Packing::DWordWL, + }, + BusValue::Packed { + start_column: cols::ARG2_0, + packing: Packing::DWordWL, + }, + BusValue::Packed { + start_column: cols::ALU_FLAGS, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::RES_0, + packing: Packing::DWordHL, + }, + ], + )); + + // BYTE_ALU[AND, 32, alu_flags] -> 32·signed (extracts the signed bit). + interactions.push(BusInteraction::sender( + BusId::ByteAlu, + Multiplicity::Column(cols::MU), + vec![ + BusValue::constant(alu_op::AND as u64), + BusValue::constant(1u64 << pd::ALU_FLAGS_SIGNED), // 32 + BusValue::Packed { + start_column: cols::ALU_FLAGS, + packing: Packing::Direct, + }, + BusValue::linear(vec![LinearTerm::Column { + coefficient: 1 << pd::ALU_FLAGS_SIGNED, + column: cols::SIGNED, + }]), + ], + )); + + // MSB16 sign extraction (high half of each low word). + for (half_col, sign_col) in [ + (cols::RV1_1, cols::RV1_SIGN), + (cols::RV2_1, cols::RV2_SIGN), + (cols::RES_1, cols::RES_SIGN), + ] { + interactions.push(BusInteraction::sender( + BusId::Msb16, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: half_col, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: sign_col, + packing: Packing::Direct, + }, + ], + )); + } + + // CPU32[timestamp, pc, half_instruction_length] (receiver from the main CPU). + interactions.push(BusInteraction::receiver( + BusId::Cpu32, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::DWordWL, + }, + BusValue::Packed { + start_column: cols::PC_0, + packing: Packing::DWordWL, + }, + BusValue::Packed { + start_column: cols::HALF_INSTRUCTION_LENGTH, + packing: Packing::Direct, + }, + ], + )); + + interactions +} + +// ========================================================================= +// Constraints +// ========================================================================= + +/// Arithmetic constraints for CPU32: the sign-extension `ext` group plus the +/// register-zero checks. (`IS_BIT` flags and the ADD/SUB carries are produced +/// by the template helpers in [`cpu32_constraints`].) +pub struct Cpu32Constraint { + constraint_idx: usize, + kind: Cpu32ConstraintKind, +} + +#[derive(Debug, Clone, Copy)] +pub enum Cpu32ConstraintKind { + /// `arg1[0] = rv1[0] + 2^16·rv1[1]` (low word of `arg1`). + Arg1Lo, + /// `arg1[1] = (2^32-1)·signed·rv1_sign` (sign/zero extension of the high word). + Arg1Hi, + /// `arg2[0] = rv2[0] + 2^16·rv2[1] + imm[0]`. + Arg2Lo, + /// `arg2[1] = (2^32-1)·signed·rv2_sign + imm[1]`. + Arg2Hi, + /// `rvd[0] = res[0] + 2^16·res[1]`. + RvdLo, + /// `rvd[1] = (2^32-1)·res_sign` (the `*W` result is always sign-extended). + RvdHi, + /// `(1 - read_col)·value_col = 0` (an unread register half is zero). + RegZero { read_col: usize, value_col: usize }, + /// `read_register2·imm[i] = 0` (decoding guarantees at most one is nonzero; + /// spec defense-in-depth assumption). `usize` is the `imm` limb column. + Arg2Exclusive { imm_col: usize }, + /// `(1 - μ)·flag = 0`: a register flag that gates a `MEMW` bus interaction + /// must be 0 on a padding row (`μ = 0`), otherwise a disconnected row could + /// emit a forged register read/write token (no DECODE binding, no CPU32 + /// delegation). Spec `cpu32.toml` (PR #646). `usize` is the flag column. + FlagImpliesMu { flag_col: usize }, +} + +impl Cpu32Constraint { + pub fn new(kind: Cpu32ConstraintKind, constraint_idx: usize) -> Self { + Self { + constraint_idx, + kind, + } + } +} + +impl TransitionConstraint for Cpu32Constraint { + fn degree(&self) -> usize { + match self.kind { + Cpu32ConstraintKind::Arg1Lo + | Cpu32ConstraintKind::Arg2Lo + | Cpu32ConstraintKind::RvdLo + | Cpu32ConstraintKind::RvdHi => 1, + // signed·sign (degree 2) and (1-read)·value (degree 2) + Cpu32ConstraintKind::Arg1Hi + | Cpu32ConstraintKind::Arg2Hi + | Cpu32ConstraintKind::RegZero { .. } + | Cpu32ConstraintKind::Arg2Exclusive { .. } + | Cpu32ConstraintKind::FlagImpliesMu { .. } => 2, + } + } + + fn constraint_idx(&self) -> usize { + self.constraint_idx + } + + fn evaluate(&self, step: &TableView) -> FieldElement + where + F: IsSubFieldOf, + E: IsField, + { + let get = |c: usize| step.get_main_evaluation_element(0, c).clone(); + let shift16 = FieldElement::::from(SHIFT_16); + let hi_fill = FieldElement::::from(HI_FILL); + let one = FieldElement::::one(); + + match self.kind { + Cpu32ConstraintKind::Arg1Lo => { + get(cols::ARG1_0) - get(cols::RV1_0) - &shift16 * get(cols::RV1_1) + } + Cpu32ConstraintKind::Arg1Hi => { + get(cols::ARG1_1) - hi_fill * get(cols::SIGNED) * get(cols::RV1_SIGN) + } + Cpu32ConstraintKind::Arg2Lo => { + get(cols::ARG2_0) + - get(cols::RV2_0) + - &shift16 * get(cols::RV2_1) + - get(cols::IMM_0) + } + Cpu32ConstraintKind::Arg2Hi => { + get(cols::ARG2_1) + - hi_fill * get(cols::SIGNED) * get(cols::RV2_SIGN) + - get(cols::IMM_1) + } + Cpu32ConstraintKind::RvdLo => { + get(cols::RVD_0) - get(cols::RES_0) - &shift16 * get(cols::RES_1) + } + Cpu32ConstraintKind::RvdHi => get(cols::RVD_1) - hi_fill * get(cols::RES_SIGN), + Cpu32ConstraintKind::RegZero { + read_col, + value_col, + } => (one - get(read_col)) * get(value_col), + Cpu32ConstraintKind::Arg2Exclusive { imm_col } => { + get(cols::READ_REGISTER2) * get(imm_col) + } + Cpu32ConstraintKind::FlagImpliesMu { flag_col } => { + (one - get(cols::MU)) * get(flag_col) + } + } + } +} + +/// Creates all transition constraints for the CPU32 table: +/// `IS_BIT` on the flag columns, the `ADD`/`SUB` fast-path carries, the +/// register-zero checks, and the sign-extension `ext` arithmetic. +pub fn cpu32_constraints( + constraint_idx_start: usize, +) -> ( + Vec>>, + usize, +) { + let mut constraints: Vec< + Box>, + > = Vec::new(); + + // IS_BIT on the flag columns and the multiplicity. + let (is_bit, mut idx) = new_is_bit_constraints( + &[ + cols::READ_REGISTER1, + cols::READ_REGISTER2, + cols::WRITE_REGISTER, + cols::ALU, + cols::ADD, + cols::SUB, + cols::MU, + ], + constraint_idx_start, + ); + for c in is_bit { + constraints.push(c.boxed()); + } + + // ADD fast-path: arg1 + arg2 = res (cond = ADD). + let (add_lo, add_hi) = AddConstraint::new_pair( + vec![cols::ADD], + AddOperand::dword(cols::ARG1_0), + AddOperand::dword(cols::ARG2_0), + AddOperand::from_dword_hl(cols::RES_0), + idx, + ); + idx += 2; + constraints.push(add_lo.boxed()); + constraints.push(add_hi.boxed()); + + // SUB fast-path: res = arg1 - arg2, encoded as arg2 + res = arg1 (cond = SUB). + let (sub_lo, sub_hi) = AddConstraint::new_pair( + vec![cols::SUB], + AddOperand::dword(cols::ARG2_0), + AddOperand::from_dword_hl(cols::RES_0), + AddOperand::dword(cols::ARG1_0), + idx, + ); + idx += 2; + constraints.push(sub_lo.boxed()); + constraints.push(sub_hi.boxed()); + + // Unread register limbs are zero. `rv1`/`rv2` span three limbs + // (low halfword, high halfword, high word), so all three must be forced to + // zero when the register is not read — the bus reads the full word + // `[lo0 + 2^16·lo1, hi]`, leaving `RV*_2` free otherwise. + for (read_col, value_col) in [ + (cols::READ_REGISTER1, cols::RV1_0), + (cols::READ_REGISTER1, cols::RV1_1), + (cols::READ_REGISTER1, cols::RV1_2), + (cols::READ_REGISTER2, cols::RV2_0), + (cols::READ_REGISTER2, cols::RV2_1), + (cols::READ_REGISTER2, cols::RV2_2), + ] { + constraints.push( + Cpu32Constraint::new( + Cpu32ConstraintKind::RegZero { + read_col, + value_col, + }, + idx, + ) + .boxed(), + ); + idx += 1; + } + + // Sign-extension (`ext`) arithmetic for arg1, arg2, rvd. + for kind in [ + Cpu32ConstraintKind::Arg1Lo, + Cpu32ConstraintKind::Arg1Hi, + Cpu32ConstraintKind::Arg2Lo, + Cpu32ConstraintKind::Arg2Hi, + Cpu32ConstraintKind::RvdLo, + Cpu32ConstraintKind::RvdHi, + ] { + constraints.push(Cpu32Constraint::new(kind, idx).boxed()); + idx += 1; + } + + // arg2 multiplex exclusivity (spec assumption): read_register2·imm[i] = 0. + for imm_col in [cols::IMM_0, cols::IMM_1] { + constraints.push( + Cpu32Constraint::new(Cpu32ConstraintKind::Arg2Exclusive { imm_col }, idx).boxed(), + ); + idx += 1; + } + + // flag ⇒ μ: a register flag gating a MEMW interaction must be 0 on padding + // rows (μ = 0), else a disconnected row injects a forged register access + // (spec `cpu32.toml`, PR #646). ALU is not gated: with `write_register = 0` + // its ALU-lookup result is never written back, so it has no side effect. + for flag_col in [ + cols::READ_REGISTER1, + cols::READ_REGISTER2, + cols::WRITE_REGISTER, + ] { + constraints.push( + Cpu32Constraint::new(Cpu32ConstraintKind::FlagImpliesMu { flag_col }, idx).boxed(), + ); + idx += 1; + } + + (constraints, idx) +} diff --git a/prover/src/tables/decode.rs b/prover/src/tables/decode.rs index 4805ffc42..f1fe14e03 100644 --- a/prover/src/tables/decode.rs +++ b/prover/src/tables/decode.rs @@ -10,25 +10,21 @@ //! - `imm`: DWordWL (2 cols) - fully extended 64-bit immediate //! - `μ`: BaseField (1 col) - multiplicity //! -//! ## packed_decode Format (51 bits) +//! ## packed_decode Format +//! +//! A single base-field element packing the control flags, register indices, and +//! the `alu_flags`/`mem_flags` bytes. The authoritative bit layout lives in +//! `packed_decode_shrunk` and is produced by `ShrunkDecode::pack` (both in +//! `tables/types.rs`) — consult those for the exact bit position of every field. +//! Summary (low → high bits): //! //! ```text -//! Bits [0]: read_register1 -//! Bits [1]: read_register2 -//! Bits [2]: write_register -//! Bits [3]: memory_2bytes -//! Bits [4]: memory_4bytes -//! Bits [5]: memory_8bytes -//! Bits [6]: c_type -//! Bits [7]: signed -//! Bits [8]: mp_selector -//! Bits [9]: muldiv_selector -//! Bits [10]: word_instr -//! Bits [11-26]: ALU flags (ADD, SUB, SLT, AND, OR, XOR, SHIFT, JALR, -//! BEQ, BLT, LOAD, STORE, MUL, DIVREM, ECALL, EBREAK) -//! Bits [27:35]: rs1 (8 bits) -//! Bits [35:43]: rs2 (8 bits) -//! Bits [43:51]: rd (8 bits) +//! Bits [0..10]: read_register1, read_register2, write_register, word_instr, +//! ALU, ADD, SUB, MEMORY, BRANCH, ECALL (one bit each) +//! Bits [10..34]: rs1, rs2, rd (8 bits each) +//! Bits [34..42]: half_instruction_length (Byte: byte length / 2) +//! Bits [42..50]: alu_flags (Byte: alu_op in bits 0-4, then signed / signed2|invert / muldiv) +//! Bits [50..58]: mem_flags (Byte: JALR|memory_op, signed, 2B, 4B, 8B) //! ``` //! //! ## Bus Interactions @@ -113,7 +109,8 @@ pub fn generate_decode_trace( .enumerate() .map(|(row_idx, (&pc, &instr))| { pc_to_row.insert(pc, row_idx); - DecodeEntry::from_instruction(pc, instr) + // instruction_length = 4 (RV64C compressed decode is a separate workstream). + DecodeEntry::from_instruction(pc, instr, 4) }) .collect(); @@ -161,7 +158,8 @@ pub fn generate_decode_trace( data[base + cols::IMM_1] = FE::from(cpu_padding_entry.imm >> 32); } - // Fill padding rows with DECODE padding pattern: pc=7, EBREAK=1 + // Fill padding rows with the DECODE padding pattern: odd pc=1, all flags 0 + // (unprovable as a fetch target; same row the CPU pads to). let padding_entry = DecodeEntry::padding_entry(); for row_idx in num_entries..num_rows { let base = row_idx * cols::NUM_COLUMNS; @@ -377,7 +375,7 @@ pub fn tables_from_elf(elf: &Elf) -> Result { let addr = segment.base_addr + (i as u64 * 4); let instruction = Instruction::parse(word)?; pc_to_row.insert(addr, decode_entries.len()); - decode_entries.push(DecodeEntry::from_instruction(addr, instruction)); + decode_entries.push(DecodeEntry::from_instruction(addr, instruction, 4)); } } } diff --git a/prover/src/tables/dvrm.rs b/prover/src/tables/dvrm.rs index 30352e125..3329bf273 100644 --- a/prover/src/tables/dvrm.rs +++ b/prover/src/tables/dvrm.rs @@ -24,8 +24,8 @@ //! ## Bus Interactions //! - Sender: IS_HALF (×16: n, d, r, n_sub_r, q) //! - Sender: MSB16 (×3 for sign extraction: n, d, r) -//! - Sender: LT (×1 for abs_r < abs_d) -//! - Sender: MUL (×2 for n_sub_r = d * q verification) +//! - Sender: ALU (×3, on the unified bus: ×1 LT-flavored for `|r| < |d|`, +//! ×2 MUL-flavored for `n - r = d * q` lo/hi) //! - Sender: ZERO (×5 for div_by_zero, overflow, NEG template) //! - Receiver: DVRM (×2 for quotient and remainder results) @@ -40,7 +40,7 @@ use stark::trace::TraceTable; use super::types::{ BusId, FE, GoldilocksExtension, GoldilocksField, NEG_INV_2_16, NEG_INV_2_32, NEG_INV_2_48, - NEG_INV_2_64, SHIFT_16, + NEG_INV_2_64, SHIFT_16, alu_op, }; // ========================================================================= @@ -492,12 +492,14 @@ pub fn bus_interactions() -> Vec { )); // ------------------------------------------------------------------------- - // DVRM-C2: LT[1-div_by_zero; abs_r, abs_d, 0] - // Verify |r| < |d| when d != 0 + // DVRM-C2: ALU[abs_r, abs_d, opsel(LT), 1-div_by_zero, 0] + // Verify |r| < |d| when d != 0 (the ALU output is 1 iff abs_r < abs_d). + // This lookup is dispatched on the unified ALU bus with signed=0/invert=0 + // (there is no dedicated `Lt` bus). // multiplicity: μ_q + μ_r // ------------------------------------------------------------------------- interactions.push(BusInteraction::sender( - BusId::Lt, + BusId::Alu, Multiplicity::Sum(cols::MU_Q, cols::MU_R), vec![ // abs_r as DWordWL (2 words → 2 elements) @@ -510,9 +512,9 @@ pub fn bus_interactions() -> Vec { start_column: cols::ABS_D_0, packing: Packing::DWordWL, }, - // signed = 0 (unsigned comparison of absolute values) - BusValue::constant(0), - // lt_result = 1 - div_by_zero + // flags = opsel(LT) (signed=0, invert=0) + BusValue::constant(alu_op::LT as u64), + // out_lo = 1 - div_by_zero (LT result fits in the low word) BusValue::linear(vec![ LinearTerm::Constant(1), LinearTerm::Column { @@ -520,81 +522,81 @@ pub fn bus_interactions() -> Vec { column: cols::DIV_BY_ZERO, }, ]), + // out_hi = 0 + BusValue::constant(0), ], )); // ------------------------------------------------------------------------- - // DVRM-C9: MUL[n_sub_r::DWordWL; d, signed, q, sign_q, 0] - // Verify n - r = d * q (lower 64 bits) + // DVRM-C9: ALU[d, q, opsel(MUL)+32*signed+64*sign_q, n_sub_r] + // Verify n - r = d * q (lower 64 bits). The lookup is dispatched on the + // unified ALU bus with the lo selector (flags `+0`); there is no dedicated + // `Mul` bus. // multiplicity: μ_q + μ_r // ------------------------------------------------------------------------- + let mul_flags = |hi: i64| { + BusValue::linear(vec![ + LinearTerm::Constant(alu_op::MUL as i64 + hi), + LinearTerm::Column { + coefficient: 32, + column: cols::SIGNED, + }, + LinearTerm::Column { + coefficient: 64, + column: cols::SIGN_Q, + }, + ]) + }; interactions.push(BusInteraction::sender( - BusId::Mul, + BusId::Alu, Multiplicity::Sum(cols::MU_Q, cols::MU_R), vec![ - // d as DWordHL (lhs) + // lhs = d as DWordHL BusValue::Packed { start_column: cols::D_0, packing: Packing::DWordHL, }, - // lhs_signed = signed - BusValue::Packed { - start_column: cols::SIGNED, - packing: Packing::Direct, - }, - // q as DWordHL (rhs) + // rhs = q as DWordHL BusValue::Packed { start_column: cols::Q_0, packing: Packing::DWordHL, }, - // rhs_signed = sign_q - BusValue::Packed { - start_column: cols::SIGN_Q, - packing: Packing::Direct, - }, - // result: n_sub_r as DWordHL (lower 64 bits of d*q) + // flags = opsel(MUL) + 32*signed + 64*sign_q (lo half) + mul_flags(0), + // result = n_sub_r as DWordHL (lower 64 bits of d*q) BusValue::Packed { start_column: cols::N_SUB_R_0, packing: Packing::DWordHL, }, - // muldiv_selector = 0 (lo) - BusValue::constant(0), ], )); // ------------------------------------------------------------------------- - // DVRM-C10: MUL[extension_n_sub_r::DWordWL; d, signed, q, sign_q, 1] - // Verify upper 64 bits of d * q = sign extension of n_sub_r + // DVRM-C10: ALU[d, q, opsel(MUL)+32*signed+64*sign_q+128, sign_ext(n_sub_r)] + // Verify upper 64 bits of d * q = sign extension of n_sub_r. + // Dispatched on the unified ALU bus with the hi selector (flags `+128`). // multiplicity: μ_q + μ_r // ------------------------------------------------------------------------- interactions.push(BusInteraction::sender( - BusId::Mul, + BusId::Alu, Multiplicity::Sum(cols::MU_Q, cols::MU_R), vec![ - // d as DWordHL (lhs) + // lhs = d as DWordHL BusValue::Packed { start_column: cols::D_0, packing: Packing::DWordHL, }, - // lhs_signed = signed - BusValue::Packed { - start_column: cols::SIGNED, - packing: Packing::Direct, - }, - // q as DWordHL (rhs) + // rhs = q as DWordHL BusValue::Packed { start_column: cols::Q_0, packing: Packing::DWordHL, }, - // rhs_signed = sign_q - BusValue::Packed { - start_column: cols::SIGN_Q, - packing: Packing::Direct, - }, - // result: sign extension of n_sub_r as DWordHL - // Each halfword = sign_n_sub_r * 65535 - // lo32 = sign_n_sub_r * (65535 + 65535 * 2^16) = sign_n_sub_r * 0xFFFFFFFF - // hi32 = same + // flags = opsel(MUL) + 32*signed + 64*sign_q + 128 (hi half) + mul_flags(128), + // result: sign extension of n_sub_r. + // The MUL Alu receiver consumes the result as `Packed{HI_0, DWordHL}` + // → 2 elements `[HI_0 + 2^16*HI_1, HI_2 + 2^16*HI_3]`. Both equal + // SIGN_N_SUB_R * 0xFFFFFFFF (each halfword is SIGN_FILL when negative). BusValue::linear(vec![LinearTerm::Column { coefficient: (SIGN_FILL + SIGN_FILL * SHIFT_16) as i64, column: cols::SIGN_N_SUB_R, @@ -603,8 +605,6 @@ pub fn bus_interactions() -> Vec { coefficient: (SIGN_FILL + SIGN_FILL * SHIFT_16) as i64, column: cols::SIGN_N_SUB_R, }]), - // muldiv_selector = 1 (hi) - BusValue::constant(1), ], )); @@ -893,11 +893,11 @@ pub fn bus_interactions() -> Vec { )); // ------------------------------------------------------------------------- - // DVRM-C21: Receiver for quotient result - // DVRM[q::DWordWL; n, d, signed, 0] with multiplicity -μ_q + // DVRM-C21: Quotient result on the unified ALU bus. + // ALU[q::DWordWL; n, d, opsel(DIVREM) + 32*signed] | μ_q (muldiv bit 7 = 0) // ------------------------------------------------------------------------- interactions.push(BusInteraction::receiver( - BusId::Dvrm, + BusId::Alu, Multiplicity::Column(cols::MU_Q), vec![ // n as DWordHL (4 halfwords → 2 words) @@ -910,27 +910,28 @@ pub fn bus_interactions() -> Vec { start_column: cols::D_0, packing: Packing::DWordHL, }, - // signed - BusValue::Packed { - start_column: cols::SIGNED, - packing: Packing::Direct, - }, + // flags = DIVREM + 32*signed (quotient: muldiv selector = 0) + BusValue::linear(vec![ + LinearTerm::Constant(alu_op::DIVREM as i64), + LinearTerm::Column { + coefficient: 32, + column: cols::SIGNED, + }, + ]), // q as DWordHL (result) BusValue::Packed { start_column: cols::Q_0, packing: Packing::DWordHL, }, - // muldiv_selector = 0 (quotient) - BusValue::constant(0), ], )); // ------------------------------------------------------------------------- - // DVRM-C22: Receiver for remainder result - // DVRM[r::DWordWL; n, d, signed, 1] with multiplicity -μ_r + // DVRM-C22: Remainder result on the unified ALU bus. + // ALU[r::DWordWL; n, d, opsel(DIVREM) + 32*signed + 128] | μ_r (muldiv bit 7 = 1) // ------------------------------------------------------------------------- interactions.push(BusInteraction::receiver( - BusId::Dvrm, + BusId::Alu, Multiplicity::Column(cols::MU_R), vec![ // n as DWordHL @@ -943,18 +944,19 @@ pub fn bus_interactions() -> Vec { start_column: cols::D_0, packing: Packing::DWordHL, }, - // signed - BusValue::Packed { - start_column: cols::SIGNED, - packing: Packing::Direct, - }, + // flags = DIVREM + 32*signed + 128 (remainder: muldiv selector = 1) + BusValue::linear(vec![ + LinearTerm::Constant(alu_op::DIVREM as i64 + 128), + LinearTerm::Column { + coefficient: 32, + column: cols::SIGNED, + }, + ]), // r as DWordHL (result) BusValue::Packed { start_column: cols::R_0, packing: Packing::DWordHL, }, - // muldiv_selector = 1 (remainder) - BusValue::constant(1), ], )); diff --git a/prover/src/tables/eq.rs b/prover/src/tables/eq.rs new file mode 100644 index 000000000..f60ed2e58 --- /dev/null +++ b/prover/src/tables/eq.rs @@ -0,0 +1,328 @@ +//! EQ (equality) comparison table. +//! +//! Computes `res = (a == b) XOR invert` for 64-bit `a`, `b`. Used by `BEQ` +//! (`invert = 0`) and `BNE` (`invert = 1`); the CPU dispatches to it on the +//! unified `ALU` bus with `alu_flags = opsel(EQ) + 64*invert`. +//! +//! Spec: `spec/src/eq.toml`. +//! +//! ## Columns +//! - `a`: DWordWL (2 words) — first input +//! - `b`: DWordWL (2 words) — second input +//! - `invert`: Bit — invert the result +//! - `res`: Bit — output, `(a == b) XOR invert` +//! - `diff`: DWordHL (4 halves) — `a - b` (aux) +//! - `eq`: Bit — `a == b` (aux) +//! - `μ`: multiplicity +//! +//! ## Method +//! `diff = a - b` is enforced via the `ADD` template (`b + diff = a`), its +//! halves range-checked via `IS_HALF`. Then `eq = ZERO[Σ diff[i]]` (the sum of +//! four range-checked halves is `0` iff `diff == 0` iff `a == b`), and +//! `res = eq XOR invert`. + +use math::field::element::FieldElement; +use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; +use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; +use stark::table::TableView; +use stark::trace::TraceTable; + +use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField, alu_op}; +use crate::constraints::templates::{AddConstraint, AddOperand, new_is_bit_constraints}; + +// ========================================================================= +// Column indices for EQ table +// ========================================================================= + +/// Column definitions for the EQ table. +pub mod cols { + // Input: a (DWordWL = 2 words) + pub const A_0: usize = 0; + pub const A_1: usize = 1; + // Input: b (DWordWL = 2 words) + pub const B_0: usize = 2; + pub const B_1: usize = 3; + /// invert: Bit + pub const INVERT: usize = 4; + /// res: Bit (output) = (a == b) XOR invert + pub const RES: usize = 5; + // Auxiliary: diff (DWordHL = 4 halves) = a - b + pub const DIFF_0: usize = 6; + pub const DIFF_1: usize = 7; + pub const DIFF_2: usize = 8; + pub const DIFF_3: usize = 9; + /// eq: Bit (auxiliary) = (a == b) + pub const EQ: usize = 10; + /// μ: multiplicity + pub const MU: usize = 11; + + /// Total number of columns + pub const NUM_COLUMNS: usize = 12; +} + +// ========================================================================= +// Trace generation +// ========================================================================= + +/// A single EQ operation. +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct EqOperation { + /// First operand (64-bit) + pub a: u64, + /// Second operand (64-bit) + pub b: u64, + /// Whether to invert the equality result + pub invert: bool, +} + +impl EqOperation { + /// Create a new EQ operation. + pub fn new(a: u64, b: u64, invert: bool) -> Self { + Self { a, b, invert } + } + + /// `a == b` (before inversion). + pub fn compute_eq(&self) -> bool { + self.a == self.b + } + + /// The output: `(a == b) XOR invert`. + pub fn compute_res(&self) -> bool { + self.compute_eq() ^ self.invert + } + + /// The BITWISE lookups this op sends (4× `IS_HALF` on the `diff` halves and + /// one `ZERO` on their sum), for the BITWISE table's multiplicity bookkeeping. + pub fn collect_bitwise_ops(&self) -> Vec { + use super::bitwise::{BitwiseOperation, BitwiseOperationType}; + let diff = self.a.wrapping_sub(self.b); + let mut ops = Vec::with_capacity(5); + let mut sum = 0u32; + for i in 0..4 { + let half = ((diff >> (i * 16)) & 0xFFFF) as u32; + sum += half; + ops.push(BitwiseOperation::halfword( + BitwiseOperationType::IsHalf, + (half & 0xFF) as u8, + (half >> 8) as u8, + )); + } + ops.push(BitwiseOperation::zero(sum)); + ops + } +} + +/// Generates the EQ trace from a list of operations. +/// +/// Duplicate operations are merged into a single row with summed multiplicities, +/// then padded to the next power of two (minimum 4). +pub fn generate_eq_trace( + operations: &[EqOperation], +) -> TraceTable { + use std::collections::HashMap; + + let mut op_map: HashMap = HashMap::new(); + for op in operations { + *op_map.entry(op.clone()).or_insert(0) += 1; + } + + let unique_ops: Vec<_> = op_map.into_iter().collect(); + let num_rows = unique_ops.len().next_power_of_two().max(4); + let mut data = vec![FE::zero(); num_rows * cols::NUM_COLUMNS]; + + for (row_idx, (op, multiplicity)) in unique_ops.iter().enumerate() { + let base = row_idx * cols::NUM_COLUMNS; + + // a, b as DWordWL (2 words each) + data[base + cols::A_0] = FE::from(op.a & 0xFFFF_FFFF); + data[base + cols::A_1] = FE::from(op.a >> 32); + data[base + cols::B_0] = FE::from(op.b & 0xFFFF_FFFF); + data[base + cols::B_1] = FE::from(op.b >> 32); + + data[base + cols::INVERT] = FE::from(op.invert as u64); + data[base + cols::RES] = FE::from(op.compute_res() as u64); + + // diff = a - b (wrapping) as DWordHL (4 halves) + let diff = op.a.wrapping_sub(op.b); + data[base + cols::DIFF_0] = FE::from(diff & 0xFFFF); + data[base + cols::DIFF_1] = FE::from((diff >> 16) & 0xFFFF); + data[base + cols::DIFF_2] = FE::from((diff >> 32) & 0xFFFF); + data[base + cols::DIFF_3] = FE::from((diff >> 48) & 0xFFFF); + + data[base + cols::EQ] = FE::from(op.compute_eq() as u64); + data[base + cols::MU] = FE::from(*multiplicity); + } + + TraceTable::new_main(data, cols::NUM_COLUMNS, 1) +} + +// ========================================================================= +// Bus interactions +// ========================================================================= + +/// All bus interactions for the EQ table: +/// - **Sends** `IS_HALF[diff[i]]` (×4) to range-check the difference halves. +/// - **Sends** `ZERO[Σ diff[i]] -> eq`. +/// - **Receives** `ALU[a, b, opsel(EQ) + 64*invert] -> res`. +pub fn bus_interactions() -> Vec { + let mut interactions = Vec::with_capacity(6); + + // IS_HALF[diff[i]] for i in 0..3 + for diff_col in [cols::DIFF_0, cols::DIFF_1, cols::DIFF_2, cols::DIFF_3] { + interactions.push(BusInteraction::sender( + BusId::IsHalfword, + Multiplicity::Column(cols::MU), + vec![BusValue::Packed { + start_column: diff_col, + packing: Packing::Direct, + }], + )); + } + + // ZERO[diff[0] + diff[1] + diff[2] + diff[3]] -> eq + // The sum of four range-checked halves is in [0, 2^18) < 2^20, so it is 0 + // iff diff == 0 iff a == b. Matches the BITWISE ZERO lookup domain. + interactions.push(BusInteraction::sender( + BusId::Zero, + Multiplicity::Column(cols::MU), + vec![ + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::DIFF_0, + }, + LinearTerm::Column { + coefficient: 1, + column: cols::DIFF_1, + }, + LinearTerm::Column { + coefficient: 1, + column: cols::DIFF_2, + }, + LinearTerm::Column { + coefficient: 1, + column: cols::DIFF_3, + }, + ]), + BusValue::Packed { + start_column: cols::EQ, + packing: Packing::Direct, + }, + ], + )); + + // ALU[a, b, opsel(EQ) + 64*invert] -> res (receiver). + // The ALU output is DWordWL (2 elements); for a comparison it is [res, 0] + // (the bit in the low word, 0 in the high word). + interactions.push(BusInteraction::receiver( + BusId::Alu, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::A_0, + packing: Packing::DWordWL, + }, + BusValue::Packed { + start_column: cols::B_0, + packing: Packing::DWordWL, + }, + BusValue::linear(vec![ + LinearTerm::Constant(alu_op::EQ as i64), + LinearTerm::Column { + coefficient: 64, + column: cols::INVERT, + }, + ]), + // out = [res, 0] (DWordWL) + BusValue::Packed { + start_column: cols::RES, + packing: Packing::Direct, + }, + BusValue::constant(0), + ], + )); + + interactions +} + +// ========================================================================= +// Constraints +// ========================================================================= + +/// Enforces `res = eq XOR invert`, i.e. `res = eq + invert - 2*eq*invert`. +pub struct EqXorConstraint { + constraint_idx: usize, +} + +impl EqXorConstraint { + pub fn new(constraint_idx: usize) -> Self { + Self { constraint_idx } + } +} + +impl TransitionConstraint for EqXorConstraint { + fn degree(&self) -> usize { + 2 // eq * invert + } + + fn constraint_idx(&self) -> usize { + self.constraint_idx + } + + fn evaluate(&self, step: &TableView) -> FieldElement + where + F: IsSubFieldOf, + E: IsField, + { + let res = step.get_main_evaluation_element(0, cols::RES).clone(); + let eq = step.get_main_evaluation_element(0, cols::EQ).clone(); + let invert = step.get_main_evaluation_element(0, cols::INVERT).clone(); + let two = FieldElement::::from(2u64); + // res - (eq + invert - 2*eq*invert) + res - (&eq + &invert - two * &eq * &invert) + } +} + +/// Creates all transition constraints for the EQ table. +/// +/// Returns the boxed constraints and the next available constraint index: +/// - `ADD` template pair enforcing `b + diff = a` (i.e. `diff = a - b`); +/// - `IS_BIT(invert)`; +/// - `res = eq XOR invert`. +pub fn eq_constraints( + constraint_idx_start: usize, +) -> ( + Vec>>, + usize, +) { + let mut idx = constraint_idx_start; + let mut constraints: Vec< + Box>, + > = Vec::new(); + + // diff = a - b, encoded as b + diff = a (unconditional). + let (add_lo, add_hi) = AddConstraint::new_pair( + vec![], + AddOperand::dword(cols::B_0), + AddOperand::from_dword_hl(cols::DIFF_0), + AddOperand::dword(cols::A_0), + idx, + ); + idx += 2; + constraints.push(add_lo.boxed()); + constraints.push(add_hi.boxed()); + + // IS_BIT(invert) + let (is_bit, next) = new_is_bit_constraints(&[cols::INVERT], idx); + idx = next; + for c in is_bit { + constraints.push(c.boxed()); + } + + // res = eq XOR invert + constraints.push(EqXorConstraint::new(idx).boxed()); + idx += 1; + + (constraints, idx) +} diff --git a/prover/src/tables/halt.rs b/prover/src/tables/halt.rs index 5d76bc157..946268e24 100644 --- a/prover/src/tables/halt.rs +++ b/prover/src/tables/halt.rs @@ -5,23 +5,29 @@ //! //! ## Columns //! - `timestamp`: DWordWL (2 columns) - timestamp at which to halt the program +//! - `pc`: DWordWL (2 columns) - the `next_pc` the CPU wrote during the halting +//! instruction (consumed off the `memory` bus and replaced by the padding PC=1) //! //! ## Bus Interactions //! - **Receiver**: ECALL bus - receives `[timestamp, cast(rv1, DWordWL)]` from CPU //! when the ECALL flag is set (rv1 must be 93 = sys_exit) -//! - **Sender**: MEMW bus - 32 register finalization interactions at `ts = 2^64-1`: +//! - **Sender**: MEMW bus - 31 register finalization interactions at `ts = 2^64-1`: //! - x1-x9: write 0 (zeroize lo GPRs) //! - x10: read with old=0 (enforce exit_code=0; non-zero → bus imbalance → proof failure) //! - x11-x31: write 0 (zeroize hi GPRs) -//! - x255: write 1 (PC halted sentinel) +//! - **`memory` bus (PC finalization, per spec halt:c:consume_pc/emit_pc)**: at +//! `ts = timestamp + 1` the chip *consumes* the real `next_pc` the CPU wrote for +//! the halting instruction and *re-emits* `pc = 1`. This bridges the last real PC +//! write to the CPU padding rows (which all carry PC=1); the padding chain then +//! carries PC=1 to the REGISTER table's final token. x255 is therefore NOT +//! finalized via MEMW at `2^64-1` anymore. //! -//! All MEMW interactions use constant values only (no additional columns needed). //! Corresponding MEMW table rows are generated in trace_builder. //! //! ## Padding //! Single-row table (2^0 = 1), no padding needed. -use stark::lookup::{BusInteraction, BusValue, Multiplicity, Packing}; +use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::trace::TraceTable; use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; @@ -37,8 +43,13 @@ pub mod cols { /// timestamp[1]: Word (upper 32 bits of halt timestamp) pub const TIMESTAMP_1: usize = 1; + /// pc[0]: Word (lower 32 bits of the halting instruction's next_pc) + pub const PC_0: usize = 2; + /// pc[1]: Word (upper 32 bits of the halting instruction's next_pc) + pub const PC_1: usize = 3; + /// Total number of columns - pub const NUM_COLUMNS: usize = 2; + pub const NUM_COLUMNS: usize = 4; } // ========================================================================= @@ -52,7 +63,10 @@ pub mod cols { /// first ECALL, so a valid trace always contains exactly one. If a program had multiple /// ECALLs, the CPU would send multiple bus interactions but HALT only receives one, /// causing a bus imbalance and proof failure. -pub fn generate_halt_trace(timestamp: u64) -> TraceTable { +pub fn generate_halt_trace( + timestamp: u64, + next_pc: u64, +) -> TraceTable { // CPU timestamps must fit in u32 (timestamp_hi should be 0) debug_assert!( timestamp <= u32::MAX as u64, @@ -61,7 +75,12 @@ pub fn generate_halt_trace(timestamp: u64) -> TraceTable> 32; - let data = vec![FE::from(timestamp_lo), FE::from(timestamp_hi)]; + let data = vec![ + FE::from(timestamp_lo), + FE::from(timestamp_hi), + FE::from(next_pc & 0xFFFF_FFFF), + FE::from(next_pc >> 32), + ]; TraceTable::new_main(data, cols::NUM_COLUMNS, 1) } @@ -134,13 +153,14 @@ fn halt_write_bus_values(base_addr: u64, value_lo: u64) -> Vec { /// Creates all bus interactions for the HALT table. /// /// - **ECALL receiver**: receives `[timestamp, cast(rv1, DWordWL)]` from CPU -/// - **MEMW senders** (32 total): register finalization at `ts = 2^64-1` +/// - **MEMW senders** (31 total): register finalization at `ts = 2^64-1` /// - x1-x9: write 0 (zeroize lo GPRs) /// - x10: read with old=0 (enforce exit_code=0) /// - x11-x31: write 0 (zeroize hi GPRs) -/// - x255: write 1 (PC halted sentinel) +/// - **`memory` bus (4 total)**: consume_pc (x2) + emit_pc (x2) at `ts = timestamp+1`, +/// bridging the last real PC write to the PC=1 padding chain. pub fn bus_interactions() -> Vec { - let mut interactions = Vec::with_capacity(33); + let mut interactions = Vec::with_capacity(36); // ECALL receiver: receives [timestamp, cast(rv1, DWordWL)] from CPU // rv1 must be 93 (sys_exit) for bus to balance; otherwise proof fails. @@ -188,12 +208,58 @@ pub fn bus_interactions() -> Vec { )); } - // x255 (PC): write 1 at ts=2^64-1 (halted sentinel) - interactions.push(BusInteraction::sender( - BusId::Memw, - Multiplicity::One, - halt_write_bus_values(510, 1), - )); + // PC finalization on the low-level `memory` token bus at ts = timestamp + 1 + // (per spec halt:c:consume_pc / halt:c:emit_pc). The CPU's halting row wrote + // its real `next_pc` to x255 (addresses 510/511) at this same timestamp; we + // consume it (sender, +1) and re-emit pc=1 (receiver, -1) so the CPU padding + // rows — which all carry pc=1 — chain cleanly to the REGISTER final token. + // `value` layout on the bus: [is_register, addr_lo, addr_hi, ts_lo, ts_hi, value]. + let ts_plus_one_lo = || { + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::TIMESTAMP_0, + }, + LinearTerm::Constant(1), + ]) + }; + let ts_hi = || BusValue::Packed { + start_column: cols::TIMESTAMP_1, + packing: Packing::Direct, + }; + for (addr, pc_col) in [(510u64, cols::PC_0), (511u64, cols::PC_1)] { + // consume_pc (sender, +1): consume the real next_pc the CPU wrote. + interactions.push(BusInteraction::sender( + BusId::Memory, + Multiplicity::One, + vec![ + BusValue::constant(1), + BusValue::constant(addr), + BusValue::constant(0), + ts_plus_one_lo(), + ts_hi(), + BusValue::Packed { + start_column: pc_col, + packing: Packing::Direct, + }, + ], + )); + } + for (addr, value) in [(510u64, 1u64), (511u64, 0u64)] { + // emit_pc (receiver, -1): re-emit pc = 1 (value [1, 0]). + interactions.push(BusInteraction::receiver( + BusId::Memory, + Multiplicity::One, + vec![ + BusValue::constant(1), + BusValue::constant(addr), + BusValue::constant(0), + ts_plus_one_lo(), + ts_hi(), + BusValue::constant(value), + ], + )); + } interactions } diff --git a/prover/src/tables/load.rs b/prover/src/tables/load.rs index 32c945a41..77e1db628 100644 --- a/prover/src/tables/load.rs +++ b/prover/src/tables/load.rs @@ -425,48 +425,52 @@ pub fn bus_interactions() -> Vec { )); // ------------------------------------------------------------------------- - // LOAD receiver (from CPU) + // MEMORY receiver (from CPU) — unified high-level memory op. // ------------------------------------------------------------------------- - // Spec: LOAD[res::DWordWL; base_address, timestamp, read2, read4, read8, signed] | -μ - // - // res is DWordBL (8 bytes) but packed as DWordWL (2 words) for the bus. - // DWordBL packing: 8 bytes → 2 bus elements [lo32, hi32] + // MEMORY[out=res::DWordWL; timestamp, address, value, mem_flags] | -μ + // The CPU dispatches LOAD here (mem_flags bit 0 = memory_op = 0). The `value` + // field carries the store value and is 0 for loads; `out` is the loaded res. + // mem_flags = 2*signed + 4*read2 + 8*read4 + 16*read8 (memory_op = 0). interactions.push(BusInteraction::receiver( - BusId::Load, + BusId::MemoryOp, Multiplicity::Column(cols::MU), vec![ - // res::DWordWL - pack 8 bytes as 2 words - BusValue::Packed { - start_column: cols::RES[0], - packing: Packing::DWordBL, - }, - // base_address (DWordWL = 2 words) - BusValue::Packed { - start_column: cols::BASE_ADDRESS_0, - packing: Packing::DWordWL, - }, // timestamp (DWordWL = 2 words) BusValue::Packed { start_column: cols::TIMESTAMP_0, packing: Packing::DWordWL, }, - // read flags - BusValue::Packed { - start_column: cols::READ2, - packing: Packing::Direct, - }, + // address = base_address (DWordWL = 2 words) BusValue::Packed { - start_column: cols::READ4, - packing: Packing::Direct, - }, - BusValue::Packed { - start_column: cols::READ8, - packing: Packing::Direct, + start_column: cols::BASE_ADDRESS_0, + packing: Packing::DWordWL, }, - // signed flag + // value (store value) = 0 for loads + BusValue::constant(0), + BusValue::constant(0), + // mem_flags byte + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 2, + column: cols::SIGNED, + }, + LinearTerm::Column { + coefficient: 4, + column: cols::READ2, + }, + LinearTerm::Column { + coefficient: 8, + column: cols::READ4, + }, + LinearTerm::Column { + coefficient: 16, + column: cols::READ8, + }, + ]), + // out = res::DWordWL (8 bytes packed as 2 words) — the loaded value BusValue::Packed { - start_column: cols::SIGNED, - packing: Packing::Direct, + start_column: cols::RES[0], + packing: Packing::DWordBL, }, ], )); diff --git a/prover/src/tables/lt.rs b/prover/src/tables/lt.rs index da1bc948e..0ad89bcee 100644 --- a/prover/src/tables/lt.rs +++ b/prover/src/tables/lt.rs @@ -23,16 +23,17 @@ //! ## Bus Interactions //! - Sender: MSB16 (×2 for lhs_msb, rhs_msb) //! - Sender: IS_HALFWORD (×6: ×4 for lhs_sub_rhs, ×1 for lhs[1], ×1 for rhs[1]) -//! - Receiver: LT (provides less-than results to other tables) +//! - Receiver: ALU (all less-than lookups — CPU SLT/BLT/BGE dispatch and the +//! internal `memw`/`memw_aligned`/`dvrm` timestamp / |r|<|d| checks) use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; use stark::constraints::transition::TransitionConstraint; -use stark::lookup::{BusInteraction, BusValue, Multiplicity, Packing}; +use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; use stark::table::TableView; use stark::trace::TraceTable; -use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField, SHIFT_16}; +use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField, SHIFT_16, alu_op}; // ========================================================================= // Column indices for LT table @@ -80,12 +81,18 @@ pub mod cols { /// rhs_msb: Bit (MSB of rhs, i.e., bit 63) pub const RHS_MSB: usize = 13; - // Multiplicity column - /// μ: multiplicity for bus interactions - pub const MU: usize = 14; + // Every LT lookup (CPU SLT/BLT/BGE dispatch and the internal + // memw/memw_aligned/dvrm comparisons) goes through the unified `ALU` bus, + // so one multiplicity column suffices. + /// invert: Bit — invert the comparison (BGE/BGEU); `out = lt XOR invert`. + pub const INVERT: usize = 14; + /// out: the ALU result `lt XOR invert` (the low word; high word is 0). + pub const OUT: usize = 15; + /// μ: multiplicity for the `ALU` bus receiver. + pub const MU: usize = 16; /// Total number of columns - pub const NUM_COLUMNS: usize = 15; + pub const NUM_COLUMNS: usize = 17; } // ========================================================================= @@ -94,6 +101,10 @@ pub mod cols { /// A single LT operation to be added to the trace. /// +/// Every operation is dispatched on the unified `ALU` bus; the `invert` flag +/// distinguishes plain less-than (memw/dvrm internal checks, CPU `SLT[U]`/`BLT[U]`) +/// from the inverted form (`BGE[U]`). +/// /// Derives Hash and Eq so it can be used as a HashMap key for deduplication. #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct LtOperation { @@ -103,15 +114,32 @@ pub struct LtOperation { pub rhs: u64, /// Whether to do signed comparison pub signed: bool, + /// Whether to invert the result (`out = lt XOR invert`); used for BGE/BGEU. + pub invert: bool, } impl LtOperation { - /// Create a new LT operation. + /// Create a new LT operation with `invert = false` (plain less-than). pub fn new(lhs: u64, rhs: u64, signed: bool) -> Self { - Self { lhs, rhs, signed } + Self { + lhs, + rhs, + signed, + invert: false, + } + } + + /// Create a new LT operation with an explicit `invert` flag (BGE/BGEU dispatch). + pub fn new_with_invert(lhs: u64, rhs: u64, signed: bool, invert: bool) -> Self { + Self { + lhs, + rhs, + signed, + invert, + } } - /// Compute the less-than result. + /// Compute the raw less-than result (before inversion). pub fn compute_lt(&self) -> bool { if self.signed { (self.lhs as i64) < (self.rhs as i64) @@ -119,6 +147,11 @@ impl LtOperation { self.lhs < self.rhs } } + + /// The ALU output: `lt XOR invert`. + pub fn compute_out(&self) -> bool { + self.compute_lt() ^ self.invert + } } /// Generates the LT trace table from a list of operations. @@ -186,7 +219,11 @@ pub fn generate_lt_trace( data[base + cols::LHS_MSB] = FE::from(lhs_msb); data[base + cols::RHS_MSB] = FE::from(rhs_msb); - // Multiplicity: aggregated count of this operation + // ALU-bus fields: invert + the inverted output. + data[base + cols::INVERT] = FE::from(op.invert as u64); + data[base + cols::OUT] = FE::from(op.compute_out() as u64); + + // All LT lookups go through the unified ALU bus → single multiplicity. data[base + cols::MU] = FE::from(*multiplicity); } @@ -291,34 +328,40 @@ pub fn bus_interactions() -> Vec { packing: Packing::Direct, }], ), - // LT[lhs, rhs, signed] -> lt (receiver) - // lhs is DWordHHW, rhs is DWordHHW, signed is Bit, lt is Bit - // Uses DWordHHW packing: reads 3 columns (Word, Half, Half), produces 2 bus elements [lo32, hi32] - // This allows DWordWL senders (like MEMW timestamps) to match via Packing::DWordWL + // ALU[lhs, rhs, opsel(LT) + 32*signed + 64*invert] -> out (receiver). + // Every LT lookup arrives here: the CPU dispatches SLT/BLT/BGE on the + // unified ALU bus, and the internal memw/memw_aligned/dvrm comparisons + // (timestamps and |r|<|d|) encode `signed=0, invert=0`. lhs/rhs are + // packed DWordHHW -> [lo32, hi32] (matching DWordWL senders); the + // output is [out, 0] (a comparison result fits in the low word). BusInteraction::receiver( - BusId::Lt, + BusId::Alu, Multiplicity::Column(cols::MU), vec![ - // lhs as DWordHHW (reads 3 columns: Word, Half, Half; produces 2 elements: [lo32, hi32]) BusValue::Packed { start_column: cols::LHS_0, packing: Packing::DWordHHW, }, - // rhs as DWordHHW (reads 3 columns, produces 2 elements) BusValue::Packed { start_column: cols::RHS_0, packing: Packing::DWordHHW, }, - // signed - BusValue::Packed { - start_column: cols::SIGNED, - packing: Packing::Direct, - }, - // lt (output) + BusValue::linear(vec![ + LinearTerm::Constant(alu_op::LT as i64), + LinearTerm::Column { + coefficient: 32, + column: cols::SIGNED, + }, + LinearTerm::Column { + coefficient: 64, + column: cols::INVERT, + }, + ]), BusValue::Packed { - start_column: cols::LT, + start_column: cols::OUT, packing: Packing::Direct, }, + BusValue::constant(0), ], ), ] diff --git a/prover/src/tables/memw.rs b/prover/src/tables/memw.rs index 7bf75741a..39a02ead4 100644 --- a/prover/src/tables/memw.rs +++ b/prover/src/tables/memw.rs @@ -22,7 +22,8 @@ //! - `μ_sum`: μ_read + μ_write //! //! ## Bus Interactions (26) -//! - 8 LT timestamp checks (old_timestamp[i] < timestamp) +//! - 8 ALU lookups for timestamp ordering (old_timestamp[i] < timestamp, +//! dispatched as `ALU[old_ts, ts, opsel(LT), 1, 0]` on the unified bus) //! - 16 Memory bus tokens (read old + write new, per byte) //! - 2 MEMW output interactions (read + write, from CPU) //! @@ -35,7 +36,7 @@ use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing} use stark::table::TableView; use stark::trace::TraceTable; -use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; +use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField, alu_op}; use crate::constraints::templates::IsBitConstraint; /// Maximum number of rows per MEMW table chunk. @@ -747,12 +748,15 @@ pub fn bus_interactions() -> Vec { )); // ------------------------------------------------------------------------- - // LT interactions for timestamp ordering (MEMW-C4 through C7) + // ALU interactions for timestamp ordering (MEMW-C4 through C7). + // Each lookup is dispatched on the unified ALU bus as + // `[old_ts, ts, opsel(LT), 1, 0]` (signed=0, invert=0, asserting + // old_ts < ts); there is no dedicated `Lt` bus. // ------------------------------------------------------------------------- - // MEMW-C4: LT[1; old_timestamp[0], timestamp] with μ_sum + // MEMW-C4: old_timestamp[0] < timestamp with μ_sum interactions.push(BusInteraction::sender( - BusId::Lt, + BusId::Alu, Multiplicity::Sum(cols::MU_READ, cols::MU_WRITE), vec![ BusValue::Packed { @@ -763,14 +767,15 @@ pub fn bus_interactions() -> Vec { start_column: cols::TIMESTAMP_0, packing: Packing::DWordWL, }, - BusValue::constant(0), + BusValue::constant(alu_op::LT as u64), BusValue::constant(1), + BusValue::constant(0), ], )); - // MEMW-C5: LT[1; old_timestamp[1], timestamp] with w2 + // MEMW-C5: old_timestamp[1] < timestamp with w2 interactions.push(BusInteraction::sender( - BusId::Lt, + BusId::Alu, Multiplicity::Sum3(cols::WRITE2, cols::WRITE4, cols::WRITE8), vec![ BusValue::Packed { @@ -781,15 +786,16 @@ pub fn bus_interactions() -> Vec { start_column: cols::TIMESTAMP_0, packing: Packing::DWordWL, }, - BusValue::constant(0), + BusValue::constant(alu_op::LT as u64), BusValue::constant(1), + BusValue::constant(0), ], )); - // MEMW-C6: LT[1; old_timestamp[i], timestamp] for i ∈ [2,3] with w4 + // MEMW-C6: old_timestamp[i] < timestamp for i ∈ [2,3] with w4 for i in 2..4 { interactions.push(BusInteraction::sender( - BusId::Lt, + BusId::Alu, Multiplicity::Sum(cols::WRITE4, cols::WRITE8), vec![ BusValue::Packed { @@ -800,16 +806,17 @@ pub fn bus_interactions() -> Vec { start_column: cols::TIMESTAMP_0, packing: Packing::DWordWL, }, - BusValue::constant(0), + BusValue::constant(alu_op::LT as u64), BusValue::constant(1), + BusValue::constant(0), ], )); } - // MEMW-C7: LT[1; old_timestamp[i], timestamp] for i ∈ [4,7] with write8 + // MEMW-C7: old_timestamp[i] < timestamp for i ∈ [4,7] with write8 for i in 4..8 { interactions.push(BusInteraction::sender( - BusId::Lt, + BusId::Alu, Multiplicity::Column(cols::WRITE8), vec![ BusValue::Packed { @@ -820,8 +827,9 @@ pub fn bus_interactions() -> Vec { start_column: cols::TIMESTAMP_0, packing: Packing::DWordWL, }, - BusValue::constant(0), + BusValue::constant(alu_op::LT as u64), BusValue::constant(1), + BusValue::constant(0), ], )); } @@ -867,6 +875,8 @@ pub enum MemwConstraintKind { MuSumIsBit, /// w2 => μ_sum: if accessing 2+ bytes, must be active row W2ImpliesMuSum, + /// IS_BIT: the width-sum is 0 or 1 (spec assumption). + WidthSumIsBit, } /// MEMW table constraint. @@ -900,6 +910,10 @@ impl MemwConstraint { let mu_sum = compute_mu_sum(step); &w2 * (&one - &mu_sum) } + MemwConstraintKind::WidthSumIsBit => { + let w2 = compute_w2(step); + &w2 * (&one - &w2) + } } } } @@ -909,6 +923,7 @@ impl TransitionConstraint for MemwConstrai match self.kind { MemwConstraintKind::MuSumIsBit => 2, MemwConstraintKind::W2ImpliesMuSum => 2, + MemwConstraintKind::WidthSumIsBit => 2, } } @@ -927,12 +942,13 @@ impl TransitionConstraint for MemwConstrai /// Creates all constraints for the MEMW table. /// -/// 11 constraints total: +/// 15 constraints total: /// - IS_BIT<μ_sum> (1) /// - w2 => μ_sum (1) /// - IS_BIT<μ_read> (1) /// - IS_BIT<μ_write> (1) /// - IS_BIT for carry[0..6] (7) +/// - IS_BIT (3) + IS_BIT (1) [spec assumption] pub fn constraints() -> Vec>> { let mut constraints: Vec< @@ -963,5 +979,12 @@ pub fn constraints() idx += 1; } + // IS_BIT on the width flags + their sum (spec defense-in-depth assumption). + for &col in &[cols::WRITE2, cols::WRITE4, cols::WRITE8] { + constraints.push(IsBitConstraint::unconditional(col, idx).boxed()); + idx += 1; + } + constraints.push(MemwConstraint::new(MemwConstraintKind::WidthSumIsBit, idx).boxed()); + constraints } diff --git a/prover/src/tables/memw_aligned.rs b/prover/src/tables/memw_aligned.rs index f61c66679..91a9e8fd8 100644 --- a/prover/src/tables/memw_aligned.rs +++ b/prover/src/tables/memw_aligned.rs @@ -20,7 +20,7 @@ //! //! ## Bus Interactions (20) //! - 1 IS_HALF[base_address[0] + mask] (range check: address span fits in 16 bits) -//! - 1 LT[old_timestamp, timestamp, 0] → 1 +//! - 1 ALU[old_timestamp, timestamp, opsel(LT), 1, 0] → asserts old_ts < ts //! - 16 Memory bus tokens //! - 2 MEMW output interactions (read + write) //! @@ -42,7 +42,7 @@ use stark::table::TableView; use stark::trace::TraceTable; use super::memw::MemwOperation; -use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; +use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField, alu_op}; use crate::constraints::templates::IsBitConstraint; /// Maximum number of rows per MEMW_A table chunk. @@ -180,10 +180,12 @@ pub fn bus_interactions() -> Vec { )); // ------------------------------------------------------------------------- - // LT[old_timestamp, timestamp, 0] → 1 with μ_sum + // ALU[old_timestamp, timestamp, opsel(LT), 1, 0] → asserts old_ts < ts. + // (Every LT lookup goes through the unified ALU bus with + // signed=0/invert=0; there is no dedicated `Lt` bus.) // ------------------------------------------------------------------------- interactions.push(BusInteraction::sender( - BusId::Lt, + BusId::Alu, mu_sum.clone(), vec![ BusValue::Packed { @@ -194,8 +196,9 @@ pub fn bus_interactions() -> Vec { start_column: cols::TIMESTAMP_0, packing: Packing::DWordWL, }, - BusValue::constant(0), + BusValue::constant(alu_op::LT as u64), BusValue::constant(1), + BusValue::constant(0), ], )); @@ -665,6 +668,8 @@ pub enum MemwAlignedConstraintKind { MuSumIsBit, /// w2 => μ_sum: if accessing 2+ bytes, must be active row W2ImpliesMuSum, + /// IS_BIT: the width-sum is 0 or 1 (spec assumption). + WidthSumIsBit, } pub struct MemwAlignedConstraint { @@ -699,6 +704,13 @@ impl MemwAlignedConstraint { let w2 = write2 + write4 + write8; &w2 * (&one - &mu_sum) } + MemwAlignedConstraintKind::WidthSumIsBit => { + let write2 = step.get_main_evaluation_element(0, cols::WRITE2).clone(); + let write4 = step.get_main_evaluation_element(0, cols::WRITE4).clone(); + let write8 = step.get_main_evaluation_element(0, cols::WRITE8).clone(); + let w2 = write2 + write4 + write8; + &w2 * (&one - &w2) + } } } } @@ -721,7 +733,8 @@ impl TransitionConstraint for MemwAlignedC } } -/// Creates all constraints for the MEMW_A table (4 total). +/// Creates all constraints for the MEMW_A table (8 total). The last four are the +/// spec's defense-in-depth width-flag assumptions. pub fn constraints() -> Vec>> { vec![ @@ -729,5 +742,9 @@ pub fn constraints() MemwAlignedConstraint::new(MemwAlignedConstraintKind::W2ImpliesMuSum, 1).boxed(), IsBitConstraint::unconditional(cols::MU_READ, 2).boxed(), IsBitConstraint::unconditional(cols::MU_WRITE, 3).boxed(), + IsBitConstraint::unconditional(cols::WRITE2, 4).boxed(), + IsBitConstraint::unconditional(cols::WRITE4, 5).boxed(), + IsBitConstraint::unconditional(cols::WRITE8, 6).boxed(), + MemwAlignedConstraint::new(MemwAlignedConstraintKind::WidthSumIsBit, 7).boxed(), ] } diff --git a/prover/src/tables/mod.rs b/prover/src/tables/mod.rs index 3c1e97736..6e5bae947 100644 --- a/prover/src/tables/mod.rs +++ b/prover/src/tables/mod.rs @@ -23,10 +23,13 @@ pub mod types; pub mod bitwise; pub mod branch; +pub mod bytewise; pub mod commit; pub mod cpu; +pub mod cpu32; pub mod decode; pub mod dvrm; +pub mod eq; pub mod halt; pub mod keccak; pub mod keccak_rc; @@ -40,6 +43,7 @@ pub mod mul; pub mod page; pub mod register; pub mod shift; +pub mod store; pub mod trace_builder; pub use types::BusId; @@ -82,6 +86,11 @@ pub mod max_rows { pub const LOAD: usize = 1 << 20; // 1,048,576 — eff. width 33 pub const BRANCH: usize = 1 << 20; // 1,048,576 — eff. width 32 pub const MEMW_R: usize = 1 << 20; // 1,048,576 — eff. width 31 + // Auxiliary ALU / memory / CPU32 dispatch chips + pub const EQ: usize = 1 << 20; + pub const BYTEWISE: usize = 1 << 20; + pub const STORE: usize = 1 << 20; + pub const CPU32: usize = 1 << 19; } /// Per-table maximum row limits, configurable for different environments. @@ -100,6 +109,10 @@ pub struct MaxRowsConfig { pub load: usize, pub branch: usize, pub memw_register: usize, + pub eq: usize, + pub bytewise: usize, + pub store: usize, + pub cpu32: usize, } impl Default for MaxRowsConfig { @@ -115,6 +128,10 @@ impl Default for MaxRowsConfig { load: max_rows::LOAD, branch: max_rows::BRANCH, memw_register: max_rows::MEMW_R, + eq: max_rows::EQ, + bytewise: max_rows::BYTEWISE, + store: max_rows::STORE, + cpu32: max_rows::CPU32, } } } @@ -134,6 +151,10 @@ impl MaxRowsConfig { load: 1 << 5, branch: 1 << 5, memw_register: 1 << 5, + eq: 1 << 5, + bytewise: 1 << 5, + store: 1 << 5, + cpu32: 1 << 5, } } } diff --git a/prover/src/tables/mul.rs b/prover/src/tables/mul.rs index ecb72a4d1..6e5ca3a72 100644 --- a/prover/src/tables/mul.rs +++ b/prover/src/tables/mul.rs @@ -27,7 +27,8 @@ //! - Sender: MSB16 (×2 for sign extraction) //! - Sender: IS_HALF (×8 for lo/hi range checks) //! - Sender: IS_B20 (×4 for carry range checks) -//! - Receiver: MUL (×2 for lo and hi results) +//! - Receiver: ALU (×2 for lo and hi results — every MUL lookup, CPU +//! MUL/MULH dispatch and dvrm's internal `d*q` consistency) use std::collections::HashMap; @@ -41,9 +42,15 @@ use stark::trace::TraceTable; use super::types::{ BusId, FE, GoldilocksExtension, GoldilocksField, INV_2_32, INV_2_64, INV_2_96, INV_2_128, NEG_INV_2_16, NEG_INV_2_32, NEG_INV_2_48, NEG_INV_2_64, NEG_INV_2_80, NEG_INV_2_96, - NEG_INV_2_112, NEG_INV_2_128, SHIFT_16, + NEG_INV_2_112, NEG_INV_2_128, SHIFT_16, alu_op, }; +/// Total row multiplicity (`ALU` bus, lo + hi), used by the internal +/// range-check sends so they fire once per row-instance. +fn row_mult() -> Multiplicity { + Multiplicity::Sum(cols::MU_LO, cols::MU_HI) +} + // ========================================================================= // Column indices for MUL table // ========================================================================= @@ -112,10 +119,11 @@ pub mod cols { /// raw_product[3]: Intermediate convolution value pub const RAW_PRODUCT_3: usize = 23; - // Multiplicity columns - /// μ_lo: multiplicity for lo result lookups + // Multiplicity columns. All MUL lookups (CPU MUL/MULH dispatch and dvrm's + // internal `d*q` consistency checks) go through the unified `ALU` bus. + /// μ_lo: `ALU` bus multiplicity for lo result lookups pub const MU_LO: usize = 24; - /// μ_hi: multiplicity for hi result lookups + /// μ_hi: `ALU` bus multiplicity for hi result lookups pub const MU_HI: usize = 25; /// Total number of columns @@ -135,6 +143,10 @@ const SIGN_FILL: u64 = 0xFFFF; /// A single MUL operation to be added to the trace. /// +/// Every operation is dispatched on the unified `ALU` bus (CPU MUL/MULH and +/// dvrm's internal `d*q` consistency checks); the lo/hi half is selected by +/// the sender's `flags` byte at lookup time. +/// /// Derives Hash and Eq for HashMap-based deduplication. #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct MulOperation { @@ -148,12 +160,12 @@ pub struct MulOperation { pub rhs_signed: bool, } -/// Multiplicities for a MUL operation (separate for lo and hi lookups). +/// Multiplicities for a MUL operation, split by lo/hi result lookup. #[derive(Debug, Clone, Default)] pub struct MulMultiplicities { - /// Count of lookups requesting lo result + /// `ALU` bus count requesting lo result pub mu_lo: u64, - /// Count of lookups requesting hi result + /// `ALU` bus count requesting hi result pub mu_hi: u64, } @@ -342,7 +354,7 @@ pub fn generate_mul_trace( data[base + cols::RAW_PRODUCT_2] = FE::from(raw[2]); data[base + cols::RAW_PRODUCT_3] = FE::from(raw[3]); - // Fill multiplicities + // Fill multiplicities (ALU bus, lo/hi) data[base + cols::MU_LO] = FE::from(multiplicities.mu_lo); data[base + cols::MU_HI] = FE::from(multiplicities.mu_hi); } @@ -405,7 +417,7 @@ pub fn bus_interactions() -> Vec { for col in [cols::LO_0, cols::LO_1, cols::LO_2, cols::LO_3] { interactions.push(BusInteraction::sender( BusId::IsHalfword, - Multiplicity::Sum(cols::MU_LO, cols::MU_HI), + row_mult(), vec![BusValue::Packed { start_column: col, packing: Packing::Direct, @@ -419,7 +431,7 @@ pub fn bus_interactions() -> Vec { for col in [cols::HI_0, cols::HI_1, cols::HI_2, cols::HI_3] { interactions.push(BusInteraction::sender( BusId::IsHalfword, - Multiplicity::Sum(cols::MU_LO, cols::MU_HI), + row_mult(), vec![BusValue::Packed { start_column: col, packing: Packing::Direct, @@ -438,7 +450,7 @@ pub fn bus_interactions() -> Vec { // carry[0] = 2^-32 * raw_product[0] - 2^-32 * lo[0] - 2^-16 * lo[1] interactions.push(BusInteraction::sender( BusId::IsB20, - Multiplicity::Sum(cols::MU_LO, cols::MU_HI), + row_mult(), vec![BusValue::linear(vec![ LinearTerm::ColumnUnsigned { coefficient: INV_2_32, @@ -459,7 +471,7 @@ pub fn bus_interactions() -> Vec { // - 2^-64 * lo[0] - 2^-48 * lo[1] - 2^-32 * lo[2] - 2^-16 * lo[3] interactions.push(BusInteraction::sender( BusId::IsB20, - Multiplicity::Sum(cols::MU_LO, cols::MU_HI), + row_mult(), vec![BusValue::linear(vec![ LinearTerm::ColumnUnsigned { coefficient: INV_2_32, @@ -493,7 +505,7 @@ pub fn bus_interactions() -> Vec { // - 2^-32 * hi[0] - 2^-16 * hi[1] interactions.push(BusInteraction::sender( BusId::IsB20, - Multiplicity::Sum(cols::MU_LO, cols::MU_HI), + row_mult(), vec![BusValue::linear(vec![ LinearTerm::ColumnUnsigned { coefficient: INV_2_32, @@ -539,7 +551,7 @@ pub fn bus_interactions() -> Vec { // - 2^-64 * hi[0] - 2^-48 * hi[1] - 2^-32 * hi[2] - 2^-16 * hi[3] interactions.push(BusInteraction::sender( BusId::IsB20, - Multiplicity::Sum(cols::MU_LO, cols::MU_HI), + row_mult(), vec![BusValue::linear(vec![ LinearTerm::ColumnUnsigned { coefficient: INV_2_32, @@ -593,78 +605,62 @@ pub fn bus_interactions() -> Vec { )); // ------------------------------------------------------------------------- - // MUL receiver for lo result + // ALU receivers: every MUL lookup arrives here — CPU + // MUL/MULH/MULHSU/MULHU dispatch and dvrm's internal `d*q` consistency. + // ALU[lhs, rhs, flags, result] where flags = + // opsel(MUL) + 32*lhs_signed + 64*rhs_signed (+128 for the hi result). // ------------------------------------------------------------------------- - // MUL[lhs, lhs_signed, rhs, rhs_signed, lo, 0] per spec MUL-C7 + let mul_flags = |hi: i64| { + BusValue::linear(vec![ + LinearTerm::Constant(alu_op::MUL as i64 + hi), + LinearTerm::Column { + coefficient: 32, + column: cols::LHS_SIGNED, + }, + LinearTerm::Column { + coefficient: 64, + column: cols::RHS_SIGNED, + }, + ]) + }; + // ALU lo (muldiv bit 7 = 0) interactions.push(BusInteraction::receiver( - BusId::Mul, + BusId::Alu, Multiplicity::Column(cols::MU_LO), vec![ - // lhs as DWordHL (4 halfwords -> 2 words) BusValue::Packed { start_column: cols::LHS_0, packing: Packing::DWordHL, }, - // lhs_signed - BusValue::Packed { - start_column: cols::LHS_SIGNED, - packing: Packing::Direct, - }, - // rhs as DWordHL BusValue::Packed { start_column: cols::RHS_0, packing: Packing::DWordHL, }, - // rhs_signed - BusValue::Packed { - start_column: cols::RHS_SIGNED, - packing: Packing::Direct, - }, - // lo as DWordHL (result) + mul_flags(0), BusValue::Packed { start_column: cols::LO_0, packing: Packing::DWordHL, }, - // muldiv_selector = 0 (lo) - BusValue::constant(0), ], )); - - // ------------------------------------------------------------------------- - // MUL receiver for hi result - // ------------------------------------------------------------------------- - // MUL[lhs, lhs_signed, rhs, rhs_signed, hi, 1] per spec MUL-C8 + // ALU hi (muldiv bit 7 = 1 => +128) interactions.push(BusInteraction::receiver( - BusId::Mul, + BusId::Alu, Multiplicity::Column(cols::MU_HI), vec![ - // lhs as DWordHL BusValue::Packed { start_column: cols::LHS_0, packing: Packing::DWordHL, }, - // lhs_signed - BusValue::Packed { - start_column: cols::LHS_SIGNED, - packing: Packing::Direct, - }, - // rhs as DWordHL BusValue::Packed { start_column: cols::RHS_0, packing: Packing::DWordHL, }, - // rhs_signed - BusValue::Packed { - start_column: cols::RHS_SIGNED, - packing: Packing::Direct, - }, - // hi as DWordHL (result) + mul_flags(128), BusValue::Packed { start_column: cols::HI_0, packing: Packing::DWordHL, }, - // muldiv_selector = 1 (hi) - BusValue::constant(1), ], )); @@ -682,6 +678,10 @@ pub enum MulConstraintKind { LhsSign, /// SIGN constraint for rhs: (1 - rhs_signed) * rhs_is_negative = 0 RhsSign, + /// IS_BIT range check on a sign flag column: `x * (1 - x) = 0`. Required + /// because `lhs_signed`/`rhs_signed` are used as bus multiplicities, so an + /// out-of-range value (e.g. `lhs_signed = 3`) would otherwise be accepted. + SignedIsBit(usize), /// Raw product convolution formula for index i RawProduct(usize), } @@ -730,6 +730,12 @@ impl MulConstraint { let one = FieldElement::::one(); (&one - &rhs_signed) * &rhs_is_neg } + MulConstraintKind::SignedIsBit(col) => { + // x * (1 - x) = 0 + let x = step.get_main_evaluation_element(0, col).clone(); + let one = FieldElement::::one(); + &x * &(&one - &x) + } MulConstraintKind::RawProduct(i) => { // raw_product[i] = convolution formula // This requires computing the sign-extended values and convolution @@ -827,6 +833,8 @@ impl TransitionConstraint for MulConstrain match self.kind { // (1 - signed) * is_negative is degree 2 MulConstraintKind::LhsSign | MulConstraintKind::RhsSign => 2, + // x * (1 - x) is degree 2 + MulConstraintKind::SignedIsBit(_) => 2, // Raw product: lhs_ext[j] * rhs_ext[idx-j] where each may involve // sign_fill * is_negative (degree 1), so product is degree 2 // But we're summing many degree-2 terms, still degree 2 @@ -854,6 +862,18 @@ pub fn mul_constraints(constraint_idx_start: usize) -> (Vec, usiz let mut idx = constraint_idx_start; let mut constraints = Vec::new(); + // IS_BIT range checks on the sign flags (used as bus multiplicities). + constraints.push(MulConstraint::new( + MulConstraintKind::SignedIsBit(cols::LHS_SIGNED), + idx, + )); + idx += 1; + constraints.push(MulConstraint::new( + MulConstraintKind::SignedIsBit(cols::RHS_SIGNED), + idx, + )); + idx += 1; + // SIGN constraints constraints.push(MulConstraint::new(MulConstraintKind::LhsSign, idx)); idx += 1; diff --git a/prover/src/tables/shift.rs b/prover/src/tables/shift.rs index 9014799e5..20f66ff90 100644 --- a/prover/src/tables/shift.rs +++ b/prover/src/tables/shift.rs @@ -24,7 +24,7 @@ use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing} use stark::table::TableView; use stark::trace::TraceTable; -use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField, SHIFT_16}; +use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField, SHIFT_16, alu_op}; // ========================================================================= // Column indices @@ -74,7 +74,25 @@ pub mod cols { // Multiplicity pub const MU: usize = 25; - pub const NUM_COLUMNS: usize = 26; + // The unified ALU bus carries the full (un-reduced) shift + // amount `arg2` as in2. This mirrors the spec's `shift : DWordWHBB` layout + // `[Byte, Byte, Half, Word]`: SHIFT_AMOUNT (col 4) = shift[0] (low byte, used + // by the computation, which reduces mod 32/64), then SHIFT_B1 = shift[1], + // SHIFT_H1 = shift[2], SHIFT_HIGH = shift[3]. The low-word limbs are + // range-checked (byte/half) so the decomposition is unique → SHIFT_AMOUNT is + // forced to `arg2 & 0xFF`. + /// bits 8-15 of the shift amount (byte) — spec `shift[1]` + pub const SHIFT_B1: usize = 26; + /// bits 16-31 of the shift amount (half) — spec `shift[2]` + pub const SHIFT_H1: usize = 27; + /// bits 32-63 of the shift amount (word) — spec `shift[3]`. `IS_WORD` is + /// *assumed* (per the spec): on the ALU bus this column equals the CPU's + /// `arg2` high word, which is already a well-formed 32-bit word, so it needs + /// no in-chip range check. The high shift bits never affect the result + /// (`shift mod 32/64` only uses the low byte). + pub const SHIFT_HIGH: usize = 28; + + pub const NUM_COLUMNS: usize = 29; // Helpers for iteration pub const IN: [usize; 4] = [IN_0, IN_1, IN_2, IN_3]; @@ -92,8 +110,10 @@ pub mod cols { pub struct ShiftOperation { /// Input value as 4 halfwords (DWordHL) pub in_halves: [u16; 4], - /// Shift amount (byte) + /// Shift amount low byte (used by the computation; effective = mod 32/64). pub shift: u8, + /// Full shift amount `arg2` (the unified ALU bus carries this as in2). + pub shift_amount: u64, /// 0 = left, 1 = right pub direction: bool, /// Whether arithmetic (signed) right shift @@ -103,7 +123,15 @@ pub struct ShiftOperation { } impl ShiftOperation { - pub fn new(value: u64, shift: u8, direction: bool, signed: bool, word_instr: bool) -> Self { + /// `shift_amount` is the full (un-reduced) shift operand `arg2`; only its low + /// byte feeds the computation (the result depends on `arg2 mod 32/64`). + pub fn new( + value: u64, + shift_amount: u64, + direction: bool, + signed: bool, + word_instr: bool, + ) -> Self { Self { in_halves: [ (value & 0xFFFF) as u16, @@ -111,7 +139,8 @@ impl ShiftOperation { ((value >> 32) & 0xFFFF) as u16, ((value >> 48) & 0xFFFF) as u16, ], - shift, + shift: (shift_amount & 0xFF) as u8, + shift_amount, direction, signed, word_instr, @@ -175,6 +204,15 @@ impl ShiftOperation { } } + /// The raw shift output the chip writes to `OUT` (DWordWL) and sends on the + /// ALU bus as `res`. Unlike [`compute_result`](Self::compute_result), this is + /// NOT sign-extended for word shifts — the CPU32 applies that extension to + /// obtain `rvd`. For non-word shifts the two coincide. + pub fn compute_out(&self) -> u64 { + let aux = self.compute_aux(); + aux.out[0] as u64 | ((aux.out[1] as u64) << 32) + } + /// Compute all auxiliary values for trace generation. fn compute_aux(&self) -> ShiftAux { let left = !self.direction; @@ -332,6 +370,10 @@ pub fn generate_shift_trace( data[base + cols::IN[i]] = FE::from(op.in_halves[i] as u64); } data[base + cols::SHIFT_AMOUNT] = FE::from(op.shift as u64); + // High bits of the full shift amount (for the ALU bus in2 = arg2). + data[base + cols::SHIFT_B1] = FE::from((op.shift_amount >> 8) & 0xFF); + data[base + cols::SHIFT_H1] = FE::from((op.shift_amount >> 16) & 0xFFFF); + data[base + cols::SHIFT_HIGH] = FE::from(op.shift_amount >> 32); data[base + cols::DIRECTION] = FE::from(op.direction as u64); data[base + cols::SIGNED] = FE::from(op.signed as u64); data[base + cols::WORD_INSTR] = FE::from(op.word_instr as u64); @@ -561,43 +603,103 @@ pub fn bus_interactions() -> Vec { ], )); - // SHIFT-C15: SHIFT[out; in, shift, direction, signed, word_instr] | -μ (receiver) + // Unified ALU receiver: the CPU dispatches SLL/SRL/SRA here. + // ALU[out::DWordWL; in1=in, in2=shift_amount, flags] where + // flags = opsel(SHIFT=5, +word_instr→SHIFTW=6) + 32*signed + 64*direction. + // in2 = the full shift amount: [SHIFT_AMOUNT + 256*SHIFT_B1 + 2^16*SHIFT_H1, + // SHIFT_HIGH]. interactions.push(BusInteraction::receiver( - BusId::Shift, + BusId::Alu, Multiplicity::Column(cols::MU), vec![ - // out as DWordWL (2 elements) - BusValue::Packed { - start_column: cols::OUT_0, - packing: Packing::DWordWL, - }, - // in as DWordHL (4 halfwords → 2 elements) + // in1 = in as DWordHL (4 halfwords → 2 words) BusValue::Packed { start_column: cols::IN_0, packing: Packing::DWordHL, }, - // shift + // in2 = full shift amount, low word + BusValue::linear(vec![ + LinearTerm::Column { + coefficient: 1, + column: cols::SHIFT_AMOUNT, + }, + LinearTerm::Column { + coefficient: 1 << 8, + column: cols::SHIFT_B1, + }, + LinearTerm::Column { + coefficient: 1 << 16, + column: cols::SHIFT_H1, + }, + ]), + // in2 high word = arg2 bits 32-63 (spec `shift[3]`, a Word; IS_WORD + // assumed via this column's bus equality with the CPU's well-formed + // arg2 high word). BusValue::Packed { - start_column: cols::SHIFT_AMOUNT, + start_column: cols::SHIFT_HIGH, packing: Packing::Direct, }, - // direction + // flags = opsel(SHIFT) + word_instr + 32*signed + 64*direction + BusValue::linear(vec![ + LinearTerm::Constant(alu_op::SHIFT as i64), + LinearTerm::Column { + coefficient: 1, + column: cols::WORD_INSTR, + }, + LinearTerm::Column { + coefficient: 32, + column: cols::SIGNED, + }, + LinearTerm::Column { + coefficient: 64, + column: cols::DIRECTION, + }, + ]), + // out as DWordWL (2 elements) BusValue::Packed { - start_column: cols::DIRECTION, - packing: Packing::Direct, + start_column: cols::OUT_0, + packing: Packing::DWordWL, }, - // signed + ], + )); + + // Range checks for the low-word high bits (so the in2 low-word decomposition + // is unique → SHIFT_AMOUNT is forced to `arg2 & 0xFF`). SHIFT_AMOUNT is also + // byte-checked implicitly via the AND_BYTE[shift, mask] lookups; we still emit + // the explicit ARE_BYTES[shift[0]] below to match the spec's `IS_BYTE[shift[0]]` + // (defense-in-depth, redundant with AND_BYTE). SHIFT_HIGH (the high word) needs + // no check: IS_WORD is assumed (it equals the CPU's well-formed arg2 high word + // on the bus), matching the spec's `shift[3]`. + interactions.push(BusInteraction::sender( + BusId::AreBytes, + Multiplicity::Column(cols::MU), + vec![ BusValue::Packed { - start_column: cols::SIGNED, + start_column: cols::SHIFT_B1, packing: Packing::Direct, }, - // word_instr + BusValue::constant(0), + ], + )); + interactions.push(BusInteraction::sender( + BusId::AreBytes, + Multiplicity::Column(cols::MU), + vec![ BusValue::Packed { - start_column: cols::WORD_INSTR, + start_column: cols::SHIFT_AMOUNT, packing: Packing::Direct, }, + BusValue::constant(0), ], )); + interactions.push(BusInteraction::sender( + BusId::IsHalfword, + Multiplicity::Column(cols::MU), + vec![BusValue::Packed { + start_column: cols::SHIFT_H1, + packing: Packing::Direct, + }], + )); interactions } @@ -932,6 +1034,27 @@ pub fn collect_bitwise_from_shift(operations: &[ShiftOperation]) -> Vec> 8) & 0xFF) as u8, + )); + // ARE_BYTES[shift[0]] — spec IS_BYTE[shift[0]] (defense-in-depth, + // redundant with the AND_BYTE[shift, mask] lookups above). + bitwise_ops.push(BitwiseOperation::single_byte( + BitwiseOperationType::AreBytes, + op.shift, + )); + let half = ((op.shift_amount >> 16) & 0xFFFF) as u16; + bitwise_ops.push(BitwiseOperation::halfword( + BitwiseOperationType::IsHalf, + (half & 0xFF) as u8, + (half >> 8) as u8, + )); } bitwise_ops diff --git a/prover/src/tables/store.rs b/prover/src/tables/store.rs new file mode 100644 index 000000000..d607049c7 --- /dev/null +++ b/prover/src/tables/store.rs @@ -0,0 +1,338 @@ +//! STORE table. +//! +//! Receives the high-level `MEMORY` op from the CPU for store instructions and +//! emits the low-level `MEMW` write. Spec: `spec/src/store.toml`. +//! +//! ## `memory_op` flag bit (spec-faithful) +//! The `MEMORY` receiver flags are `1 + 4·write2 + 8·write4 + 16·write8`; the +//! `+1` is `memory_op`, which balances against the CPU's `mem_flags` +//! (`memory_op = 1` for stores). This matches `store.toml`. +//! +//! Note: the `MEMW` *write* fingerprint carries no `old` value — the +//! previous memory contents are handled inside the MEMW table. So STORE needs +//! no `old` column (mirrors the current CPU store sender, `cpu.rs:1672-1748`). +//! +//! ## Columns +//! - `base_address`: DWordWL (2 words) — effective write address +//! - `timestamp`: DWordWL (2 words) +//! - `write2`/`write4`/`write8`: Bit — exclusive width flags (1 byte = none set) +//! - `value`: DWordBL (8 bytes) — value to store +//! - `μ`: multiplicity + +use math::field::element::FieldElement; +use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; +use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing}; +use stark::table::TableView; +use stark::trace::TraceTable; + +use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField}; +use crate::constraints::templates::new_is_bit_constraints; + +// ========================================================================= +// Column indices for STORE table +// ========================================================================= + +/// Column definitions for the STORE table. +pub mod cols { + pub const BASE_ADDRESS_0: usize = 0; + pub const BASE_ADDRESS_1: usize = 1; + pub const TIMESTAMP_0: usize = 2; + pub const TIMESTAMP_1: usize = 3; + pub const WRITE2: usize = 4; + pub const WRITE4: usize = 5; + pub const WRITE8: usize = 6; + /// value as 8 bytes (DWordBL), little-endian. + pub const VALUE: [usize; 8] = [7, 8, 9, 10, 11, 12, 13, 14]; + /// μ: multiplicity + pub const MU: usize = 15; + + /// Total number of columns + pub const NUM_COLUMNS: usize = 16; +} + +// ========================================================================= +// Trace generation +// ========================================================================= + +/// A single STORE operation. Exactly one of `write2/write4/write8` is set, or +/// none for a single-byte store. +#[derive(Debug, Clone, Default, Hash, PartialEq, Eq)] +pub struct StoreOperation { + pub base_address: u64, + pub timestamp: u64, + pub value: u64, + pub write2: bool, + pub write4: bool, + pub write8: bool, +} + +impl StoreOperation { + pub fn new(base_address: u64, timestamp: u64, value: u64, bytes: u8) -> Self { + Self { + base_address, + timestamp, + value, + write2: bytes == 2, + write4: bytes == 4, + write8: bytes == 8, + } + } + + /// The 8 `ARE_BYTES[value[i], 0]` range checks this op sends, for the BITWISE + /// table's multiplicity bookkeeping. + pub fn collect_bitwise_ops(&self) -> Vec { + use super::bitwise::{BitwiseOperation, BitwiseOperationType}; + (0..8) + .map(|i| { + let byte = ((self.value >> (i * 8)) & 0xFF) as u8; + BitwiseOperation::single_byte(BitwiseOperationType::AreBytes, byte) + }) + .collect() + } +} + +/// Generates the STORE trace. Each store has a distinct timestamp, so rows are +/// not deduplicated (μ = 1 each); the table is padded to a power of two (min 4). +pub fn generate_store_trace( + operations: &[StoreOperation], +) -> TraceTable { + let num_rows = operations.len().next_power_of_two().max(4); + let mut data = vec![FE::zero(); num_rows * cols::NUM_COLUMNS]; + + for (row_idx, op) in operations.iter().enumerate() { + let base = row_idx * cols::NUM_COLUMNS; + + data[base + cols::BASE_ADDRESS_0] = FE::from(op.base_address & 0xFFFF_FFFF); + data[base + cols::BASE_ADDRESS_1] = FE::from(op.base_address >> 32); + data[base + cols::TIMESTAMP_0] = FE::from(op.timestamp & 0xFFFF_FFFF); + data[base + cols::TIMESTAMP_1] = FE::from(op.timestamp >> 32); + data[base + cols::WRITE2] = FE::from(op.write2 as u64); + data[base + cols::WRITE4] = FE::from(op.write4 as u64); + data[base + cols::WRITE8] = FE::from(op.write8 as u64); + for i in 0..8 { + data[base + cols::VALUE[i]] = FE::from((op.value >> (8 * i)) & 0xFF); + } + data[base + cols::MU] = FE::one(); + } + + TraceTable::new_main(data, cols::NUM_COLUMNS, 1) +} + +// ========================================================================= +// Bus interactions +// ========================================================================= + +/// All bus interactions for the STORE table: +/// - **Sends** the low-level `MEMW` write (16 elements, no `old`). +/// - **Receives** the high-level `MEMORY` op (flags include the `memory_op` bit). +/// - **Sends** `ARE_BYTES[value[i], 0]` (×8) to range-check the stored bytes. +pub fn bus_interactions() -> Vec { + let mut interactions = Vec::with_capacity(10); + + // MEMW[0, base_address, value, timestamp, write2, write4, write8] (write, + // 16 elements, no `old`). Mirrors cpu.rs:1672-1748. + interactions.push(BusInteraction::sender( + BusId::Memw, + Multiplicity::Column(cols::MU), + vec![ + BusValue::constant(0), // is_register = 0 (memory access) + BusValue::Packed { + start_column: cols::BASE_ADDRESS_0, + packing: Packing::DWordWL, + }, + // value as 8 individual bytes + BusValue::Packed { + start_column: cols::VALUE[0], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[1], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[2], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[3], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[4], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[5], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[6], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::VALUE[7], + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::DWordWL, + }, + BusValue::Packed { + start_column: cols::WRITE2, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::WRITE4, + packing: Packing::Direct, + }, + BusValue::Packed { + start_column: cols::WRITE8, + packing: Packing::Direct, + }, + ], + )); + + // MEMORY[timestamp, base_address, value, flags] -> 0 (receiver, mult μ). + // flags = 1 + 4·write2 + 8·write4 + 16·write8 — the `1` is memory_op + // (matches store.toml). + interactions.push(BusInteraction::receiver( + BusId::MemoryOp, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: cols::TIMESTAMP_0, + packing: Packing::DWordWL, + }, + BusValue::Packed { + start_column: cols::BASE_ADDRESS_0, + packing: Packing::DWordWL, + }, + // value cast to DWordWL (8 bytes -> 2 words) + BusValue::Packed { + start_column: cols::VALUE[0], + packing: Packing::DWordBL, + }, + // flags: memory_op(1) + width bits + BusValue::linear(vec![ + LinearTerm::Constant(1), + LinearTerm::Column { + coefficient: 4, + column: cols::WRITE2, + }, + LinearTerm::Column { + coefficient: 8, + column: cols::WRITE4, + }, + LinearTerm::Column { + coefficient: 16, + column: cols::WRITE8, + }, + ]), + // output = 0 (DWordWL): stores write nothing back to rd. + BusValue::constant(0), + BusValue::constant(0), + ], + )); + + // ARE_BYTES[value[i], 0] range checks. + for value_col in cols::VALUE { + interactions.push(BusInteraction::sender( + BusId::AreBytes, + Multiplicity::Column(cols::MU), + vec![ + BusValue::Packed { + start_column: value_col, + packing: Packing::Direct, + }, + BusValue::constant(0), + ], + )); + } + + interactions +} + +// ========================================================================= +// Constraints +// ========================================================================= + +/// Width-flag constraints for the STORE table. +pub struct StoreConstraint { + constraint_idx: usize, + kind: StoreConstraintKind, +} + +#[derive(Debug, Clone, Copy)] +pub enum StoreConstraintKind { + /// `write2 + write4 + write8 ∈ {0, 1}` (at most one width bit set). + WidthSumIsBit, + /// `(write2 + write4 + write8) ⇒ μ`, i.e. `(Σ width)·(1 − μ) = 0`. + WidthImpliesMu, +} + +impl StoreConstraint { + pub fn new(kind: StoreConstraintKind, constraint_idx: usize) -> Self { + Self { + constraint_idx, + kind, + } + } +} + +impl TransitionConstraint for StoreConstraint { + fn degree(&self) -> usize { + 2 + } + + fn constraint_idx(&self) -> usize { + self.constraint_idx + } + + fn evaluate(&self, step: &TableView) -> FieldElement + where + F: IsSubFieldOf, + E: IsField, + { + let w2 = step.get_main_evaluation_element(0, cols::WRITE2).clone(); + let w4 = step.get_main_evaluation_element(0, cols::WRITE4).clone(); + let w8 = step.get_main_evaluation_element(0, cols::WRITE8).clone(); + let sum = &w2 + &w4 + &w8; + let one = FieldElement::::one(); + match self.kind { + StoreConstraintKind::WidthSumIsBit => &sum * (&one - &sum), + StoreConstraintKind::WidthImpliesMu => { + let mu = step.get_main_evaluation_element(0, cols::MU).clone(); + &sum * (&one - &mu) + } + } + } +} + +/// Creates all transition constraints for the STORE table: `IS_BIT` on each +/// width flag, the width-sum-is-bit constraint, and width ⇒ μ. +pub fn store_constraints( + constraint_idx_start: usize, +) -> ( + Vec>>, + usize, +) { + let mut constraints: Vec< + Box>, + > = Vec::new(); + + let (is_bit, mut idx) = new_is_bit_constraints( + &[cols::WRITE2, cols::WRITE4, cols::WRITE8, cols::MU], + constraint_idx_start, + ); + for c in is_bit { + constraints.push(c.boxed()); + } + + constraints.push(StoreConstraint::new(StoreConstraintKind::WidthSumIsBit, idx).boxed()); + idx += 1; + constraints.push(StoreConstraint::new(StoreConstraintKind::WidthImpliesMu, idx).boxed()); + idx += 1; + + (constraints, idx) +} diff --git a/prover/src/tables/trace_builder.rs b/prover/src/tables/trace_builder.rs index 76535484b..bbaed139f 100644 --- a/prover/src/tables/trace_builder.rs +++ b/prover/src/tables/trace_builder.rs @@ -39,10 +39,13 @@ use stark::trace::TraceTable; use super::bitwise::{self, BitwiseOperation, BitwiseOperationType}; use super::branch::{self, BranchOperation}; +use super::bytewise; use super::commit::{self, CommitOperation}; use super::cpu::{self, CpuOperation}; +use super::cpu32; use super::decode; use super::dvrm::{self, DvrmOperation}; +use super::eq; use super::halt; use super::keccak::{self, KeccakOperation}; use super::keccak_rc; @@ -56,6 +59,7 @@ use super::mul::{self, MulOperation}; use super::page::{self, FinalByteState, FinalStateMap, PageConfig}; use super::register::{self, FinalRegisterStateMap, FinalRegisterWordState}; use super::shift::{self, ShiftOperation}; +use super::store; use super::types::{GoldilocksExtension, GoldilocksField}; use crate::Error; @@ -290,16 +294,8 @@ impl RegisterState { /// Get byte count and signed flag from CpuOperation memory flags. fn cpu_op_to_bytes_and_signed(op: &CpuOperation) -> (usize, bool) { - let byte_count = if op.decode.memory_8bytes { - 8 - } else if op.decode.memory_4bytes { - 4 - } else if op.decode.memory_2bytes { - 2 - } else { - 1 - }; - (byte_count, op.decode.signed) + let f = &op.decode.fields; + (f.mem_bytes(), f.mem_signed()) } /// Pack a 64-bit register value into the MEMW value format. @@ -368,6 +364,7 @@ fn collect_ops_from_cpu( Vec, Vec, Vec, + Vec, ) { let mut memw_ops = Vec::with_capacity(cpu_ops.len() * 3); let mut load_ops = Vec::with_capacity(cpu_ops.len() / 8 + 1); @@ -376,19 +373,27 @@ fn collect_ops_from_cpu( let mut bitwise_ops = Vec::with_capacity(cpu_ops.len() * 4); let mut commit_ops = Vec::new(); let mut keccak_ops = Vec::new(); + let mut cpu32_ops = Vec::new(); let mut current_commit_index = 0u32; let mut commit_ecall_count = 0u32; for op in cpu_ops { + // Word (`*W`) instructions delegate to the CPU32 table (built in program + // order; its register accesses are still emitted via the shared register + // collector below so the MEMW table balances). + if op.decode.fields.word_instr { + cpu32_ops.push(build_cpu32_op(op)); + } + // --- MEMW and LOAD (require state tracking, order matters) --- // Collect memory operations for Load/Store instructions - if op.decode.op_load { + if op.decode.fields.is_load() { let (memw_op, load_op, lookups) = collect_load_op_from_cpu(op, memory_state); memw_ops.push(memw_op); load_ops.push(load_op); bitwise_ops.extend(lookups); - } else if op.decode.op_store { + } else if op.decode.fields.is_store() { let memw_op = collect_store_op_from_cpu(op, memory_state); memw_ops.push(memw_op); } @@ -450,32 +455,37 @@ fn collect_ops_from_cpu( }); } - // --- LT, SHIFT, and Bitwise (no state tracking needed) --- - - // Collect LT operations from SLT/BLT instructions - if op.decode.op_slt || op.decode.op_blt { - let arg1 = op.compute_arg1(); - let arg2 = op.compute_arg2(); - lt_ops.push(LtOperation::new(arg1, arg2, op.decode.signed)); - } - - // Collect SHIFT operations - if op.decode.op_shift { - let input = op.compute_arg1(); - let shift_amount = (op.compute_arg2() & 0xFF) as u8; - let direction = op.decode.mp_selector; // 0=left, 1=right - let signed = op.decode.signed; - let word_instr = op.decode.word_instr; - shift_ops.push(ShiftOperation::new( - input, - shift_amount, - direction, - signed, - word_instr, - )); + // --- ALU chip dispatch (no state tracking) --- + // Word (`*W`) instructions are delegated to CPU32 (which itself drives + // the ALU chips); the main CPU does not send the ALU bus for them, so we + // must not emit chip ops here. CPU32 op-generation is B5b. + let f = op.decode.fields; + if !f.word_instr { + // LT: SLT / BLT / BGE, dispatched on the unified ALU bus. `invert` + // (BGE/BGEU) is applied inside the LT chip (`out = lt XOR invert`). + if f.is_lt() { + lt_ops.push(LtOperation::new_with_invert( + op.rv1, + op.arg2, + f.alu_signed(), + f.alu_signed2_or_invert(), + )); + } + // SHIFT: SLL/SRL/SRA. direction = invert bit (0 = left, 1 = right). + // The full arg2 goes on the ALU bus as in2; the chip uses its low + // byte for the (mod 32/64) computation. + if f.is_shift() { + shift_ops.push(ShiftOperation::new( + op.rv1, + op.arg2, + f.alu_signed2_or_invert(), + f.alu_signed(), + f.word_instr, + )); + } } - // Collect bitwise lookups + // Collect CPU range-check bitwise lookups (ARE_BYTES + IS_HALF). bitwise_ops.extend(op.collect_bitwise_ops()); } @@ -494,6 +504,7 @@ fn collect_ops_from_cpu( bitwise_ops, commit_ops, keccak_ops, + cpu32_ops, ) } @@ -582,19 +593,21 @@ fn collect_store_op_from_cpu(op: &CpuOperation, memory_state: &mut MemoryState) *byte = (store_value >> (j * 8)) & 0xFF; } - // Create MEMW operation (write) - M7 uses timestamp+1 + // The STORE chip now owns this MEMW write (the CPU sends MEMORY instead of + // the old inline M7). It uses the base timestamp — the same the CPU sends on + // the MEMORY bus — per spec store.toml. let memw_op = MemwOperation::new( false, // is_register = false base_address, value_bytes, - op.timestamp + 1, + op.timestamp, byte_count as u8, false, // is_read = false (write) ) .with_old(old_values, old_timestamps); - // Update memory state (using timestamp+1 to match M7) - memory_state.write_bytes(base_address, store_value, byte_count, op.timestamp + 1); + // Update memory state at the base timestamp (matches the STORE MEMW write). + memory_state.write_bytes(base_address, store_value, byte_count, op.timestamp); memw_op } @@ -607,7 +620,11 @@ fn collect_register_ops_from_cpu( register_state: &mut RegisterState, ) -> Vec { let mut memw_ops = Vec::with_capacity(4); - let d = &op.decode; + let d = &op.decode.fields; + // These register accesses happen for every real instruction. For non-word + // rows the main CPU sends the MEMW lookups; for word (`*W`) rows the CPU32 + // table sends them. Either way the MEMW *table* receives the same record, so + // we generate it here (in program order, for register-state timestamps). // M1: Read rs1 register at timestamp+0 // Skip x0 (hardwired zero). x255 (the register where the pc is stored) is handled @@ -669,6 +686,152 @@ fn collect_register_ops_from_cpu( memw_ops } +// ============================================================================= +// CPU32 (word `*W` instruction) op-generation +// ============================================================================= + +/// The raw ALU result `res` for a CPU32 row, matching what the dispatched chip +/// (or the ADD/SUB fast-path) computes from the sign-extended `arg1`/`arg2`. +fn cpu32_res(c: &cpu32::Cpu32Operation, arg1: u64, arg2: u64) -> u64 { + use crate::tables::types::alu_op; + if c.add { + return arg1.wrapping_add(arg2); + } + if c.sub { + return arg1.wrapping_sub(arg2); + } + if !c.alu { + return 0; + } + let op = c.alu_flags & 0x1F; + let signed = (c.alu_flags >> 5) & 1 == 1; + let s2_or_inv = (c.alu_flags >> 6) & 1 == 1; + let muldiv = (c.alu_flags >> 7) & 1 == 1; + if op == alu_op::SHIFT || op == alu_op::SHIFTW { + // The ALU bus carries the chip's raw OUT (not the sign-extended value); + // CPU32 sign-extends it to rvd. + ShiftOperation::new(arg1, arg2, s2_or_inv, signed, true).compute_out() + } else if op == alu_op::MUL { + MulOperation::new(arg1, signed, arg2, s2_or_inv) + .compute_product() + .0 + } else if op == alu_op::DIVREM { + let d = DvrmOperation::new(arg1, arg2, signed); + if muldiv { + d.compute_remainder() + } else { + d.compute_quotient() + } + } else { + 0 + } +} + +/// Builds the CPU32 row for a word (`*W`) instruction. `op.rv1/rv2/rvd` carry the +/// real register values (the main CPU delegate row zeroes its own columns). +fn build_cpu32_op(op: &CpuOperation) -> cpu32::Cpu32Operation { + let f = &op.decode.fields; + let mut c = cpu32::Cpu32Operation { + timestamp: op.timestamp, + pc: op.decode.pc, + rs1: f.rs1, + read_register1: f.read_register1, + rv1: op.rv1, + rs2: f.rs2, + read_register2: f.read_register2, + rv2: op.rv2, + imm: op.decode.imm, + res: 0, + rd: f.rd, + write_register: f.write_register, + alu: f.alu, + alu_flags: f.alu_flags, + add: f.add, + sub: f.sub, + half_instruction_length: f.half_instruction_length, + }; + let aux = c.compute_aux(); + c.res = cpu32_res(&c, aux.arg1, aux.arg2); + c +} + +/// The BITWISE-table lookups a CPU32 row sends: 5×ARE_BYTES (byte fields), +/// 8×IS_HALF (rv1/rv2 low-word halves + the 4 res halves), 1×BYTE_ALU (extracts +/// the signed bit from `alu_flags`), and 3×MSB16 (rv1/rv2/res sign bits). +fn collect_cpu32_bitwise(c: &cpu32::Cpu32Operation) -> Vec { + let mut ops = Vec::with_capacity(17); + let half = |v: u64, sh: u32| ((v >> sh) & 0xFFFF) as u16; + let push_half = |ops: &mut Vec, kind, h: u16| { + ops.push(BitwiseOperation::halfword( + kind, + (h & 0xFF) as u8, + (h >> 8) as u8, + )); + }; + + for b in [c.half_instruction_length, c.alu_flags, c.rs1, c.rs2, c.rd] { + ops.push(BitwiseOperation::single_byte( + BitwiseOperationType::AreBytes, + b, + )); + } + // IS_HALF: rv1[0],rv1[1],rv2[0],rv2[1],res[0..3] + let rv1_h0 = half(c.rv1, 0); + let rv1_h1 = half(c.rv1, 16); + let rv2_h0 = half(c.rv2, 0); + let rv2_h1 = half(c.rv2, 16); + for h in [rv1_h0, rv1_h1, rv2_h0, rv2_h1] { + push_half(&mut ops, BitwiseOperationType::IsHalf, h); + } + for i in 0..4 { + push_half(&mut ops, BitwiseOperationType::IsHalf, half(c.res, i * 16)); + } + // BYTE_ALU[AND, X=32, Y=alu_flags] -> 32*signed (extract signed bit). + ops.push(BitwiseOperation::byte_op( + BitwiseOperationType::ByteAluAnd, + 32, + c.alu_flags, + )); + // MSB16 on the high half of each low word (rv1, rv2, res). + let res_h1 = half(c.res, 16); + for h in [rv1_h1, rv2_h1, res_h1] { + push_half(&mut ops, BitwiseOperationType::Msb16, h); + } + ops +} + +/// The ALU-chip op a word ALU instruction dispatches (SHIFT/MUL/DVRM). ADDW/SUBW +/// are the CPU32 ADD/SUB fast-path (no external chip), returning `None`. +#[allow(clippy::type_complexity)] +fn cpu32_chip_op( + c: &cpu32::Cpu32Operation, + shift_ops: &mut Vec, + mul_ops: &mut Vec<(MulOperation, bool)>, + dvrm_ops: &mut Vec<(DvrmOperation, bool)>, +) { + use crate::tables::types::alu_op; + if c.add || c.sub || !c.alu { + return; + } + let aux = c.compute_aux(); + let op = c.alu_flags & 0x1F; + let signed = aux.signed; + let s2_or_inv = (c.alu_flags >> 6) & 1 == 1; + let muldiv = (c.alu_flags >> 7) & 1 == 1; + if op == alu_op::SHIFT || op == alu_op::SHIFTW { + shift_ops.push(ShiftOperation::new( + aux.arg1, aux.arg2, s2_or_inv, signed, true, + )); + } else if op == alu_op::MUL { + mul_ops.push(( + MulOperation::new(aux.arg1, signed, aux.arg2, s2_or_inv), + muldiv, + )); + } else if op == alu_op::DIVREM { + dvrm_ops.push((DvrmOperation::new(aux.arg1, aux.arg2, signed), muldiv)); + } +} + /// Collects MEMW operations for a COMMIT ECALL from CpuOperation. /// /// All operations use the raw ECALL timestamp (no offsets). Per the spec, @@ -771,12 +934,15 @@ fn collect_commit_memw_ops( /// Collects HALT finalization MEMW operations for all 33 registers. /// -/// Per spec (halt.toml): at timestamp 2^64-1, HALT finalizes every register: +/// Per spec (halt.toml): at timestamp 2^64-1, HALT finalizes the GP registers: /// - x1-x9, x11-x31: write 0 (zeroize) /// - x10: read (verify exit code = 0; if x10 ≠ 0, proof fails via bus mismatch) -/// - x255 (PC): write 1 (halted sentinel) /// -/// Also updates `register_state` so `to_final_state_map()` reflects the finalized values. +/// The PC (x255) is NOT finalized here — it is handled on the inline-PC `memory` +/// bus by the HALT chip's consume_pc/emit_pc plus the CPU padding chain (its +/// REGISTER final token is set separately by the caller, at the last padding +/// timestamp). Also updates `register_state` so `to_final_state_map()` reflects +/// the finalized GP register values. fn collect_halt_ops(register_state: &mut RegisterState) -> Vec { let mut ops = Vec::with_capacity(32); let ts = u64::MAX; @@ -816,16 +982,9 @@ fn collect_halt_ops(register_state: &mut RegisterState) -> Vec { register_state.write(i, 0, ts); } - // x255 (PC): write 1 - { - let (old_val, old_ts) = register_state.read_pc(); - let old_value = pack_register_value(old_val); - let old_timestamps = [old_ts, old_ts, 0, 0, 0, 0, 0, 0]; - let memw_op = MemwOperation::new(true, 510, pack_register_value(1), ts, 2, false) - .with_old(old_value, old_timestamps); - ops.push(memw_op); - register_state.write_pc(1, ts); - } + // x255 (PC) is finalized via the inline-PC `memory` bus + REGISTER table, not + // via a MEMW write at 2^64-1. See `collect_halt_ops` doc and the PC finalization + // in the caller. ops } @@ -1441,25 +1600,23 @@ fn collect_byte_check_ops_for_padding(num_padding_rows: usize) -> Vec>, + // Auxiliary ALU / memory / CPU32 dispatch chips (split into chunks of their max_rows) + pub eqs: Vec>, + pub bytewises: Vec>, + pub stores: Vec>, + pub cpu32s: Vec>, } /// Intermediate state from Phase 2: all ops collected from CPU, ready for @@ -2068,6 +2230,11 @@ struct CollectedOps { dvrm_ops: Vec<(DvrmOperation, bool)>, commit_ops: Vec, keccak_ops: Vec, + // Auxiliary ALU / memory / CPU32 dispatch chips (driven by the CPU ALU/MEMORY dispatch). + eq_ops: Vec, + bytewise_ops: Vec, + store_ops: Vec, + cpu32_ops: Vec, } /// Chunk raw ops and generate one trace table per chunk. When `storage_mode` @@ -2109,10 +2276,11 @@ fn collect_all_ops( mut memw_ops: Vec, load_ops: Vec, mut lt_ops: Vec, - shift_ops: Vec, - bitwise_ops: Vec, + mut shift_ops: Vec, + mut bitwise_ops: Vec, commit_ops: Vec, keccak_ops: Vec, + cpu32_ops: Vec, register_state: &mut RegisterState, ) -> CollectedOps { // HALT finalization: 33 register MEMW operations at timestamp u64::MAX. @@ -2135,44 +2303,81 @@ fn collect_all_ops( BranchOperation::new( op.decode.pc, op.decode.imm, // offset as full 64-bit DWordWL (already sign-extended) - op.compute_arg1(), // register value must match CPU's arg1 for bus signature - op.decode.op_jalr, + op.rv1, // register value must match the CPU's BRANCH bus signature + op.decode.fields.jalr(), ) }) .collect(); - // Collect MUL operations from CPU ops where op_mul = true + // Collect MUL operations from non-word MUL instructions. lhs_signed = `signed` + // (alu_flags bit 5); rhs_signed = `signed2` (bit 6); wants_hi = `muldiv` (bit 7). let mut mul_ops: Vec<(MulOperation, bool)> = cpu_ops .iter() - .filter(|op| op.decode.op_mul) + .filter(|op| !op.decode.fields.word_instr && op.decode.fields.is_mul()) .map(|op| { - let lhs = op.compute_arg1(); - let lhs_signed = op.decode.signed; - // rhs_signed = mp_selector per spec CPU-CA44: - // MUL/MULH have mp_selector=1 (both signed), MULHU/MULHSU have mp_selector=0 (rhs unsigned) - let rhs_signed = op.decode.mp_selector; - let rhs = op.compute_arg2(); - let wants_hi = op.decode.muldiv_selector; + let f = op.decode.fields; ( - MulOperation::new(lhs, lhs_signed, rhs, rhs_signed), - wants_hi, + MulOperation::new(op.rv1, f.alu_signed(), op.arg2, f.alu_signed2_or_invert()), + f.alu_muldiv(), ) }) .collect(); - // Collect DVRM operations from CPU ops where op_divrem = true - let dvrm_ops: Vec<(DvrmOperation, bool)> = cpu_ops + // Collect DVRM operations from non-word DIV/REM instructions. + let mut dvrm_ops: Vec<(DvrmOperation, bool)> = cpu_ops .iter() - .filter(|op| op.decode.op_divrem) + .filter(|op| !op.decode.fields.word_instr && op.decode.fields.is_divrem()) .map(|op| { - let n = op.compute_arg1(); - let d = op.compute_arg2(); - let signed = op.decode.signed; - let wants_remainder = op.decode.muldiv_selector; - (DvrmOperation::new(n, d, signed), wants_remainder) + let f = op.decode.fields; + ( + DvrmOperation::new(op.rv1, op.arg2, f.alu_signed()), + f.alu_muldiv(), + ) }) .collect(); + // Collect the ALU/MEMORY chip ops (non-word rows). + // EQ: BEQ/BNE (invert = alu_flags bit 6). BYTEWISE: AND/OR/XOR (op = alu_op). + let eq_ops: Vec = cpu_ops + .iter() + .filter(|op| !op.decode.fields.word_instr && op.decode.fields.is_eq()) + .map(|op| eq::EqOperation::new(op.rv1, op.arg2, op.decode.fields.alu_signed2_or_invert())) + .collect(); + let bytewise_ops: Vec = cpu_ops + .iter() + .filter(|op| { + let f = &op.decode.fields; + !f.word_instr && (f.is_and() || f.is_or() || f.is_xor()) + }) + .map(|op| bytewise::BytewiseOperation::new(op.rv1, op.arg2, op.decode.fields.alu_op())) + .collect(); + // STORE: receives MEMORY(memory_op=1) from the CPU and sends the MEMW write + // at timestamp+1 (mirrors `collect_store_op_from_cpu`, which records the MEMW + // table row). + let store_ops: Vec = cpu_ops + .iter() + .filter(|op| op.decode.fields.is_store()) + .map(|op| { + // The MEMORY bus and the STORE chip's MEMW write share the base + // timestamp (spec store.toml uses one `timestamp` for both). + store::StoreOperation::new( + op.res, + op.timestamp, + op.rv2, + op.decode.fields.mem_bytes() as u8, + ) + }) + .collect(); + + // CPU32 (word `*W`) dispatch: each CPU32 row that uses the full ALU sends to + // the SHIFT/MUL/DVRM chips (ADDW/SUBW are the CPU32 ADD/SUB fast-path). These + // word DVRM ops are added before the DVRM→LT/MUL loops so they get their own + // internal consistency lookups. CPU32 also sends its own BITWISE range checks. + for c in &cpu32_ops { + cpu32_chip_op(c, &mut shift_ops, &mut mul_ops, &mut dvrm_ops); + bitwise_ops.extend(collect_cpu32_bitwise(c)); + } + // Collect LT operations from DVRM: |r| < |d| (unsigned comparison) for (op, _wants_remainder) in &dvrm_ops { lt_ops.push(LtOperation::new(op.abs_r(), op.abs_d(), false)); @@ -2203,6 +2408,10 @@ fn collect_all_ops( dvrm_ops, commit_ops, keccak_ops, + eq_ops, + bytewise_ops, + store_ops, + cpu32_ops, } } @@ -2218,7 +2427,7 @@ fn build_traces( entry_point: u64, decode_trace: TraceTable, decode_pc_to_row: HashMap, - register_state: RegisterState, + mut register_state: RegisterState, max_rows: &super::MaxRowsConfig, #[cfg(feature = "disk-spill")] storage_mode: StorageMode, private_input: &[u8], @@ -2237,6 +2446,10 @@ fn build_traces( dvrm_ops, commit_ops, keccak_ops, + eq_ops, + bytewise_ops, + store_ops, + cpu32_ops, } = ops; // ===================================================================== @@ -2253,6 +2466,16 @@ fn build_traces( bitwise_ops.extend(collect_bitwise_from_dvrm(&dvrm_ops)); bitwise_ops.extend(collect_bitwise_from_branch(&branch_ops)); bitwise_ops.extend(shift::collect_bitwise_from_shift(&shift_ops)); + // Auxiliary chips: BYTEWISE sends 8× BYTE_ALU/op; EQ sends 4× IS_HALF + ZERO. + for op in &bytewise_ops { + bitwise_ops.extend(op.collect_bitwise_ops()); + } + for op in &eq_ops { + bitwise_ops.extend(op.collect_bitwise_ops()); + } + for op in &store_ops { + bitwise_ops.extend(op.collect_bitwise_ops()); + } bitwise_ops.extend(collect_bitwise_from_memw_aligned(&memw_aligned_ops)); // MEMW_R sends IS_HALFWORD[timestamp_0 - old_timestamp_lo - 1] bitwise_ops.extend(collect_bitwise_from_memw_register(&memw_register_ops)); @@ -2287,9 +2510,18 @@ fn build_traces( let halt_op = cpu_ops .iter() .rev() - .find(|op| op.decode.op_ecall) + .find(|op| op.decode.fields.ecall) .ok_or(Error::MissingHaltEcall)?; let halt_timestamp = halt_op.timestamp; + let halt_next_pc = halt_op.next_pc; + + // Finalize the PC (x255) on the REGISTER table. The CPU padding rows carry + // pc=1 and chain the inline-PC `memory` tokens with a +4 timestamp cadence + // starting from the HALT chip's emit_pc at `halt_timestamp + 1`; the last + // padding write therefore lands at `halt_timestamp + 4*num_padding_rows + 1` + // (= `halt_timestamp + 1` when there is no padding). The REGISTER final token + // must match that last write to balance the memory argument. + register_state.write_pc(1, halt_timestamp + 4 * num_padding_rows as u64 + 1); let cpus = chunk_and_generate( &cpu_ops, @@ -2362,6 +2594,38 @@ fn build_traces( storage_mode, )?; + // Auxiliary ALU / memory / CPU32 dispatch chips. Not yet driven by the CPU + // dispatch, so they are generated empty — one padded (μ=0) chunk each, which + // contributes nothing to any bus. + let eqs = chunk_and_generate::( + &eq_ops, + max_rows.eq, + eq::generate_eq_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + )?; + let bytewises = chunk_and_generate::( + &bytewise_ops, + max_rows.bytewise, + bytewise::generate_bytewise_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + )?; + let stores = chunk_and_generate::( + &store_ops, + max_rows.store, + store::generate_store_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + )?; + let cpu32s = chunk_and_generate::( + &cpu32_ops, + max_rows.cpu32, + cpu32::generate_cpu32_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + )?; + let mut bitwise = bitwise::generate_bitwise_trace(); bitwise::update_multiplicities(&mut bitwise, &bitwise_ops); @@ -2410,7 +2674,7 @@ fn build_traces( || register::generate_register_trace(®ister_final_state, entry_point), ) }, - || halt::generate_halt_trace(halt_timestamp), + || halt::generate_halt_trace(halt_timestamp, halt_next_pc), ); let (pages_v, page_configs_v) = pages_val; pages = pages_v; @@ -2432,7 +2696,7 @@ fn build_traces( } } register_trace = register::generate_register_trace(®ister_final_state, entry_point); - halt_trace = halt::generate_halt_trace(halt_timestamp); + halt_trace = halt::generate_halt_trace(halt_timestamp, halt_next_pc); } // Fixed-size and per-page tables aren't built through `chunk_and_generate`, @@ -2488,6 +2752,10 @@ fn build_traces( keccak_rnd: keccak_rnd_trace, keccak_rc: keccak_rc_trace, memw_registers, + eqs, + bytewises, + stores, + cpu32s, }) } @@ -2596,7 +2864,7 @@ pub fn count_table_lengths( cpu_count += 1; // Memory ops from load/store - if cpu_op.decode.op_load { + if cpu_op.decode.fields.is_load() { let (memw_op, _load_op, _bitwise) = collect_load_op_from_cpu(&cpu_op, &mut memory_state); partition_memw( @@ -2606,7 +2874,7 @@ pub fn count_table_lengths( &mut memw_register_count, ); load_count += 1; - } else if cpu_op.decode.op_store { + } else if cpu_op.decode.fields.is_store() { let memw_op = collect_store_op_from_cpu(&cpu_op, &mut memory_state); partition_memw( &memw_op, @@ -2651,17 +2919,18 @@ pub fn count_table_lengths( .ok_or_else(|| Error::Execution("commit index exceeds u32 range".into()))?; } - // CPU-side per-instruction-kind counters - if cpu_op.decode.op_slt || cpu_op.decode.op_blt { + // CPU-side per-instruction-kind counters (non-word; word → CPU32, B5b) + let f = &cpu_op.decode.fields; + if !f.word_instr && f.is_lt() { lt_count += 1; } - if cpu_op.decode.op_shift { + if !f.word_instr && f.is_shift() { shift_count += 1; } - if cpu_op.decode.op_mul { + if !f.word_instr && f.is_mul() { mul_count += 1; } - if cpu_op.decode.op_divrem { + if !f.word_instr && f.is_divrem() { dvrm_count += 1; } if cpu_op.branch_cond { @@ -2728,11 +2997,14 @@ impl Traces { use super::bitwise::NUM_PRECOMPUTED_COLS as BITWISE_PRECOMPUTED; use super::bitwise::cols::NUM_COLUMNS as BITWISE_COLS; use super::branch::cols::NUM_COLUMNS as BRANCH_COLS; + use super::bytewise::cols::NUM_COLUMNS as BYTEWISE_COLS; use super::commit::cols::NUM_COLUMNS as COMMIT_COLS; use super::cpu::cols::NUM_COLUMNS as CPU_COLS; + use super::cpu32::cols::NUM_COLUMNS as CPU32_COLS; use super::decode::NUM_PRECOMPUTED_COLS as DECODE_PRECOMPUTED; use super::decode::cols::NUM_COLUMNS as DECODE_COLS; use super::dvrm::cols::NUM_COLUMNS as DVRM_COLS; + use super::eq::cols::NUM_COLUMNS as EQ_COLS; use super::halt::cols::NUM_COLUMNS as HALT_COLS; use super::keccak::cols::NUM_COLUMNS as KECCAK_COLS; use super::keccak_rc::NUM_PRECOMPUTED_COLS as KECCAK_RC_PRECOMPUTED; @@ -2749,6 +3021,7 @@ impl Traces { use super::register::NUM_PREPROCESSED_COLS as REGISTER_PREPROCESSED; use super::register::cols::NUM_COLUMNS as REGISTER_COLS; use super::shift::cols::NUM_COLUMNS as SHIFT_COLS; + use super::store::cols::NUM_COLUMNS as STORE_COLS; let Traces { cpus, @@ -2770,6 +3043,10 @@ impl Traces { keccak_rnd, keccak_rc, memw_registers, + eqs, + bytewises, + stores, + cpu32s, page_configs: _, public_output_bytes: _, } = self; @@ -2816,6 +3093,18 @@ impl Traces { total += (keccak.num_rows() * KECCAK_COLS) as u64; total += (keccak_rnd.num_rows() * KECCAK_RND_COLS) as u64; total += (keccak_rc.num_rows() * (KECCAK_RC_COLS - KECCAK_RC_PRECOMPUTED)) as u64; + for t in eqs { + total += (t.num_rows() * EQ_COLS) as u64; + } + for t in bytewises { + total += (t.num_rows() * BYTEWISE_COLS) as u64; + } + for t in stores { + total += (t.num_rows() * STORE_COLS) as u64; + } + for t in cpu32s { + total += (t.num_rows() * CPU32_COLS) as u64; + } total } @@ -2851,6 +3140,10 @@ impl Traces { let n_keccak = aux_cols(super::keccak::bus_interactions().len()); let n_keccak_rnd = aux_cols(super::keccak_rnd::bus_interactions().len()); let n_keccak_rc = aux_cols(super::keccak_rc::bus_interactions().len()); + let n_eq = aux_cols(super::eq::bus_interactions().len()); + let n_bytewise = aux_cols(super::bytewise::bus_interactions().len()); + let n_store = aux_cols(super::store::bus_interactions().len()); + let n_cpu32 = aux_cols(super::cpu32::bus_interactions().len()); let Traces { cpus, @@ -2872,6 +3165,10 @@ impl Traces { keccak_rnd, keccak_rc, memw_registers, + eqs, + bytewises, + stores, + cpu32s, page_configs: _, public_output_bytes: _, } = self; @@ -2918,6 +3215,18 @@ impl Traces { total += (keccak.num_rows() * n_keccak) as u64; total += (keccak_rnd.num_rows() * n_keccak_rnd) as u64; total += (keccak_rc.num_rows() * n_keccak_rc) as u64; + for t in eqs { + total += (t.num_rows() * n_eq) as u64; + } + for t in bytewises { + total += (t.num_rows() * n_bytewise) as u64; + } + for t in stores { + total += (t.num_rows() * n_store) as u64; + } + for t in cpu32s { + total += (t.num_rows() * n_cpu32) as u64; + } total } @@ -2934,6 +3243,10 @@ impl Traces { shift: self.shifts.len(), branch: self.branches.len(), memw_register: self.memw_registers.len(), + eq: self.eqs.len(), + bytewise: self.bytewises.len(), + store: self.stores.len(), + cpu32: self.cpu32s.len(), } } @@ -3077,7 +3390,7 @@ impl Traces { let mut memory_state = MemoryState::from_elf(elf); memory_state.add_private_input(private_input); let mut register_state = RegisterState::new(elf.entry_point); - let (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops, keccak_ops) = + let (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops, keccak_ops, cpu32_ops) = collect_ops_from_cpu(&cpu_ops, &mut memory_state, &mut register_state); let ops = collect_all_ops( @@ -3089,6 +3402,7 @@ impl Traces { bitwise_ops, commit_ops, keccak_ops, + cpu32_ops, &mut register_state, ); @@ -3126,7 +3440,7 @@ impl Traces { let mut memory_state = MemoryState::new(); let entry_point = cpu_ops.first().map_or(0, |op| op.decode.pc); let mut register_state = RegisterState::new(entry_point); - let (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops, keccak_ops) = + let (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops, keccak_ops, cpu32_ops) = collect_ops_from_cpu(&cpu_ops, &mut memory_state, &mut register_state); let ops = collect_all_ops( @@ -3138,6 +3452,7 @@ impl Traces { bitwise_ops, commit_ops, keccak_ops, + cpu32_ops, &mut register_state, ); diff --git a/prover/src/tables/types.rs b/prover/src/tables/types.rs index ceefbbc60..9b7c17c53 100644 --- a/prover/src/tables/types.rs +++ b/prover/src/tables/types.rs @@ -47,73 +47,92 @@ pub enum BusId { /// Single-byte checks (spec template `IS_BYTE`) send the second value as 0. AreBytes = 0, /// Range check: value is a valid halfword [0, 2^16) - IsHalfword, + IsHalfword = 1, /// Range check: value is a 20-bit value [0, 2^20) - IsB20, + IsB20 = 2, // ========================================================================= // Bitwise operations (BITWISE table provides) // ========================================================================= /// Bitwise AND of two bytes: AND_BYTE[X, Y] -> X & Y - AndByte, + AndByte = 3, /// Bitwise OR of two bytes: OR_BYTE[X, Y] -> X | Y - OrByte, + OrByte = 4, /// Bitwise XOR of two bytes: XOR_BYTE[X, Y] -> X ^ Y - XorByte, + XorByte = 5, /// Most significant bit of a byte: MSB8[X] -> (X >> 7) & 1 - Msb8, + Msb8 = 6, /// Most significant bit of a halfword: MSB16[X] -> (X >> 15) & 1 - Msb16, + Msb16 = 7, /// Check if value is zero: ZERO[X] -> X == 0 ? 1 : 0 - Zero, + Zero = 8, // ========================================================================= // Shift helpers (BITWISE table provides) // ========================================================================= /// Halfword shift left: HWSL[X, Z] -> [(X << Z) & 0xFFFF, X >> (16 - Z)] - Hwsl, + Hwsl = 9, // ========================================================================= // Arithmetic operations (separate tables) // ========================================================================= - /// Less-than comparison: LT[lhs, rhs, signed] -> lhs < rhs - Lt, - /// Multiplication: MUL[lhs, lhs_signed, rhs, rhs_signed, hi] -> product - Mul, - /// Division/Remainder: DVRM[result; n, d, signed, muldiv_selector] - Dvrm, - /// Shift operation: SHIFT[in, shift, dir, signed, word] -> out - Shift, + // The four per-chip ALU buses (LT, MUL, DVRM, SHIFT — IDs 10/11/12/13) + // are collapsed into [`Alu`](BusId::Alu). Their numeric IDs are reserved + // (not removed) so the live variants below keep their discriminants stable. // ========================================================================= // Memory/Control // ========================================================================= /// Memory word read/write with timestamps (lookup bus from CPU) - Memw, - /// Memory load with sign/zero extension (lookup bus from CPU) - Load, + Memw = 14, + // ID 15 (Load) is reserved: the load lookup is now dispatched through + // [`MemoryOp`](BusId::MemoryOp). /// Internal memory consistency bus: memory[is_register, address, timestamp, value] /// Used for read/write pairing in MEMW table (M1-M8 in spec) - Memory, + Memory = 16, /// Branch target computation - Branch, + Branch = 17, // ========================================================================= // System (specs not yet defined) // ========================================================================= /// Instruction decode lookup - Decode, + Decode = 18, /// System call handling (CPU → HALT/COMMIT for all ECALLs) - Ecall, + Ecall = 19, /// COMMIT self-referencing recursive bus (row N → row N+1) - CommitNextByte, + CommitNextByte = 20, /// COMMIT output bus: verifier computes the receiver contribution externally /// from `VmProof.public_output` using the shared LogUp challenges - Commit, + Commit = 21, /// Keccak core ↔ round chip: (timestamp, round, state[200 bytes]) - Keccak, + Keccak = 22, /// Keccak round ↔ RC lookup: (round, rc[8 bytes]) - KeccakRc, + KeccakRc = 23, + + // ========================================================================= + // Byte ALU (BITWISE table provides) + // ========================================================================= + /// Unified byte-level ALU lookup: `BYTE_ALU[opsel, X, Y] -> out`, where + /// `opsel` is an [`alu_op`] descriptor (AND=0/OR=1/XOR=2). Collapses the + /// separate `AndByte`/`OrByte`/`XorByte` buses into one. + ByteAlu = 24, + + // ========================================================================= + // Unified ALU + high-level memory dispatch + // ========================================================================= + /// Unified ALU lookup: `ALU[out; in1, in2, alu_flags]`. The CPU (sender) + /// dispatches to the ALU chips (lt/mul/dvrm/shift/eq/bytewise/cpu32) which + /// receive on this bus, selected by the `alu_flags` byte. Replaces the + /// per-chip `Lt`/`Mul`/`Dvrm`/`Shift` output buses. + Alu = 25, + /// High-level memory op: `MEMORY[out; timestamp, address, value, mem_flags]`. + /// The CPU (sender) dispatches to `LOAD`/`STORE` based on `mem_flags`. + /// Distinct from the low-level [`Memory`](BusId::Memory) token bus. + MemoryOp = 26, + /// CPU → CPU32 delegation of word (`*W`) instructions: + /// `CPU32[timestamp, pc, instruction_length]`. + Cpu32 = 27, } impl BusId { @@ -130,20 +149,19 @@ impl BusId { BusId::Msb16 => "Msb16", BusId::Zero => "Zero", BusId::Hwsl => "Hwsl", - BusId::Lt => "Lt", - BusId::Mul => "Mul", - BusId::Shift => "Shift", BusId::Memw => "Memw", - BusId::Load => "Load", BusId::Memory => "Memory", BusId::Branch => "Branch", BusId::Decode => "Decode", BusId::Ecall => "Ecall", - BusId::Dvrm => "Dvrm", BusId::CommitNextByte => "CommitNextByte", BusId::Commit => "Commit", BusId::Keccak => "Keccak", BusId::KeccakRc => "KeccakRc", + BusId::ByteAlu => "ByteAlu", + BusId::Alu => "Alu", + BusId::MemoryOp => "MemoryOp", + BusId::Cpu32 => "Cpu32", } } } @@ -163,12 +181,7 @@ impl TryFrom for BusId { 7 => Ok(BusId::Msb16), 8 => Ok(BusId::Zero), 9 => Ok(BusId::Hwsl), - 10 => Ok(BusId::Lt), - 11 => Ok(BusId::Mul), - 12 => Ok(BusId::Dvrm), - 13 => Ok(BusId::Shift), 14 => Ok(BusId::Memw), - 15 => Ok(BusId::Load), 16 => Ok(BusId::Memory), 17 => Ok(BusId::Branch), 18 => Ok(BusId::Decode), @@ -177,6 +190,10 @@ impl TryFrom for BusId { 21 => Ok(BusId::Commit), 22 => Ok(BusId::Keccak), 23 => Ok(BusId::KeccakRc), + 24 => Ok(BusId::ByteAlu), + 25 => Ok(BusId::Alu), + 26 => Ok(BusId::MemoryOp), + 27 => Ok(BusId::Cpu32), other => Err(other), } } @@ -232,260 +249,199 @@ pub const NEG_INV_2_112: u64 = 18446462594437939201; pub const NEG_INV_2_128: u64 = 18446744065119617026; // ========================================================================= -// packed_decode bit positions (shared between CPU and DECODE tables) +// ALU operation descriptors // ========================================================================= -/// Bit positions for the packed_decode field. +/// Numerical descriptors for ALU operations, per `spec/decode.typ`. /// -/// This is the single source of truth for how decode fields are packed into -/// a 51-bit value. Used by: -/// - `DecodeEntry::packed_decode()` - packs fields into a u64 -/// - CPU table bus interaction - builds LinearTerm coefficients -/// -/// ## Format (51 bits total) -/// -/// ```text -/// Bits [0-10]: Control flags (read_reg1, read_reg2, write_reg, memory_*, etc.) -/// Bits [11-26]: ALU operation flags (ADD, SUB, SLT, AND, OR, XOR, etc.) -/// Bits [27-34]: rs1 register index (8 bits) -/// Bits [35-42]: rs2 register index (8 bits) -/// Bits [43-50]: rd register index (8 bits) -/// ``` -pub mod packed_decode { - // Control flags (bits 0-10) - pub const READ_REG1: u32 = 0; - pub const READ_REG2: u32 = 1; - pub const WRITE_REG: u32 = 2; - pub const MEMORY_2BYTES: u32 = 3; - pub const MEMORY_4BYTES: u32 = 4; - pub const MEMORY_8BYTES: u32 = 5; - pub const C_TYPE: u32 = 6; - pub const SIGNED: u32 = 7; - pub const MP_SELECTOR: u32 = 8; - pub const MULDIV_SELECTOR: u32 = 9; - pub const WORD_INSTR: u32 = 10; - - // ALU operation flags (bits 11-26) - pub const OP_ADD: u32 = 11; - pub const OP_SUB: u32 = 12; - pub const OP_SLT: u32 = 13; - pub const OP_AND: u32 = 14; - pub const OP_OR: u32 = 15; - pub const OP_XOR: u32 = 16; - pub const OP_SHIFT: u32 = 17; - pub const OP_JALR: u32 = 18; - pub const OP_BEQ: u32 = 19; - pub const OP_BLT: u32 = 20; - pub const OP_LOAD: u32 = 21; - pub const OP_STORE: u32 = 22; - pub const OP_MUL: u32 = 23; - pub const OP_DIVREM: u32 = 24; - pub const OP_ECALL: u32 = 25; - pub const OP_EBREAK: u32 = 26; - - // Register indices (bits 27-50) - pub const RS1: u32 = 27; - pub const RS2: u32 = 35; - pub const RD: u32 = 43; +/// These values are the single source of truth for: +/// - the `opsel` selector of the [`BusId::ByteAlu`] lookup (AND/OR/XOR), and +/// - the low 5 bits (`alu_op`) of the packed `alu_flags` byte consumed by the +/// unified `ALU` bus and the ALU chips (shift/lt/mul/dvrm). +pub mod alu_op { + pub const AND: u8 = 0; + pub const OR: u8 = 1; + pub const XOR: u8 = 2; + pub const EQ: u8 = 3; + pub const LT: u8 = 4; + pub const SHIFT: u8 = 5; + pub const SHIFTW: u8 = 6; + pub const MUL: u8 = 7; + pub const DIVREM: u8 = 8; } // ========================================================================= -// DecodeEntry - Shared decode information for CPU and DECODE tables +// packed_decode layout // ========================================================================= -/// A single decoded instruction entry. -/// -/// This struct contains all static decode-time information extracted from an instruction. -/// It is shared between the CPU table (which uses it for execution) and the DECODE table -/// (which provides it as a lookup table). -/// -/// ## Usage +/// Bit layout of the shrunk `packed_decode` field (58 bits used), per +/// `cpu.toml:184-205` and `decode_uncompressed.toml`. /// -/// - **CPU table**: `CpuOperation` contains a `DecodeEntry` plus runtime values (rv1, rv2, etc.) -/// - **DECODE table**: Stores `DecodeEntry` directly, with multiplicity tracking +/// This is the single source of truth shared by the DECODE-table producer and +/// the CPU's `packed_decode` reconstruction, so the DECODE bus fingerprint +/// matches on both sides. /// -/// ## packed_decode Format (51 bits) -/// -/// ```text -/// Bits [0]: read_register1 -/// Bits [1]: read_register2 -/// Bits [2]: write_register -/// Bits [3]: memory_2bytes -/// Bits [4]: memory_4bytes -/// Bits [5]: memory_8bytes -/// Bits [6]: c_type -/// Bits [7]: signed -/// Bits [8]: mp_selector -/// Bits [9]: muldiv_selector -/// Bits [10]: word_instr -/// Bits [11-26]: ALU flags (ADD, SUB, SLT, AND, OR, XOR, SHIFT, JALR, -/// BEQ, BLT, LOAD, STORE, MUL, DIVREM, ECALL, EBREAK) -/// Bits [27:35]: rs1 (8 bits) -/// Bits [35:43]: rs2 (8 bits) -/// Bits [43:51]: rd (8 bits) -/// ``` -#[derive(Debug, Clone, Hash, PartialEq, Eq, Default)] -pub struct DecodeEntry { - // Program counter - /// Program counter (64-bit) - pub pc: u64, +/// NOTE: not yet wired into the DECODE/CPU tables — those still use the older +/// [`packed_decode`] layout. +pub mod packed_decode_shrunk { + // Top-level flags + register indices. + pub const READ_REG1: u32 = 0; + pub const READ_REG2: u32 = 1; + pub const WRITE_REG: u32 = 2; + pub const WORD_INSTR: u32 = 3; + pub const ALU: u32 = 4; + pub const ADD: u32 = 5; + pub const SUB: u32 = 6; + pub const MEMORY: u32 = 7; + pub const BRANCH: u32 = 8; + pub const ECALL: u32 = 9; + pub const RS1: u32 = 10; + pub const RS2: u32 = 18; + pub const RD: u32 = 26; + /// `half_instruction_length`: bytes/2 (1 for C-type, 2 for regular). The + /// half-encoding makes odd (misaligned) instruction lengths unrepresentable + /// (`spec/src/cpu.toml`). + pub const HALF_INSTRUCTION_LENGTH: u32 = 34; + pub const ALU_FLAGS: u32 = 42; + pub const MEM_FLAGS: u32 = 50; + + // `alu_flags` byte interior: bits 0-4 are the `alu_op` descriptor + // (see [`super::alu_op`]); the high bits are flags. + pub const ALU_FLAGS_OP_MASK: u8 = 0x1F; + pub const ALU_FLAGS_SIGNED: u32 = 5; + /// `signed2` (MUL) and `invert` (SHIFT/EQ/LT) are mutually exclusive and + /// share this bit (`64·(signed2 + invert)` in `decode_uncompressed.toml`). + pub const ALU_FLAGS_SIGNED2_OR_INVERT: u32 = 6; + pub const ALU_FLAGS_MULDIV: u32 = 7; + + // `mem_flags` byte interior. Bit 0 aliases `JALR` (under BRANCH) and + // `memory_op` (0=LOAD/1=STORE, under MEMORY); the two are mutually exclusive. + pub const MEM_FLAGS_JALR_OR_OP: u32 = 0; + pub const MEM_FLAGS_SIGNED: u32 = 1; + pub const MEM_FLAGS_2B: u32 = 2; + pub const MEM_FLAGS_4B: u32 = 3; + pub const MEM_FLAGS_8B: u32 = 4; +} - // Register indices (8 bits each) - /// Source register 1 index - pub rs1: u8, - /// Source register 2 index - pub rs2: u8, - /// Destination register index - pub rd: u8, +/// Build the `alu_flags` byte: `alu_op + 32·signed + 64·(signed2|invert) + 128·muldiv`. +pub fn build_alu_flags(alu_op: u8, signed: bool, signed2_or_invert: bool, muldiv: bool) -> u8 { + use packed_decode_shrunk as b; + debug_assert!(alu_op <= b::ALU_FLAGS_OP_MASK, "alu_op must fit in 5 bits"); + alu_op + | ((signed as u8) << b::ALU_FLAGS_SIGNED) + | ((signed2_or_invert as u8) << b::ALU_FLAGS_SIGNED2_OR_INVERT) + | ((muldiv as u8) << b::ALU_FLAGS_MULDIV) +} + +/// Build the `mem_flags` byte: `jalr_or_op + 2·mem_signed + 4·mem_2B + 8·mem_4B + 16·mem_8B`. +pub fn build_mem_flags( + jalr_or_memory_op: bool, + mem_signed: bool, + mem_2b: bool, + mem_4b: bool, + mem_8b: bool, +) -> u8 { + use packed_decode_shrunk as b; + ((jalr_or_memory_op as u8) << b::MEM_FLAGS_JALR_OR_OP) + | ((mem_signed as u8) << b::MEM_FLAGS_SIGNED) + | ((mem_2b as u8) << b::MEM_FLAGS_2B) + | ((mem_4b as u8) << b::MEM_FLAGS_4B) + | ((mem_8b as u8) << b::MEM_FLAGS_8B) +} - // Control flags - /// Whether to read from rs1 +/// Logical (unpacked) view of the reworked `packed_decode` field. `alu_flags` +/// and `mem_flags` are stored already-packed (build them with +/// [`build_alu_flags`] / [`build_mem_flags`]). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +pub struct ShrunkDecode { pub read_register1: bool, - /// Whether to read from rs2 pub read_register2: bool, - /// Whether to write to rd pub write_register: bool, - /// Memory access is 2 bytes - pub memory_2bytes: bool, - /// Memory access is 4 bytes - pub memory_4bytes: bool, - /// Memory access is 8 bytes - pub memory_8bytes: bool, - /// Compressed instruction (2 bytes instead of 4) - pub c_type: bool, - /// Signed operation - pub signed: bool, - /// Multi-purpose selector (shift direction, branch invert, etc.) - pub mp_selector: bool, - /// MUL/DIV output selector - pub muldiv_selector: bool, - /// Word instruction (32-bit with sign extension) pub word_instr: bool, - - // ALU selector flags (one-hot) - /// ADD operation - pub op_add: bool, - /// SUB operation - pub op_sub: bool, - /// SLT (Set Less Than) operation - pub op_slt: bool, - /// AND operation - pub op_and: bool, - /// OR operation - pub op_or: bool, - /// XOR operation - pub op_xor: bool, - /// SHIFT operation - pub op_shift: bool, - /// JALR operation - pub op_jalr: bool, - /// BEQ (Branch if Equal) operation - pub op_beq: bool, - /// BLT (Branch if Less Than) operation - pub op_blt: bool, - /// LOAD operation - pub op_load: bool, - /// STORE operation - pub op_store: bool, - /// MUL operation - pub op_mul: bool, - /// DIVREM operation - pub op_divrem: bool, - /// ECALL operation - pub op_ecall: bool, - /// EBREAK operation - pub op_ebreak: bool, - - // Immediate value - /// Fully extended 64-bit immediate - pub imm: u64, + pub alu: bool, + pub add: bool, + pub sub: bool, + pub memory: bool, + pub branch: bool, + pub ecall: bool, + pub rs1: u8, + pub rs2: u8, + pub rd: u8, + /// Half the byte length of the instruction (1 for C-type, 2 for regular); + /// the real length is `2 * half_instruction_length`. + pub half_instruction_length: u8, + pub alu_flags: u8, + pub mem_flags: u8, } -impl DecodeEntry { - /// Creates a new empty DecodeEntry. - pub fn new() -> Self { - Self::default() +impl ShrunkDecode { + /// Pack into the 58-bit `packed_decode` field value. + pub fn pack(&self) -> u64 { + use packed_decode_shrunk as b; + ((self.read_register1 as u64) << b::READ_REG1) + | ((self.read_register2 as u64) << b::READ_REG2) + | ((self.write_register as u64) << b::WRITE_REG) + | ((self.word_instr as u64) << b::WORD_INSTR) + | ((self.alu as u64) << b::ALU) + | ((self.add as u64) << b::ADD) + | ((self.sub as u64) << b::SUB) + | ((self.memory as u64) << b::MEMORY) + | ((self.branch as u64) << b::BRANCH) + | ((self.ecall as u64) << b::ECALL) + | ((self.rs1 as u64) << b::RS1) + | ((self.rs2 as u64) << b::RS2) + | ((self.rd as u64) << b::RD) + | ((self.half_instruction_length as u64) << b::HALF_INSTRUCTION_LENGTH) + | ((self.alu_flags as u64) << b::ALU_FLAGS) + | ((self.mem_flags as u64) << b::MEM_FLAGS) } - /// Creates the special padding entry for DECODE table. - /// - /// Uses pc=7 with EBREAK=1 flag set. This makes padding rows - /// unprovable since CPU asserts EBREAK=0. - pub fn padding_entry() -> Self { + /// Inverse of [`pack`](Self::pack). + pub fn unpack(packed: u64) -> Self { + use packed_decode_shrunk as b; + let bit = |pos: u32| (packed >> pos) & 1 == 1; + let byte = |pos: u32| ((packed >> pos) & 0xFF) as u8; Self { - pc: 7, - op_ebreak: true, - ..Default::default() + read_register1: bit(b::READ_REG1), + read_register2: bit(b::READ_REG2), + write_register: bit(b::WRITE_REG), + word_instr: bit(b::WORD_INSTR), + alu: bit(b::ALU), + add: bit(b::ADD), + sub: bit(b::SUB), + memory: bit(b::MEMORY), + branch: bit(b::BRANCH), + ecall: bit(b::ECALL), + rs1: byte(b::RS1), + rs2: byte(b::RS2), + rd: byte(b::RD), + half_instruction_length: byte(b::HALF_INSTRUCTION_LENGTH), + alu_flags: byte(b::ALU_FLAGS), + mem_flags: byte(b::MEM_FLAGS), } } - /// Packs all flags and register indices into a single 51-bit value. - /// - /// This matches the spec's packed_decode format (decode.md). - /// Bit positions are defined in the `packed_decode` module. + /// Build the reworked packed-decode flags for an instruction, per + /// `spec/decode.typ`. Does NOT include `pc`/`imm` (separate DECODE columns). /// - /// Note: The register flags (read_register1, read_register2, write_register) - /// are adjusted to exclude x0 (hardwired zero) and x255 (virtual PC for AUIPC/JAL). - /// This matches the CPU trace columns and ensures the DECODE bus balances. - pub fn packed_decode(&self) -> u64 { - use crate::tables::types::packed_decode as bits; - - let mut packed: u64 = 0; - - // Control flags (bits 0-10) - // x0 is hardwired to zero and never physically read. - // x255 is the register where the pc is stored (per spec decode.md), - // so read_register1=1 for rs1=255. - let read_reg1_physical = self.read_register1 && self.rs1 != 0; - let read_reg2_physical = self.read_register2 && self.rs2 != 0; - let write_reg_physical = self.write_register && self.rd != 0; - packed |= (read_reg1_physical as u64) << bits::READ_REG1; - packed |= (read_reg2_physical as u64) << bits::READ_REG2; - packed |= (write_reg_physical as u64) << bits::WRITE_REG; - packed |= (self.memory_2bytes as u64) << bits::MEMORY_2BYTES; - packed |= (self.memory_4bytes as u64) << bits::MEMORY_4BYTES; - packed |= (self.memory_8bytes as u64) << bits::MEMORY_8BYTES; - packed |= (self.c_type as u64) << bits::C_TYPE; - packed |= (self.signed as u64) << bits::SIGNED; - packed |= (self.mp_selector as u64) << bits::MP_SELECTOR; - packed |= (self.muldiv_selector as u64) << bits::MULDIV_SELECTOR; - packed |= (self.word_instr as u64) << bits::WORD_INSTR; - - // ALU flags (bits 11-26) - packed |= (self.op_add as u64) << bits::OP_ADD; - packed |= (self.op_sub as u64) << bits::OP_SUB; - packed |= (self.op_slt as u64) << bits::OP_SLT; - packed |= (self.op_and as u64) << bits::OP_AND; - packed |= (self.op_or as u64) << bits::OP_OR; - packed |= (self.op_xor as u64) << bits::OP_XOR; - packed |= (self.op_shift as u64) << bits::OP_SHIFT; - packed |= (self.op_jalr as u64) << bits::OP_JALR; - packed |= (self.op_beq as u64) << bits::OP_BEQ; - packed |= (self.op_blt as u64) << bits::OP_BLT; - packed |= (self.op_load as u64) << bits::OP_LOAD; - packed |= (self.op_store as u64) << bits::OP_STORE; - packed |= (self.op_mul as u64) << bits::OP_MUL; - packed |= (self.op_divrem as u64) << bits::OP_DIVREM; - packed |= (self.op_ecall as u64) << bits::OP_ECALL; - packed |= (self.op_ebreak as u64) << bits::OP_EBREAK; - - // Register indices (bits 27-50) - packed |= (self.rs1 as u64) << bits::RS1; - packed |= (self.rs2 as u64) << bits::RS2; - packed |= (self.rd as u64) << bits::RD; - - packed - } - - /// Creates a DecodeEntry from a PC and Instruction. + /// `instruction_length` is the byte length: 2 (RV64C compressed) or 4. It is + /// stored as `half_instruction_length = instruction_length / 2`; the real + /// length is recovered as `2 * half_instruction_length`. /// - /// Extracts all decode-time information: pc, registers, flags, immediate. - pub fn from_instruction(pc: u64, instruction: Instruction) -> Self { - let mut entry = Self { - pc, + /// Per `spec/decode.typ`: conditional branches set + /// `BRANCH=1 ∧ ALU=1` (the EQ/LT chip computes the comparison; `BRANCH` + /// selects `arg2 = rv2`). JAL/JALR set `BRANCH=1 ∧ JALR=1` with no ALU op — + /// the return address `pc + instruction_length` is written to `rvd` by the + /// CPU branch group, not the ALU. + pub fn from_instruction(instruction: Instruction, instruction_length: u8) -> Self { + debug_assert!( + instruction_length.is_multiple_of(2), + "instruction_length must be even (RISC-V instructions are 2 or 4 bytes)" + ); + let mut d = Self { + half_instruction_length: instruction_length / 2, ..Default::default() }; - match instruction { Instruction::Arith { dst, @@ -493,309 +449,365 @@ impl DecodeEntry { src2, op, } => { - entry.rd = dst as u8; - entry.rs1 = src1 as u8; - entry.rs2 = src2 as u8; - entry.read_register1 = src1 != 0; - entry.read_register2 = src2 != 0; - if dst != 0 { - entry.write_register = true; - } - Self::set_arith_op(&mut entry, op); - } - - Instruction::ArithImm { dst, src, imm, op } => { - entry.rd = dst as u8; - entry.rs1 = src as u8; - entry.rs2 = 0; - entry.imm = imm as i64 as u64; // Sign extend - entry.read_register1 = src != 0; - if dst != 0 { - entry.write_register = true; - } - Self::set_arith_op(&mut entry, op); + d.rd = dst as u8; + d.rs1 = src1 as u8; + d.rs2 = src2 as u8; + d.read_register1 = src1 != 0; + d.read_register2 = src2 != 0; + d.write_register = dst != 0; + d.apply_arith_op(op, false); + } + Instruction::ArithImm { dst, src, op, .. } => { + d.rd = dst as u8; + d.rs1 = src as u8; + d.read_register1 = src != 0; + d.write_register = dst != 0; + d.apply_arith_op(op, false); } - Instruction::ArithW { dst, src1, src2, op, } => { - entry.rd = dst as u8; - entry.rs1 = src1 as u8; - entry.rs2 = src2 as u8; - entry.word_instr = true; - entry.read_register1 = src1 != 0; - entry.read_register2 = src2 != 0; - if dst != 0 { - entry.write_register = true; - } - Self::set_arith_op(&mut entry, op); - } - - Instruction::ArithImmW { dst, src, imm, op } => { - entry.rd = dst as u8; - entry.rs1 = src as u8; - entry.rs2 = 0; - entry.imm = imm as i64 as u64; // Sign extend - entry.word_instr = true; - entry.read_register1 = src != 0; - if dst != 0 { - entry.write_register = true; - } - Self::set_arith_op(&mut entry, op); - } - - Instruction::JumpAndLink { dst, offset } => { - entry.op_jalr = true; - entry.rd = dst as u8; - // Per spec: JAL is represented as JALR rd, x255, imm - // x255 is the virtual register holding PC - entry.rs1 = 255; - entry.read_register1 = true; // rs1 ≠ 0 - entry.imm = offset as i64 as u64; - if dst != 0 { - entry.write_register = true; - } + d.rd = dst as u8; + d.rs1 = src1 as u8; + d.rs2 = src2 as u8; + d.read_register1 = src1 != 0; + d.read_register2 = src2 != 0; + d.write_register = dst != 0; + d.word_instr = true; + d.apply_arith_op(op, true); + } + Instruction::ArithImmW { dst, src, op, .. } => { + d.rd = dst as u8; + d.rs1 = src as u8; + d.read_register1 = src != 0; + d.write_register = dst != 0; + d.word_instr = true; + d.apply_arith_op(op, true); + } + // JAL is represented as JALR rd, x255, imm (x255 holds pc). + Instruction::JumpAndLink { dst, .. } => { + d.rd = dst as u8; + d.rs1 = 255; + d.read_register1 = true; + d.write_register = dst != 0; + d.branch = true; + d.mem_flags = build_mem_flags(true, false, false, false, false); // JALR bit + } + Instruction::JumpAndLinkRegister { base, dst, .. } => { + d.rd = dst as u8; + d.rs1 = base as u8; + d.read_register1 = base != 0; + d.write_register = dst != 0; + d.branch = true; + d.mem_flags = build_mem_flags(true, false, false, false, false); // JALR bit } - - Instruction::JumpAndLinkRegister { base, dst, offset } => { - entry.op_jalr = true; - entry.rd = dst as u8; - entry.rs1 = base as u8; - entry.imm = offset as i64 as u64; - entry.read_register1 = base != 0; - if dst != 0 { - entry.write_register = true; - } - } - Instruction::Store { - src, - offset, - base, - width, + src, base, width, .. } => { - entry.op_store = true; - entry.rs1 = base as u8; - entry.rs2 = src as u8; - entry.imm = offset as i64 as u64; - entry.read_register1 = base != 0; - entry.read_register2 = src != 0; - // write_register = false for STORE - Self::set_memory_width(&mut entry, width); + d.rs1 = base as u8; + d.rs2 = src as u8; + d.read_register1 = base != 0; + d.read_register2 = src != 0; + d.add = true; // address = rv1 + imm + d.memory = true; + let (m2, m4, m8) = store_width_bits(width); + d.mem_flags = build_mem_flags(true, false, m2, m4, m8); // memory_op = store } - Instruction::Load { - dst, - offset, - base, - width, + dst, base, width, .. } => { - entry.op_load = true; - entry.rd = dst as u8; - entry.rs1 = base as u8; - entry.imm = offset as i64 as u64; - entry.read_register1 = base != 0; - if dst != 0 { - entry.write_register = true; - } - Self::set_memory_width(&mut entry, width); - // Set signed flag for sign-extending loads - match width { - LoadStoreWidth::Byte | LoadStoreWidth::Half | LoadStoreWidth::Word => { - entry.signed = true; - } - _ => {} - } + d.rd = dst as u8; + d.rs1 = base as u8; + d.read_register1 = base != 0; + d.write_register = dst != 0; + d.add = true; // address = rv1 + imm + d.memory = true; + let (m2, m4, m8, signed) = load_width_bits(width); + d.mem_flags = build_mem_flags(false, signed, m2, m4, m8); // memory_op = load } - Instruction::Branch { - src1, - src2, - cond, - offset, + src1, src2, cond, .. } => { - entry.rs1 = src1 as u8; - entry.rs2 = src2 as u8; - entry.imm = offset as i64 as u64; - entry.read_register1 = src1 != 0; - entry.read_register2 = src2 != 0; - - match cond { - Comparison::Equal => { - entry.op_beq = true; - } - Comparison::NotEqual => { - entry.op_beq = true; - entry.mp_selector = true; // Inverted - } - Comparison::LessThan => { - entry.op_blt = true; - entry.signed = true; - } - Comparison::LessThanUnsigned => { - entry.op_blt = true; - } - Comparison::GreaterOrEqual => { - entry.op_blt = true; - entry.signed = true; - entry.mp_selector = true; // Inverted - } - Comparison::GreaterOrEqualUnsigned => { - entry.op_blt = true; - entry.mp_selector = true; // Inverted - } - } + d.rs1 = src1 as u8; + d.rs2 = src2 as u8; + d.read_register1 = src1 != 0; + d.read_register2 = src2 != 0; + d.branch = true; + d.alu = true; // Q3: conditional branches go through the EQ/LT ALU chip + let (op, signed, invert) = branch_cond_flags(cond); + d.alu_flags = build_alu_flags(op, signed, invert, false); + } + // LUI is represented as ADDI rd, x0, imm. + Instruction::LoadUpperImm { dst, .. } => { + d.rd = dst as u8; + d.write_register = dst != 0; + d.add = true; + } + // AUIPC is represented as ADDI rd, x255, imm (x255 holds pc). + Instruction::AddUpperImmToPc { dst, .. } => { + d.rd = dst as u8; + d.rs1 = 255; + d.read_register1 = true; + d.write_register = dst != 0; + d.add = true; } - - Instruction::LoadUpperImm { dst, imm } => { - entry.op_add = true; - entry.rd = dst as u8; - entry.rs1 = 0; - entry.rs2 = 0; - // LUI immediate is sign-extended to 64 bits - entry.imm = (imm as i32) as i64 as u64; - if dst != 0 { - entry.write_register = true; - } + Instruction::EcallEbreak => { + d.rs1 = 17; // a7 holds the syscall number + d.read_register1 = true; + d.ecall = true; } - - Instruction::AddUpperImmToPc { dst, imm } => { - entry.op_add = true; - entry.rd = dst as u8; - // Per spec: AUIPC is represented as ADDI rd, x255, imm - // x255 is the virtual register holding PC - entry.rs1 = 255; - entry.read_register1 = true; // rs1 ≠ 0 - // AUIPC immediate is sign-extended to 64 bits - entry.imm = (imm as i32) as i64 as u64; - if dst != 0 { - entry.write_register = true; - } + // FENCE and CSR are treated as no-ops (ADDI x0, x0, 0). + Instruction::Fence | Instruction::CSR { .. } => { + d.add = true; } + } + d + } - Instruction::CSR { .. } => { - // CSR instructions are executed as no-ops by the VM (see - // executor Instruction::CSR arm returning dst_val: 0, - // src1/2_val: 0). Mirror that here by treating them as - // `ADDI x0, x0, 0` — same pattern as `Fence`. This sets - // `op_add=true` so CM54's multiplicity is non-zero and the - // CPU's PC-update Memw sender fires. - entry.op_add = true; - } + /// Set the `ADD`/`SUB`/`ALU` flags and `alu_flags` byte for an `ArithOp`, + /// per `spec/decode.typ`. `ADD`/`SUB` are fast-paths (ALU not set). + fn apply_arith_op(&mut self, op: ArithOp, word_instr: bool) { + let shift = if word_instr { + alu_op::SHIFTW + } else { + alu_op::SHIFT + }; + // (alu_op, signed, signed2|invert, muldiv, is_add, is_sub) + let (alu, signed, s2_or_inv, muldiv, is_add, is_sub) = match op { + ArithOp::Add => (0, false, false, false, true, false), + ArithOp::Sub => (0, false, false, false, false, true), + ArithOp::And => (alu_op::AND, false, false, false, false, false), + ArithOp::Or => (alu_op::OR, false, false, false, false, false), + ArithOp::Xor => (alu_op::XOR, false, false, false, false, false), + ArithOp::ShiftLeftLogical => (shift, false, false, false, false, false), + ArithOp::ShiftRightLogical => (shift, false, true, false, false, false), // invert = right + ArithOp::ShiftRightArith => (shift, true, true, false, false, false), + ArithOp::SetLessThan => (alu_op::LT, true, false, false, false, false), + ArithOp::SetLessThanU => (alu_op::LT, false, false, false, false, false), + ArithOp::Mul => (alu_op::MUL, true, true, false, false, false), + ArithOp::MulHigh => (alu_op::MUL, true, true, true, false, false), + ArithOp::MulHighSignedUnsigned => (alu_op::MUL, true, false, true, false, false), + ArithOp::MulHighUnsigned => (alu_op::MUL, false, false, true, false, false), + ArithOp::Div => (alu_op::DIVREM, true, false, false, false, false), + ArithOp::DivUnsigned => (alu_op::DIVREM, false, false, false, false, false), + ArithOp::Remainder => (alu_op::DIVREM, true, false, true, false, false), + ArithOp::RemainderUnsigned => (alu_op::DIVREM, false, false, true, false, false), + }; + self.add = is_add; + self.sub = is_sub; + self.alu = !(is_add || is_sub); + self.alu_flags = build_alu_flags(alu, signed, s2_or_inv, muldiv); + } - Instruction::EcallEbreak => { - entry.op_ecall = true; - entry.rs1 = 17; // a7 (syscall number) - entry.read_register1 = true; // M1 reads a7 → rv1 = syscall number - // rs2 and rd default to 0 per spec; read_register2 and write_register remain false. - // HALT/COMMIT chips access registers via direct MEMW interactions. - } + // ---- packed `alu_flags` accessors ---- - Instruction::Fence => { - // Per spec, FENCE is a no-op interpreted as ADDI x0, x0, 0. - entry.op_add = true; - } + /// The `alu_op` descriptor (bits 0-4 of `alu_flags`). + #[inline] + pub fn alu_op(&self) -> u8 { + self.alu_flags & packed_decode_shrunk::ALU_FLAGS_OP_MASK + } + /// `signed` flag (bit 5 of `alu_flags`). + #[inline] + pub fn alu_signed(&self) -> bool { + (self.alu_flags >> packed_decode_shrunk::ALU_FLAGS_SIGNED) & 1 == 1 + } + /// Shared `signed2`/`invert` flag (bit 6 of `alu_flags`); meaning depends on + /// `alu_op` (MUL: `signed2`; SHIFT/EQ/LT: `invert`). + #[inline] + pub fn alu_signed2_or_invert(&self) -> bool { + (self.alu_flags >> packed_decode_shrunk::ALU_FLAGS_SIGNED2_OR_INVERT) & 1 == 1 + } + /// `muldiv_selector` flag (bit 7 of `alu_flags`). + #[inline] + pub fn alu_muldiv(&self) -> bool { + (self.alu_flags >> packed_decode_shrunk::ALU_FLAGS_MULDIV) & 1 == 1 + } + + // ---- packed `mem_flags` accessors (valid under `memory`/`branch`) ---- + + /// Virtual `JALR` bit (bit 0 of `mem_flags`); valid under `branch`. + #[inline] + pub fn jalr(&self) -> bool { + self.mem_flags & 1 == 1 + } + /// STORE (vs LOAD) when `memory`: `memory_op` is bit 0 of `mem_flags`. + #[inline] + pub fn is_store(&self) -> bool { + self.memory && (self.mem_flags & 1 == 1) + } + /// LOAD (vs STORE) when `memory`. + #[inline] + pub fn is_load(&self) -> bool { + self.memory && (self.mem_flags & 1 == 0) + } + /// `mem_signed` flag (bit 1 of `mem_flags`). + #[inline] + pub fn mem_signed(&self) -> bool { + (self.mem_flags >> packed_decode_shrunk::MEM_FLAGS_SIGNED) & 1 == 1 + } + /// Memory access width in bytes (from the `mem_flags` width bits; default 1). + #[inline] + pub fn mem_bytes(&self) -> usize { + use packed_decode_shrunk as b; + if (self.mem_flags >> b::MEM_FLAGS_8B) & 1 == 1 { + 8 + } else if (self.mem_flags >> b::MEM_FLAGS_4B) & 1 == 1 { + 4 + } else if (self.mem_flags >> b::MEM_FLAGS_2B) & 1 == 1 { + 2 + } else { + 1 } + } + + // ---- ALU operation classifiers (valid only when `alu`) ---- - entry + #[inline] + pub fn is_and(&self) -> bool { + self.alu && self.alu_op() == alu_op::AND + } + #[inline] + pub fn is_or(&self) -> bool { + self.alu && self.alu_op() == alu_op::OR + } + #[inline] + pub fn is_xor(&self) -> bool { + self.alu && self.alu_op() == alu_op::XOR + } + #[inline] + pub fn is_eq(&self) -> bool { + self.alu && self.alu_op() == alu_op::EQ } + #[inline] + pub fn is_lt(&self) -> bool { + self.alu && self.alu_op() == alu_op::LT + } + #[inline] + pub fn is_shift(&self) -> bool { + self.alu && matches!(self.alu_op(), x if x == alu_op::SHIFT || x == alu_op::SHIFTW) + } + #[inline] + pub fn is_mul(&self) -> bool { + self.alu && self.alu_op() == alu_op::MUL + } + #[inline] + pub fn is_divrem(&self) -> bool { + self.alu && self.alu_op() == alu_op::DIVREM + } +} - /// Helper to set ALU operation flags based on ArithOp. - fn set_arith_op(entry: &mut Self, arith_op: ArithOp) { - match arith_op { - ArithOp::Add => { - entry.op_add = true; - } - ArithOp::Sub => { - entry.op_sub = true; - } - ArithOp::Xor => entry.op_xor = true, - ArithOp::Or => entry.op_or = true, - ArithOp::And => entry.op_and = true, - ArithOp::ShiftLeftLogical => { - entry.op_shift = true; - // mp_selector = 0 for left shift - } - ArithOp::ShiftRightLogical => { - entry.op_shift = true; - entry.mp_selector = true; // Right shift - } - ArithOp::ShiftRightArith => { - entry.op_shift = true; - entry.mp_selector = true; - entry.signed = true; - } - ArithOp::SetLessThan => { - entry.op_slt = true; - entry.signed = true; - } - ArithOp::SetLessThanU => { - entry.op_slt = true; - } - ArithOp::Mul => { - entry.op_mul = true; - entry.mp_selector = true; - entry.signed = true; - } - ArithOp::MulHigh => { - entry.op_mul = true; - entry.muldiv_selector = true; - entry.mp_selector = true; // both operands signed for MULH - entry.signed = true; - } - ArithOp::MulHighSignedUnsigned => { - entry.op_mul = true; - entry.muldiv_selector = true; - // mp_selector = false (default): rhs is unsigned for MULHSU - entry.signed = true; - } - ArithOp::MulHighUnsigned => { - entry.op_mul = true; - entry.muldiv_selector = true; - } - ArithOp::Div => { - entry.op_divrem = true; - entry.signed = true; - } - ArithOp::DivUnsigned => { - entry.op_divrem = true; - } - ArithOp::Remainder => { - entry.op_divrem = true; - entry.muldiv_selector = true; - entry.signed = true; - } - ArithOp::RemainderUnsigned => { - entry.op_divrem = true; - entry.muldiv_selector = true; - } +/// Memory-width bits `(mem_2B, mem_4B, mem_8B)` for STORE (1 byte = none set). +fn store_width_bits(width: LoadStoreWidth) -> (bool, bool, bool) { + match width { + LoadStoreWidth::Byte | LoadStoreWidth::ByteUnsigned => (false, false, false), + LoadStoreWidth::Half | LoadStoreWidth::HalfUnsigned => (true, false, false), + LoadStoreWidth::Word | LoadStoreWidth::WordUnsigned => (false, true, false), + LoadStoreWidth::DoubleWord => (false, false, true), + } +} + +/// Memory-width bits `(mem_2B, mem_4B, mem_8B, mem_signed)` for LOAD. +/// `mem_signed = ¬[U]`; the full-width `LD` is not sign-extended. +fn load_width_bits(width: LoadStoreWidth) -> (bool, bool, bool, bool) { + match width { + LoadStoreWidth::Byte => (false, false, false, true), + LoadStoreWidth::ByteUnsigned => (false, false, false, false), + LoadStoreWidth::Half => (true, false, false, true), + LoadStoreWidth::HalfUnsigned => (true, false, false, false), + LoadStoreWidth::Word => (false, true, false, true), + LoadStoreWidth::WordUnsigned => (false, true, false, false), + LoadStoreWidth::DoubleWord => (false, false, true, false), + } +} + +/// `(alu_op, signed, invert)` for a branch comparison, per `spec/decode.typ`. +fn branch_cond_flags(cond: Comparison) -> (u8, bool, bool) { + match cond { + Comparison::Equal => (alu_op::EQ, false, false), + Comparison::NotEqual => (alu_op::EQ, false, true), + Comparison::LessThan => (alu_op::LT, true, false), + Comparison::LessThanUnsigned => (alu_op::LT, false, false), + Comparison::GreaterOrEqual => (alu_op::LT, true, true), + Comparison::GreaterOrEqualUnsigned => (alu_op::LT, false, true), + } +} + +// ========================================================================= +// DecodeEntry - Shared decode information for CPU and DECODE tables +// ========================================================================= + +/// A single decoded instruction entry. +/// +/// This struct contains all static decode-time information extracted from an instruction. +/// It is shared between the CPU table (which uses it for execution) and the DECODE table +/// (which provides it as a lookup table). +/// +/// ## Usage +/// +/// - **CPU table**: `CpuOperation` contains a `DecodeEntry` plus runtime values (rv1, rv2, etc.) +/// - **DECODE table**: Stores `DecodeEntry` directly, with multiplicity tracking +/// +/// The packed decode layout is defined by [`packed_decode_shrunk`] and produced +/// by [`ShrunkDecode::pack`]; consult those for the bit positions of every flag, +/// the ALU/MEM flag bytes, and the rs1/rs2/rd register indices. +#[derive(Debug, Clone, Default, Hash, PartialEq, Eq)] +pub struct DecodeEntry { + /// Program counter (64-bit). + pub pc: u64, + /// Fully sign-extended 64-bit immediate. + pub imm: u64, + /// Packed decode flags + register indices. + pub fields: ShrunkDecode, +} + +impl DecodeEntry { + /// Creates an empty DecodeEntry. + pub fn new() -> Self { + Self::default() + } + + /// Padding row for the DECODE/CPU tables: an odd PC (never a valid fetch + /// target, hence unprovable) with all flags zero. Replaces the old + /// EBREAK-based padding (EBREAK has no decoding in this layout). + pub fn padding_entry() -> Self { + Self { + pc: 1, + imm: 0, + fields: ShrunkDecode::default(), } } - /// Helper to set memory width flags (exclusive encoding per spec). - /// - /// Memory width uses exclusive flags ("exactly N bytes"): - /// - 1 byte: no flags - /// - 2 bytes: memory_2bytes = true - /// - 4 bytes: memory_4bytes = true - /// - 8 bytes: memory_8bytes = true - fn set_memory_width(entry: &mut Self, width: LoadStoreWidth) { - match width { - LoadStoreWidth::Byte | LoadStoreWidth::ByteUnsigned => { - // 1 byte - no flags set - } - LoadStoreWidth::Half | LoadStoreWidth::HalfUnsigned => { - entry.memory_2bytes = true; - } - LoadStoreWidth::Word | LoadStoreWidth::WordUnsigned => { - entry.memory_4bytes = true; - } - LoadStoreWidth::DoubleWord => { - entry.memory_8bytes = true; - } + /// Packs the decode fields into the `packed_decode` field-element value. + pub fn packed_decode(&self) -> u64 { + self.fields.pack() + } + + /// Decode an instruction into `(pc, imm, fields)`. `instruction_length` is + /// 2 (RV64C compressed) or 4. + pub fn from_instruction(pc: u64, instruction: Instruction, instruction_length: u8) -> Self { + Self { + pc, + imm: imm_from_instruction(instruction), + fields: ShrunkDecode::from_instruction(instruction, instruction_length), + } + } +} + +/// The fully sign-extended 64-bit immediate for an instruction (0 when none). +fn imm_from_instruction(instruction: Instruction) -> u64 { + match instruction { + Instruction::ArithImm { imm, .. } | Instruction::ArithImmW { imm, .. } => imm as i64 as u64, + Instruction::JumpAndLink { offset, .. } + | Instruction::JumpAndLinkRegister { offset, .. } + | Instruction::Store { offset, .. } + | Instruction::Load { offset, .. } + | Instruction::Branch { offset, .. } => offset as i64 as u64, + Instruction::LoadUpperImm { imm, .. } | Instruction::AddUpperImmToPc { imm, .. } => { + (imm as i32) as i64 as u64 } + _ => 0, } } diff --git a/prover/src/test_utils.rs b/prover/src/test_utils.rs index 1b608034c..b4c93f27d 100644 --- a/prover/src/test_utils.rs +++ b/prover/src/test_utils.rs @@ -37,6 +37,9 @@ use crate::tables::bitwise::{ use crate::tables::branch::{ branch_constraints, bus_interactions as branch_bus_interactions, cols as branch_cols, }; +use crate::tables::bytewise::{ + bus_interactions as bytewise_bus_interactions, cols as bytewise_cols, +}; use crate::tables::commit::{ bus_interactions as commit_bus_interactions, cols as commit_cols, create_constraints as commit_constraints, @@ -44,10 +47,14 @@ use crate::tables::commit::{ use crate::tables::cpu::{ CpuOperation, bus_interactions as cpu_bus_interactions, cols as cpu_cols, }; +use crate::tables::cpu32::{ + bus_interactions as cpu32_bus_interactions, cols as cpu32_cols, cpu32_constraints, +}; use crate::tables::decode::{bus_interactions as decode_bus_interactions, cols as decode_cols}; use crate::tables::dvrm::{ bus_interactions as dvrm_bus_interactions, cols as dvrm_cols, dvrm_constraints, }; +use crate::tables::eq::{bus_interactions as eq_bus_interactions, cols as eq_cols, eq_constraints}; use crate::tables::halt::{bus_interactions as halt_bus_interactions, cols as halt_cols}; use crate::tables::keccak::{bus_interactions as keccak_bus_interactions, cols as keccak_cols}; use crate::tables::keccak_rc::{ @@ -79,6 +86,9 @@ use crate::tables::register::{ use crate::tables::shift::{ bus_interactions as shift_bus_interactions, cols as shift_cols, shift_constraints, }; +use crate::tables::store::{ + bus_interactions as store_bus_interactions, cols as store_cols, store_constraints, +}; use crate::tables::types::{GoldilocksExtension, GoldilocksField}; pub type F = GoldilocksField; @@ -408,7 +418,7 @@ pub fn generate_minimal_bitwise_trace(ops: &[BitwiseOperation]) -> TraceTable = HashMap::new(); + let mut row_data: HashMap<(u8, u8, u8), [u64; 13]> = HashMap::new(); for op in ops { let key = (op.x, op.y, op.z); @@ -423,8 +433,11 @@ pub fn generate_minimal_bitwise_trace(ops: &[BitwiseOperation]) -> TraceTable 7, BitwiseOperationType::IsB20 => 8, BitwiseOperationType::Hwsl => 9, + BitwiseOperationType::ByteAluAnd => 10, + BitwiseOperationType::ByteAluOr => 11, + BitwiseOperationType::ByteAluXor => 12, }; - row_data.entry(key).or_insert([0; 10])[mu_idx] += 1; + row_data.entry(key).or_insert([0; 13])[mu_idx] += 1; } // Need at least 4 rows for FRI, pad to power of 2 @@ -482,6 +495,9 @@ pub fn generate_minimal_bitwise_trace(ops: &[BitwiseOperation]) -> TraceTable VmAir { .with_name("SHIFT") } +/// Create the EQ AIR. +pub fn create_eq_air(proof_options: &ProofOptions) -> VmAir { + let (transition_constraints, _) = eq_constraints(0); + let auxiliary_trace_build_data = AuxiliaryTraceBuildData { + interactions: eq_bus_interactions(), + }; + AirWithBuses::new( + eq_cols::NUM_COLUMNS, + auxiliary_trace_build_data, + proof_options, + 1, + transition_constraints, + ) + .with_name("EQ") +} + +/// Create the BYTEWISE AIR. No polynomial constraints. +pub fn create_bytewise_air(proof_options: &ProofOptions) -> VmAir { + let transition_constraints: Vec>> = vec![]; + let auxiliary_trace_build_data = AuxiliaryTraceBuildData { + interactions: bytewise_bus_interactions(), + }; + AirWithBuses::new( + bytewise_cols::NUM_COLUMNS, + auxiliary_trace_build_data, + proof_options, + 1, + transition_constraints, + ) + .with_name("BYTEWISE") +} + +/// Create the STORE AIR. +pub fn create_store_air(proof_options: &ProofOptions) -> VmAir { + let (transition_constraints, _) = store_constraints(0); + let auxiliary_trace_build_data = AuxiliaryTraceBuildData { + interactions: store_bus_interactions(), + }; + AirWithBuses::new( + store_cols::NUM_COLUMNS, + auxiliary_trace_build_data, + proof_options, + 1, + transition_constraints, + ) + .with_name("STORE") +} + +/// Create the CPU32 AIR. +pub fn create_cpu32_air(proof_options: &ProofOptions) -> VmAir { + let (transition_constraints, _) = cpu32_constraints(0); + let auxiliary_trace_build_data = AuxiliaryTraceBuildData { + interactions: cpu32_bus_interactions(), + }; + AirWithBuses::new( + cpu32_cols::NUM_COLUMNS, + auxiliary_trace_build_data, + proof_options, + 1, + transition_constraints, + ) + .with_name("CPU32") +} + /// Create MEMW AIR with constraints and bus interactions. pub fn create_memw_air(proof_options: &ProofOptions) -> VmAir { let transition_constraints = memw_constraints(); diff --git a/prover/src/tests/bitwise_tests.rs b/prover/src/tests/bitwise_tests.rs index 8337f8bf7..f02748939 100644 --- a/prover/src/tests/bitwise_tests.rs +++ b/prover/src/tests/bitwise_tests.rs @@ -4,9 +4,10 @@ use crate::tables::bitwise::{ NUM_PRECOMPUTED_COLS, NUM_ROWS, bus_interactions, cols, generate_bitwise_row, generate_bitwise_trace, is_preprocessed, preprocessed_commitment, row_index, }; -use crate::tables::types::FE; +use crate::tables::types::{BusId, FE}; use crate::test_utils::multi_prove_ram; use math::field::element::FieldElement; +use stark::lookup::Multiplicity; use stark::proof::options::ProofOptions; #[test] @@ -95,8 +96,43 @@ fn test_zero_check() { #[test] fn test_bus_interactions_count() { let interactions = bus_interactions(); - // Should have 10 interactions (one per lookup type; HWSLC merged into HWSL) - assert_eq!(interactions.len(), 10); + // 10 legacy lookups (one per type; HWSLC merged into HWSL) + 3 BYTE_ALU + // receivers (opsel AND/OR/XOR). + assert_eq!(interactions.len(), 13); +} + +#[test] +fn test_byte_alu_receivers() { + let byte_alu: Vec<_> = bus_interactions() + .into_iter() + .filter(|i| i.bus_id == u64::from(BusId::ByteAlu)) + .collect(); + + // One receiver per opsel (AND/OR/XOR), each carrying [opsel, X, Y, out]. + assert_eq!(byte_alu.len(), 3); + for interaction in &byte_alu { + assert!(!interaction.is_sender, "BYTE_ALU lookups are receivers"); + assert_eq!(interaction.values.len(), 4, "[opsel, X, Y, out]"); + } + + // Each opsel uses its own multiplicity column, reusing the precomputed + // AND/OR/XOR result columns. + let mut mu_columns: Vec = byte_alu + .iter() + .map(|i| match i.multiplicity { + Multiplicity::Column(c) => c, + _ => panic!("BYTE_ALU multiplicity must be a column"), + }) + .collect(); + mu_columns.sort_unstable(); + assert_eq!( + mu_columns, + vec![ + cols::MU_BYTE_ALU_AND, + cols::MU_BYTE_ALU_OR, + cols::MU_BYTE_ALU_XOR + ] + ); } #[test] diff --git a/prover/src/tests/branch_constraints_tests.rs b/prover/src/tests/branch_constraints_tests.rs index 2fd1fead0..af0b3aadb 100644 --- a/prover/src/tests/branch_constraints_tests.rs +++ b/prover/src/tests/branch_constraints_tests.rs @@ -17,23 +17,26 @@ use stark::constraints::transition::TransitionConstraint; fn test_branch_constraint_degree() { let (constraints, _) = branch_constraints(0); - // All 4 conditional carry IS_BIT constraints have degree 3: - // cond (degree 1) * carry (degree 1) * (1 - carry) (degree 1) - for c in &constraints { + // The 4 conditional carry IS_BIT constraints have degree 3: + // cond (degree 1) * carry (degree 1) * (1 - carry) (degree 1) + // and the IS_BIT constraint has degree 2: JALR * (1 - JALR). + for c in &constraints[..4] { assert_eq!(c.degree(), 3); } + assert_eq!(constraints[4].degree(), 2); } #[test] fn test_branch_constraint_indices_unique() { let (constraints, next_idx) = branch_constraints(0); - assert_eq!(constraints.len(), 4); + assert_eq!(constraints.len(), 5); assert_eq!(constraints[0].constraint_idx(), 0); assert_eq!(constraints[1].constraint_idx(), 1); assert_eq!(constraints[2].constraint_idx(), 2); assert_eq!(constraints[3].constraint_idx(), 3); - assert_eq!(next_idx, 4); + assert_eq!(constraints[4].constraint_idx(), 4); + assert_eq!(next_idx, 5); } #[test] @@ -44,7 +47,8 @@ fn test_branch_constraint_indices_with_offset() { assert_eq!(constraints[1].constraint_idx(), 11); assert_eq!(constraints[2].constraint_idx(), 12); assert_eq!(constraints[3].constraint_idx(), 13); - assert_eq!(next_idx, 14); + assert_eq!(constraints[4].constraint_idx(), 14); + assert_eq!(next_idx, 15); } // ========================================================================= diff --git a/prover/src/tests/bytewise_tests.rs b/prover/src/tests/bytewise_tests.rs new file mode 100644 index 000000000..ac534cdc2 --- /dev/null +++ b/prover/src/tests/bytewise_tests.rs @@ -0,0 +1,89 @@ +//! Tests for the BYTEWISE ALU table. + +use crate::tables::bytewise::{BytewiseOperation, bus_interactions, cols, generate_bytewise_trace}; +use crate::tables::types::{BusId, FE, alu_op}; + +#[test] +fn test_compute_res() { + let a = 0xFF00u64; + let b = 0x0FF0u64; + assert_eq!( + BytewiseOperation::new(a, b, alu_op::AND).compute_res(), + 0x0F00 + ); + assert_eq!( + BytewiseOperation::new(a, b, alu_op::OR).compute_res(), + 0xFFF0 + ); + assert_eq!( + BytewiseOperation::new(a, b, alu_op::XOR).compute_res(), + 0xF0F0 + ); +} + +#[test] +fn test_trace_byte_decomposition() { + // a XOR b across all 8 bytes. + let a = 0x1122_3344_5566_7788u64; + let b = 0x00FF_00FF_00FF_00FFu64; + let trace = generate_bytewise_trace(&[BytewiseOperation::new(a, b, alu_op::XOR)]); + assert_eq!(trace.main_table.width, cols::NUM_COLUMNS); + assert_eq!(trace.main_table.height, 4); // padded to min 4 + + let row = trace.main_table.get_row(0); + // Little-endian: byte 0 is the least significant. + assert_eq!(row[cols::A[0]], FE::from(0x88u64)); + assert_eq!(row[cols::A[7]], FE::from(0x11u64)); + assert_eq!(row[cols::B[0]], FE::from(0xFFu64)); + assert_eq!(row[cols::OP], FE::from(alu_op::XOR as u64)); + // res byte 0 = 0x88 ^ 0xFF = 0x77 + assert_eq!(row[cols::RES[0]], FE::from(0x77u64)); + // res byte 7 = 0x11 ^ 0x00 = 0x11 + assert_eq!(row[cols::RES[7]], FE::from(0x11u64)); + assert_eq!(row[cols::MU], FE::from(1u64)); +} + +#[test] +fn test_multiplicity_aggregation() { + let ops = vec![ + BytewiseOperation::new(1, 2, alu_op::AND), + BytewiseOperation::new(3, 4, alu_op::OR), + BytewiseOperation::new(1, 2, alu_op::AND), + ]; + let trace = generate_bytewise_trace(&ops); + assert_eq!(trace.main_table.height, 4); + + let mut found = false; + for row_idx in 0..4 { + let row = trace.main_table.get_row(row_idx); + if row[cols::A[0]] == FE::from(1u64) + && row[cols::B[0]] == FE::from(2u64) + && row[cols::OP] == FE::from(alu_op::AND as u64) + { + assert_eq!(row[cols::MU], FE::from(2u64)); + found = true; + } + } + assert!(found, "expected the (1, 2, AND) row with multiplicity 2"); +} + +#[test] +fn test_bus_interactions_shape() { + let interactions = bus_interactions(); + // 8 BYTE_ALU senders + 1 ALU receiver. + assert_eq!(interactions.len(), 9); + + let byte_alu_senders = interactions + .iter() + .filter(|i| i.bus_id == u64::from(BusId::ByteAlu) && i.is_sender) + .count(); + assert_eq!(byte_alu_senders, 8); + + let alu: Vec<_> = interactions + .iter() + .filter(|i| i.bus_id == u64::from(BusId::Alu)) + .collect(); + assert_eq!(alu.len(), 1); + assert!(!alu[0].is_sender, "ALU is a receiver for BYTEWISE"); + assert_eq!(alu[0].values.len(), 4); // [a, b, op, res] +} diff --git a/prover/src/tests/constraints_tests.rs b/prover/src/tests/constraints_tests.rs index e48f73d67..e52cc6c0e 100644 --- a/prover/src/tests/constraints_tests.rs +++ b/prover/src/tests/constraints_tests.rs @@ -513,132 +513,104 @@ fn test_dword_bl_repack_formula() { // ========================================================================= use crate::constraints::cpu::{ - Arg1LowerConstraint, Arg1UpperConstraint, BIT_FLAG_COLUMNS, BranchCondConstraint, - EbreakConstraint, ExtBitZeroConstraint, NUM_CPU_CONSTRAINTS, NextPcAddConstraint, + Arg2Constraint, BIT_FLAG_COLUMNS, BranchCondConstraint, NUM_CPU_CONSTRAINTS, + NextPcAddConstraint, ProductZeroConstraint, RegNotReadIsZeroConstraint, RvdEqResConstraint, create_add_constraints, create_all_cpu_constraints, create_is_bit_constraints, - create_slt_res_zero_constraints, + create_sub_constraints, }; - use crate::tables::cpu::cols as cpu_cols; #[test] fn test_cpu_bit_flag_columns_count() { - // Should have 34 bit flag columns (includes read_register1, read_register2, inline-pc columns) - assert_eq!(BIT_FLAG_COLUMNS.len(), 34); + // 10 top-level flags + pc_double_read + prev_pc_timestamp_borrow + non_padding. + assert_eq!(BIT_FLAG_COLUMNS.len(), 12); } #[test] fn test_cpu_bit_flag_columns_valid() { - // All columns should be valid CPU column indices for &col in BIT_FLAG_COLUMNS { assert!(col < cpu_cols::NUM_COLUMNS, "Column {} out of range", col); } } #[test] -fn test_create_is_bit_constraints() { - let (constraints, next_idx) = create_is_bit_constraints(0); - - assert_eq!(constraints.len(), 34); - assert_eq!(next_idx, 34); - - // Check constraint indices are sequential - for (i, c) in constraints.iter().enumerate() { - assert_eq!(c.constraint_idx(), i); - } -} - -#[test] -fn test_create_add_constraints() { - let (constraints, next_idx) = create_add_constraints(0); - - // Should create 4 constraints: 2 for ADD+LOAD, 2 for STORE (res = arg1 + imm) - assert_eq!(constraints.len(), 4); - assert_eq!(next_idx, 4); - - assert_eq!(constraints[0].constraint_idx(), 0); - assert_eq!(constraints[1].constraint_idx(), 1); - assert_eq!(constraints[2].constraint_idx(), 2); - assert_eq!(constraints[3].constraint_idx(), 3); +fn test_create_is_bit_constraints_count() { + let (cs, next) = create_is_bit_constraints(0); + assert_eq!(cs.len(), BIT_FLAG_COLUMNS.len()); + assert_eq!(next, BIT_FLAG_COLUMNS.len()); } #[test] -fn test_create_slt_res_zero_constraints() { - let (constraints, next_idx) = create_slt_res_zero_constraints(0); - - // Should create 7 constraints (for bytes 1-7) - assert_eq!(constraints.len(), 7); - assert_eq!(next_idx, 7); - - for (i, c) in constraints.iter().enumerate() { - assert_eq!(c.constraint_idx(), i); - } +fn test_add_sub_constraint_pairs() { + let (add, next) = create_add_constraints(0); + assert_eq!(add.len(), 2, "ADD carry pair"); + let (sub, next2) = create_sub_constraints(next); + assert_eq!(sub.len(), 2, "SUB carry pair"); + assert_eq!(next2, next + 2, "constraint indices are contiguous"); } #[test] -fn test_branch_cond_constraint_degree() { - let c = BranchCondConstraint::new(0); - assert_eq!(c.degree(), 3); +fn test_product_zero_constraint_degree() { + // word_instr · MEMORY = 0 (decode mutex): degree 2. + let c = ProductZeroConstraint::new(cpu_cols::WORD_INSTR, cpu_cols::MEMORY, 0); + assert_eq!(c.degree(), 2); } #[test] -fn test_ebreak_constraint_degree() { - let c = EbreakConstraint::new(0); - assert_eq!(c.degree(), 1); +fn test_arg2_constraint_degree() { + // (1 - MEMORY - BRANCH)·(rv2 + imm): degree 2 (relies on the live + // MEMORY·BRANCH = 0 mutex). + assert_eq!(Arg2Constraint::new(0, 0).degree(), 2); + assert_eq!(Arg2Constraint::new(1, 0).degree(), 2); } #[test] -fn test_arg1_lower_constraint_degree() { - let c = Arg1LowerConstraint::new(0); - assert_eq!(c.degree(), 1); +fn test_rvd_eq_res_constraint_degree() { + // (1 - MEMORY - BRANCH)·(rvd[i] - cast(res, WL)[i]): degree 2. + // BRANCH rows are exempt — their rvd (`pc + len`) is pinned by + // BranchRvdConstraint instead. Well within the blowup=2 budget. + assert_eq!(RvdEqResConstraint::new(0, 0).degree(), 2); + assert_eq!(RvdEqResConstraint::new(1, 0).degree(), 2); } #[test] -fn test_arg1_upper_constraint_degree() { - let c = Arg1UpperConstraint::new(0); - assert_eq!(c.degree(), 3); +fn test_branch_cond_constraint_degree() { + // branch_cond = BRANCH·JALR + BRANCH·(1-JALR)·res[0]: degree 3. + assert_eq!(BranchCondConstraint::new(0).degree(), 3); } #[test] -fn test_ext_bit_zero_constraint_degree() { - let c = ExtBitZeroConstraint::new(0, cpu_cols::RV1_EXT_BIT); +fn test_reg_not_read_is_zero_degree() { + let c = RegNotReadIsZeroConstraint::new(cpu_cols::READ_REGISTER1, cpu_cols::RV1_0, 0); assert_eq!(c.degree(), 2); } #[test] -fn test_next_pc_add_constraint_degree() { - let c = NextPcAddConstraint::new(0, 0); - assert_eq!(c.degree(), 3); -} - -#[test] -fn test_next_pc_add_constraint_new_pair() { - let (c0, c1) = NextPcAddConstraint::new_pair(10); - assert_eq!(c0.constraint_idx(), 10); - assert_eq!(c1.constraint_idx(), 11); +fn test_next_pc_add_constraint() { + let (c0, c1) = NextPcAddConstraint::new_pair(5); + assert_eq!(c0.degree(), 3); + assert_eq!(c1.degree(), 3); + assert_eq!(c0.constraint_idx(), 5); + assert_eq!(c1.constraint_idx(), 6); } #[test] -fn test_create_all_cpu_constraints() { +fn test_create_all_cpu_constraints_count() { let (is_bit, add, other, total) = create_all_cpu_constraints(); - - assert_eq!(is_bit.len(), 34); - // ADD constraints: 2 (ADD+LOAD) + 2 (STORE: arg1+imm) + 2 (SUB+BEQ) + 2 (JALR) = 8 - assert_eq!(add.len(), 8); - // Other: branch_cond(1) + ebreak(1) + rv1_zero_forcing(3) + rv2_zero_forcing(3) + arg1(2) + arg2(2) + rvd(2) + slt_zero(7) + ext_bit_zero(3) + next_pc(2) = 26 - assert_eq!(other.len(), 26); - - // Total should be 34 + 8 + 26 = 68 - assert_eq!(total, 68); + // IS_BIT: 12, ADD+SUB pairs: 4, other (mutex 6 + arg2 2 + reg-zero 4 + rvd 2 + // + branch rvd 2 + branch_cond 1 + next_pc 2 + assumptions 4): 23. + assert_eq!(is_bit.len(), 12); + assert_eq!(add.len(), 4); + assert_eq!(other.len(), 23); assert_eq!(total, NUM_CPU_CONSTRAINTS); + assert_eq!(is_bit.len() + add.len() + other.len(), NUM_CPU_CONSTRAINTS); } #[test] -fn test_cpu_constraint_indices_are_unique() { +fn test_cpu_constraint_indices_are_unique_and_sequential() { let (is_bit, add, other, _) = create_all_cpu_constraints(); let mut indices: Vec = Vec::new(); - for c in &is_bit { indices.push(c.constraint_idx()); } @@ -649,19 +621,8 @@ fn test_cpu_constraint_indices_are_unique() { indices.push(c.constraint_idx()); } - // Check no duplicates - indices.sort(); - for i in 1..indices.len() { - assert_ne!( - indices[i], - indices[i - 1], - "Duplicate constraint index: {}", - indices[i] - ); - } - - // Check sequential + indices.sort_unstable(); for (i, &idx) in indices.iter().enumerate() { - assert_eq!(idx, i, "Expected index {} but got {}", i, idx); + assert_eq!(idx, i, "constraint indices must be unique and cover 0..N"); } } diff --git a/prover/src/tests/cpu32_tests.rs b/prover/src/tests/cpu32_tests.rs new file mode 100644 index 000000000..f055b2ceb --- /dev/null +++ b/prover/src/tests/cpu32_tests.rs @@ -0,0 +1,260 @@ +//! Tests for the CPU32 table — column layout, sign-extension aux math, and the +//! sign-extension / register-zero constraints. + +use crate::tables::cpu32::{ + Cpu32Constraint, Cpu32ConstraintKind, Cpu32Operation, bus_interactions, cols, + generate_cpu32_trace, +}; +use crate::tables::types::{ + BusId, FE, GoldilocksExtension, GoldilocksField, alu_op, build_alu_flags, +}; +use stark::constraints::transition::TransitionConstraint; +use stark::table::TableView; + +#[test] +fn test_aux_signed_input_extension() { + // Signed op (signed bit set in alu_flags) with a negative low word. + let op = Cpu32Operation { + rv1: 0x8000_0000, // bit 31 set → negative as i32 + alu_flags: build_alu_flags(alu_op::SHIFTW, true, true, false), // signed = true + ..Default::default() + }; + let aux = op.compute_aux(); + assert!(aux.signed); + assert!(aux.rv1_sign); + // arg1 sign-extended: high word all ones. + assert_eq!(aux.arg1, 0xFFFF_FFFF_8000_0000); +} + +#[test] +fn test_aux_unsigned_input_zero_extension() { + // Unsigned op (signed bit clear) with the same low word → zero-extended. + let op = Cpu32Operation { + rv1: 0x8000_0000, + alu_flags: build_alu_flags(alu_op::SHIFTW, false, false, false), // signed = false + ..Default::default() + }; + let aux = op.compute_aux(); + assert!(!aux.signed); + assert_eq!(aux.arg1, 0x0000_0000_8000_0000); +} + +#[test] +fn test_aux_arg2_from_immediate() { + // Immediate path: rv2 = 0, imm fully sign-extended. + let op = Cpu32Operation { + rv2: 0, + read_register2: false, + imm: 0xFFFF_FFFF_FFFF_FF00, + alu_flags: build_alu_flags(alu_op::SHIFTW, true, false, false), + ..Default::default() + }; + let aux = op.compute_aux(); + assert_eq!(aux.arg2, 0xFFFF_FFFF_FFFF_FF00); +} + +#[test] +fn test_aux_arg2_from_register() { + // Register path: imm = 0, rv2 negative, signed → sign-extended rv2. + let op = Cpu32Operation { + rv2: 0x8000_0001, + read_register2: true, + imm: 0, + alu_flags: build_alu_flags(alu_op::SHIFTW, true, true, false), // signed + ..Default::default() + }; + let aux = op.compute_aux(); + assert!(aux.rv2_sign); + assert_eq!(aux.arg2, 0xFFFF_FFFF_8000_0001); +} + +#[test] +fn test_aux_rvd_always_sign_extended() { + // rvd is always sign-extended from the low 32 bits of res, regardless of `signed`. + let op = Cpu32Operation { + res: 0x0000_0000_8000_0000, // low word negative + alu_flags: build_alu_flags(alu_op::SHIFTW, false, false, false), // unsigned op + ..Default::default() + }; + let aux = op.compute_aux(); + assert!(aux.res_sign); + assert_eq!(aux.rvd, 0xFFFF_FFFF_8000_0000); + + // Positive low word → zero high word. + let op2 = Cpu32Operation { + res: 0x0000_0000_0000_0001, + ..Default::default() + }; + assert_eq!(op2.compute_aux().rvd, 0x0000_0000_0000_0001); +} + +#[test] +fn test_trace_layout() { + let op = Cpu32Operation { + timestamp: 0x1234, + pc: 0xABCD, + rs1: 3, + read_register1: true, + rv1: 0x1122_3344_5566_7788, + rs2: 5, + read_register2: true, + rv2: 0x9900, + rd: 7, + write_register: true, + res: 0x42, + alu: true, + alu_flags: build_alu_flags(alu_op::SHIFTW, true, true, false), + half_instruction_length: 2, + ..Default::default() + }; + let trace = generate_cpu32_trace(&[op]); + assert_eq!(trace.main_table.width, cols::NUM_COLUMNS); + assert_eq!(trace.main_table.height, 4); // padded to min 4 + + let row = trace.main_table.get_row(0); + assert_eq!(row[cols::PC_0], FE::from(0xABCDu64)); + assert_eq!(row[cols::RS1], FE::from(3u64)); + // rv1 as DWordWHH: half0, half1, word. + assert_eq!(row[cols::RV1_0], FE::from(0x7788u64)); + assert_eq!(row[cols::RV1_1], FE::from(0x5566u64)); + assert_eq!(row[cols::RV1_2], FE::from(0x1122_3344u64)); + assert_eq!(row[cols::RD], FE::from(7u64)); + assert_eq!(row[cols::HALF_INSTRUCTION_LENGTH], FE::from(2u64)); + assert_eq!(row[cols::SIGNED], FE::from(1u64)); + assert_eq!(row[cols::MU], FE::from(1u64)); +} + +/// Build a single-row `TableView` from a CPU32 trace generated for `op`. +fn view_for(op: Cpu32Operation) -> TableView { + let trace = generate_cpu32_trace(&[op]); + let row = trace.main_table.get_row(0).to_vec(); + TableView::new(vec![row], vec![vec![]]) +} + +#[test] +fn test_ext_and_regzero_constraints_hold_on_valid_row() { + // A signed word op via the immediate path (read_register2 = 0, rv2 = 0). + let op = Cpu32Operation { + rv1: 0x8000_0001, // negative low word + read_register1: true, + rv2: 0, + read_register2: false, + imm: 0xFFFF_FFFF_FFFF_FFF0, + res: 0x0000_0000_1234_5678, + rd: 5, + write_register: true, + alu: true, + alu_flags: build_alu_flags(alu_op::SHIFTW, true, true, false), // signed + half_instruction_length: 2, + ..Default::default() + }; + let view = view_for(op); + + // All sign-extension arithmetic constraints evaluate to zero. + for kind in [ + Cpu32ConstraintKind::Arg1Lo, + Cpu32ConstraintKind::Arg1Hi, + Cpu32ConstraintKind::Arg2Lo, + Cpu32ConstraintKind::Arg2Hi, + Cpu32ConstraintKind::RvdLo, + Cpu32ConstraintKind::RvdHi, + ] { + let c = Cpu32Constraint::new(kind, 0); + assert_eq!(c.evaluate(&view), FE::zero(), "{kind:?} must hold"); + } + + // Register-zero checks: read_register1=1 ⇒ trivially 0; read_register2=0 with rv2=0 ⇒ 0. + for (read_col, value_col) in [ + (cols::READ_REGISTER1, cols::RV1_0), + (cols::READ_REGISTER1, cols::RV1_1), + (cols::READ_REGISTER2, cols::RV2_0), + (cols::READ_REGISTER2, cols::RV2_1), + ] { + let c = Cpu32Constraint::new( + Cpu32ConstraintKind::RegZero { + read_col, + value_col, + }, + 0, + ); + assert_eq!(c.evaluate(&view), FE::zero()); + } +} + +#[test] +fn test_constraints_catch_corruption() { + let op = Cpu32Operation { + rv1: 0x8000_0001, + read_register1: true, + res: 0x0000_0000_8000_0000, + write_register: true, + alu: true, + alu_flags: build_alu_flags(alu_op::SHIFTW, true, true, false), + half_instruction_length: 2, + ..Default::default() + }; + let trace = generate_cpu32_trace(&[op]); + + // Corrupt arg1[1] (the sign-extended high word) → Arg1Hi must fire. + let mut row = trace.main_table.get_row(0).to_vec(); + row[cols::ARG1_1] = &row[cols::ARG1_1] + FE::one(); + let bad: TableView = + TableView::new(vec![row], vec![vec![]]); + let c = Cpu32Constraint::new(Cpu32ConstraintKind::Arg1Hi, 0); + assert_ne!( + c.evaluate(&bad), + FE::zero(), + "Arg1Hi should catch a bad arg1[1]" + ); + + // read_register1 = 1 but a non-zero unread half would only matter when 0; + // instead corrupt with read=0 case: a value present while read flag cleared. + let op2 = Cpu32Operation { + rv2: 0x1234, // non-zero + read_register2: false, // but flagged unread + ..Default::default() + }; + let view2 = view_for(op2); + let c2 = Cpu32Constraint::new( + Cpu32ConstraintKind::RegZero { + read_col: cols::READ_REGISTER2, + value_col: cols::RV2_0, + }, + 0, + ); + assert_ne!( + c2.evaluate(&view2), + FE::zero(), + "RegZero should catch rv2≠0 when unread" + ); +} + +#[test] +fn test_bus_interactions_shape() { + let interactions = bus_interactions(); + assert_eq!(interactions.len(), 23); + + let count = |bus: BusId, sender: bool| { + interactions + .iter() + .filter(|i| i.bus_id == u64::from(bus) && i.is_sender == sender) + .count() + }; + + assert_eq!(count(BusId::Decode, true), 1); + assert_eq!(count(BusId::AreBytes, true), 5); + assert_eq!(count(BusId::IsHalfword, true), 8); + assert_eq!(count(BusId::Memw, true), 3); // rv1 read, rv2 read, rvd write + assert_eq!(count(BusId::Alu, true), 1); + assert_eq!(count(BusId::ByteAlu, true), 1); + assert_eq!(count(BusId::Msb16, true), 3); + + // CPU32 is a receiver (the main CPU sends the delegation). + let cpu32: Vec<_> = interactions + .iter() + .filter(|i| i.bus_id == u64::from(BusId::Cpu32)) + .collect(); + assert_eq!(cpu32.len(), 1); + assert!(!cpu32[0].is_sender, "CPU32 receives from the main CPU"); + assert_eq!(cpu32[0].values.len(), 3); // [timestamp, pc, instruction_length] +} diff --git a/prover/src/tests/cpu_tests.rs b/prover/src/tests/cpu_tests.rs index f05d1005c..3381d1821 100644 --- a/prover/src/tests/cpu_tests.rs +++ b/prover/src/tests/cpu_tests.rs @@ -1,484 +1,364 @@ //! Tests for the CPU table. //! -//! This module contains: -//! - Unit tests for CpuOperation struct and its methods -//! - Trace generation tests -//! - Integration tests for CpuOperation::from_log (ELF execution) +//! Unit tests for the reworked `CpuOperation::from_log` (arg2 multiplex, res, +//! rvd, branch decision, word-instruction delegation), `generate_cpu_trace` +//! (column layout, padding, word-row masking), and `collect_bitwise_ops`. -use crate::tables::cpu::{CpuOperation, bus_interactions, cols, generate_cpu_trace}; -use crate::tables::trace_builder::Traces; -use crate::tables::types::{DecodeEntry, FE}; +use crate::tables::cpu::{CPU_PADDING_PC, CpuOperation, cols, generate_cpu_trace}; +use crate::tables::types::DecodeEntry; -use executor::{ - elf::Elf, - vm::{execution::Executor, instruction::decoding::Instruction, memory::U64HashMap}, +use executor::vm::{ + instruction::decoding::{ArithOp, Comparison, Instruction, LoadStoreWidth}, + logs::Log, }; -/// Helper to create 4 operations from a template (required for power-of-2 trace). -fn ops4(op: CpuOperation) -> Vec { - (0..4) - .map(|i| { - let mut new_op = op.clone(); - new_op.timestamp = (i as u64) * 4 + 4; - new_op.decode.pc = op.decode.pc + (i as u64) * 4; - new_op.next_pc = op.decode.pc + (i as u64) * 4 + 4; - new_op - }) - .collect() -} - -#[test] -fn test_cpu_operation_default() { - let op = CpuOperation::new(); - assert_eq!(op.timestamp, 0); - assert_eq!(op.decode.pc, 0); - assert!(!op.decode.op_add); - assert!(!op.branch_cond); -} - -#[test] -fn test_cpu_operation_compute_arg1_no_extension() { - let mut op = CpuOperation::new(); - op.rv1 = 0x1234_5678_9ABC_DEF0; - op.decode.word_instr = false; - - assert_eq!(op.compute_arg1(), 0x1234_5678_9ABC_DEF0); -} - -#[test] -fn test_cpu_operation_compute_arg1_word_zero_extend() { - let mut op = CpuOperation::new(); - op.rv1 = 0x1234_5678_9ABC_DEF0; - op.decode.word_instr = true; - op.decode.signed = false; +const PC: u64 = 0x1000; - // Should zero-extend from lower 32 bits - assert_eq!(op.compute_arg1(), 0x9ABC_DEF0); -} - -#[test] -fn test_cpu_operation_compute_arg1_word_sign_extend_positive() { - let mut op = CpuOperation::new(); - op.rv1 = 0x1234_5678_1ABC_DEF0; // Positive 32-bit value - op.decode.word_instr = true; - op.decode.signed = true; - - // Bit 31 is 0, so sign extension keeps it positive - assert_eq!(op.compute_arg1(), 0x1ABC_DEF0); -} - -#[test] -fn test_cpu_operation_compute_arg1_word_sign_extend_negative() { - let mut op = CpuOperation::new(); - op.rv1 = 0x1234_5678_8000_0001; // Negative when viewed as 32-bit signed - op.decode.word_instr = true; - op.decode.signed = true; - - // Per spec constraint: arg1[4:] = (2^32-1) * rv1_sign_bit * signed - // For signed word instructions with sign bit set, arg1 is sign-extended. - assert_eq!(op.compute_arg1(), 0xFFFF_FFFF_8000_0001); +/// Build a CpuOperation from an instruction + register values. +fn op_of(instr: Instruction, src1: u64, src2: u64, dst: u64, next_pc: u64) -> CpuOperation { + let decode = DecodeEntry::from_instruction(PC, instr, 4); + let log = Log { + current_pc: PC, + next_pc, + src1_val: src1, + src2_val: src2, + dst_val: dst, + }; + CpuOperation::from_log(&log, 4, decode) } -#[test] -fn test_cpu_operation_compute_arg2_store() { - let mut op = CpuOperation::new(); - op.rv2 = 0xDEAD_BEEF; - op.decode.imm = 0x1234; - op.decode.op_store = true; - - // STORE: arg2 = rv2 (the data being stored) - // Address is computed separately as res = arg1 + imm - assert_eq!(op.compute_arg2(), 0xDEAD_BEEF); -} +// ========================================================================= +// from_log: arg2 multiplex, res, rvd, branch decision +// ========================================================================= #[test] -fn test_cpu_operation_compute_arg2_load() { - let mut op = CpuOperation::new(); - op.rv2 = 0xDEAD_BEEF; - op.decode.imm = 0x1234; - op.decode.op_load = true; - - // LOAD uses imm for address calculation (addr = rv1 + imm) - assert_eq!(op.compute_arg2(), 0x1234); +fn test_from_log_add_reg_reg() { + let op = op_of( + Instruction::Arith { + dst: 3, + src1: 1, + src2: 2, + op: ArithOp::Add, + }, + 10, + 20, + 30, + PC + 4, + ); + assert_eq!(op.rv1, 10); + assert_eq!(op.rv2, 20); + assert_eq!(op.arg2, 20, "reg-reg: arg2 = rv2 (imm = 0)"); + assert_eq!(op.res, 30, "res = rv1 + arg2"); + assert_eq!(op.rvd, 30, "rvd = res (not memory)"); + assert_eq!(op.next_pc, PC + 4); + assert!(!op.branch_cond); } #[test] -fn test_cpu_operation_compute_arg2_beq() { - let mut op = CpuOperation::new(); - op.rv2 = 0xCAFE_BABE; - op.decode.imm = 0x5678; - op.decode.op_beq = true; - - // BEQ uses rv2 - assert_eq!(op.compute_arg2(), 0xCAFE_BABE); +fn test_from_log_addi() { + let op = op_of( + Instruction::ArithImm { + dst: 3, + src: 1, + imm: 5, + op: ArithOp::Add, + }, + 10, + 0, + 15, + PC + 4, + ); + assert_eq!(op.arg2, 5, "reg-imm: arg2 = imm (rv2 = 0)"); + assert_eq!(op.res, 15); + assert_eq!(op.rvd, 15); } #[test] -fn test_cpu_operation_compute_arg2_add_with_imm() { - let mut op = CpuOperation::new(); - op.rv2 = 0; - op.decode.rs2 = 0; // rs2 = 0 means use immediate - op.decode.imm = 0x1234_5678; - op.decode.op_add = true; - - // ADD with rs2=0 uses imm - assert_eq!(op.compute_arg2(), 0x1234_5678); +fn test_from_log_sub() { + let op = op_of( + Instruction::Arith { + dst: 3, + src1: 1, + src2: 2, + op: ArithOp::Sub, + }, + 30, + 20, + 10, + PC + 4, + ); + assert_eq!(op.res, 10, "res = rv1 - arg2"); + assert_eq!(op.rvd, 10); } #[test] -fn test_cpu_operation_compute_arg2_add_with_rs2() { - let mut op = CpuOperation::new(); - op.rv2 = 0xABCD_EF00; - op.decode.rs2 = 5; // Non-zero rs2 - op.decode.imm = 0; // Per CPU-A2: when rs2 != 0, imm must be 0 - op.decode.op_add = true; - - // ADD with rs2 != 0: arg2 = rv2 + imm = rv2 + 0 = rv2 - assert_eq!(op.compute_arg2(), 0xABCD_EF00); +fn test_from_log_beq_taken() { + let op = op_of( + Instruction::Branch { + src1: 1, + src2: 2, + cond: Comparison::Equal, + offset: 8, + }, + 5, + 5, + 0, + PC + 8, + ); + assert!(op.branch_cond, "BEQ with equal operands is taken"); + assert_eq!(op.arg2, 5, "conditional branch: arg2 = rv2"); + assert_eq!(op.res, 1, "EQ result on the ALU bus is 1 when taken"); + assert_eq!(op.next_pc, PC + 8, "taken branch uses the executor next_pc"); } #[test] -fn test_sign_bit_32_positive() { - assert!(!CpuOperation::sign_bit_32(0x7FFF_FFFF)); - assert!(!CpuOperation::sign_bit_32(0x0000_0000)); - assert!(!CpuOperation::sign_bit_32(0x1234_5678)); +fn test_from_log_beq_not_taken() { + let op = op_of( + Instruction::Branch { + src1: 1, + src2: 2, + cond: Comparison::Equal, + offset: 8, + }, + 5, + 6, + 0, + PC + 4, + ); + assert!(!op.branch_cond); + assert_eq!(op.res, 0); + assert_eq!( + op.next_pc, + PC + 4, + "untaken branch falls through to pc + len" + ); } #[test] -fn test_sign_bit_32_negative() { - assert!(CpuOperation::sign_bit_32(0x8000_0000)); - assert!(CpuOperation::sign_bit_32(0xFFFF_FFFF)); - assert!(CpuOperation::sign_bit_32(0x8000_0001)); +fn test_from_log_bne_taken() { + let op = op_of( + Instruction::Branch { + src1: 1, + src2: 2, + cond: Comparison::NotEqual, + offset: 8, + }, + 5, + 6, + 0, + PC + 8, + ); + assert!( + op.branch_cond, + "BNE with differing operands is taken (invert)" + ); + assert_eq!(op.res, 1); } #[test] -fn test_trace_generation_basic() { - let ops = ops4(CpuOperation { - decode: DecodeEntry { - pc: 0x1000, - rs1: 1, - rs2: 2, - rd: 3, - write_register: true, - op_add: true, - ..Default::default() +fn test_from_log_load() { + let op = op_of( + Instruction::Load { + dst: 3, + offset: 4, + base: 1, + width: LoadStoreWidth::Word, }, - rv1: 10, - rv2: 20, - res: 30, - rvd: 30, - ..Default::default() - }); - - let trace = generate_cpu_trace(&ops); - - assert_eq!(trace.main_table.height, 4); - assert_eq!(trace.main_table.width, cols::NUM_COLUMNS); - - // Check first row values - let row0 = trace.main_table.get_row(0); - assert_eq!(row0[cols::TIMESTAMP], FE::from(4u64)); - assert_eq!(row0[cols::PC_0], FE::from(0x1000u64)); - assert_eq!(row0[cols::PC_1], FE::zero()); - assert_eq!(row0[cols::RS1], FE::from(1u64)); - assert_eq!(row0[cols::RS2], FE::from(2u64)); - assert_eq!(row0[cols::RD], FE::from(3u64)); - assert_eq!(row0[cols::WRITE_REGISTER], FE::one()); - assert_eq!(row0[cols::ADD], FE::one()); - assert_eq!(row0[cols::SUB], FE::zero()); + 0x100, + 0, + 0xDEAD, + PC + 4, + ); + assert_eq!(op.res, 0x104, "load address = rv1 + imm"); + assert_eq!(op.rvd, 0xDEAD, "load rvd = the loaded value"); } #[test] -fn test_trace_generation_64bit_pc() { - let ops = ops4(CpuOperation { - decode: DecodeEntry { - pc: 0x8000_0000_1234_5678, - op_add: true, - ..Default::default() +fn test_from_log_store() { + let op = op_of( + Instruction::Store { + src: 2, + offset: 8, + base: 1, + width: LoadStoreWidth::Word, }, - ..Default::default() - }); - - let trace = generate_cpu_trace(&ops); - let row0 = trace.main_table.get_row(0); - - // Check 64-bit PC is split correctly - assert_eq!(row0[cols::PC_0], FE::from(0x1234_5678u64)); - assert_eq!(row0[cols::PC_1], FE::from(0x8000_0000u64)); - // next_pc set by ops4 helper - assert_eq!(row0[cols::NEXT_PC_0], FE::from(0x1234_567Cu64)); - assert_eq!(row0[cols::NEXT_PC_1], FE::from(0x8000_0000u64)); + 0x100, + 0xAB, + 0, + PC + 4, + ); + assert_eq!(op.res, 0x108, "store address = rv1 + imm"); + assert_eq!(op.rv2, 0xAB, "store value comes from rs2"); + assert_eq!(op.rvd, 0, "store writes nothing back to rd"); } #[test] -fn test_trace_generation_rv1_dwordwhh() { - let ops = ops4(CpuOperation { - decode: DecodeEntry { - op_add: true, - ..Default::default() +fn test_from_log_word_carries_real_register_values() { + let op = op_of( + Instruction::ArithW { + dst: 3, + src1: 1, + src2: 2, + op: ArithOp::Add, }, - rv1: 0xFFFF_EEEE_DDDD_CCCCu64, - ..Default::default() - }); + 10, + 20, + 30, + PC + 4, + ); + assert!(op.decode.fields.word_instr); + // The delegate CpuOperation carries the real values for CPU32/register ops. + assert_eq!(op.rv1, 10); + assert_eq!(op.rv2, 20); + assert_eq!(op.rvd, 30); + assert_eq!(op.res, 0, "the main CPU delegate row computes no result"); + assert_eq!(op.next_pc, PC + 4); +} - let trace = generate_cpu_trace(&ops); - let row0 = trace.main_table.get_row(0); +// ========================================================================= +// generate_cpu_trace +// ========================================================================= - // rv1 stored as DWordWHH: [Half, Half, Word] - Word is MSB - assert_eq!(row0[cols::RV1_0], FE::from(0xCCCCu64)); // bits 0-15 (Half) - assert_eq!(row0[cols::RV1_1], FE::from(0xDDDDu64)); // bits 16-31 (Half) - assert_eq!(row0[cols::RV1_2], FE::from(0xFFFF_EEEEu64)); // bits 32-63 (Word) +fn ops4(instr: Instruction) -> Vec { + (0..4) + .map(|i| { + let decode = DecodeEntry::from_instruction(PC + i * 4, instr, 4); + let log = Log { + current_pc: PC + i * 4, + next_pc: PC + i * 4 + 4, + src1_val: 10, + src2_val: 20, + dst_val: 30, + }; + CpuOperation::from_log(&log, i * 4 + 4, decode) + }) + .collect() } #[test] -fn test_trace_generation_arg1_dwordbl() { - let ops = ops4(CpuOperation { - decode: DecodeEntry { - word_instr: false, - op_add: true, - ..Default::default() - }, - rv1: 0x0807_0605_0403_0201u64, - ..Default::default() +fn test_trace_width_and_real_row() { + let ops = ops4(Instruction::Arith { + dst: 3, + src1: 1, + src2: 2, + op: ArithOp::Add, }); - let trace = generate_cpu_trace(&ops); - let row0 = trace.main_table.get_row(0); - - // arg1 stored as DWordBL: 8 bytes - assert_eq!(row0[cols::ARG1_0], FE::from(0x01u64)); - assert_eq!(row0[cols::ARG1_1], FE::from(0x02u64)); - assert_eq!(row0[cols::ARG1_2], FE::from(0x03u64)); - assert_eq!(row0[cols::ARG1_3], FE::from(0x04u64)); - assert_eq!(row0[cols::ARG1_4], FE::from(0x05u64)); - assert_eq!(row0[cols::ARG1_5], FE::from(0x06u64)); - assert_eq!(row0[cols::ARG1_6], FE::from(0x07u64)); - assert_eq!(row0[cols::ARG1_7], FE::from(0x08u64)); + assert_eq!(trace.main_table.width, cols::NUM_COLUMNS); + assert_eq!(cols::NUM_COLUMNS, 38); + assert_eq!(trace.main_table.height, 4); + let row = trace.main_table.get_row(0); + assert_eq!(row[cols::PC_0], (PC).into()); + assert_eq!(row[cols::ADD], 1u64.into(), "ADD fast-path flag set"); + assert_eq!(row[cols::RES_0], 30u64.into()); } #[test] -fn test_trace_generation_res_dwordbl() { - // For op_add, compute_res() calculates arg1 + arg2 (not using self.res directly). - // Set rv1 to the desired result value since arg1 = rv1 when word_instr=false, - // and arg2 = 0 (imm default) when rs2=0. - let ops = ops4(CpuOperation { - decode: DecodeEntry { - op_add: true, - ..Default::default() - }, - rv1: 0xFEDC_BA98_7654_3210u64, - ..Default::default() - }); - +fn test_trace_padding_row() { + // One real op → padded to 4 rows; rows 1..4 are padding. + let ops = vec![ + ops4(Instruction::Arith { + dst: 3, + src1: 1, + src2: 2, + op: ArithOp::Add, + }) + .remove(0), + ]; let trace = generate_cpu_trace(&ops); - let row0 = trace.main_table.get_row(0); - - // res = arg1 + arg2 = rv1 + 0 = 0xFEDC_BA98_7654_3210 - // Stored as DWordBL: 8 bytes (little-endian) - assert_eq!(row0[cols::RES_0], FE::from(0x10u64)); - assert_eq!(row0[cols::RES_1], FE::from(0x32u64)); - assert_eq!(row0[cols::RES_2], FE::from(0x54u64)); - assert_eq!(row0[cols::RES_3], FE::from(0x76u64)); - assert_eq!(row0[cols::RES_4], FE::from(0x98u64)); - assert_eq!(row0[cols::RES_5], FE::from(0xBAu64)); - assert_eq!(row0[cols::RES_6], FE::from(0xDCu64)); - assert_eq!(row0[cols::RES_7], FE::from(0xFEu64)); + let pad = trace.main_table.get_row(1); + assert_eq!( + pad[cols::PC_0], + CPU_PADDING_PC.into(), + "padding pc = 1 (odd)" + ); + assert_eq!( + pad[cols::NEXT_PC_0], + CPU_PADDING_PC.into(), + "next_pc = pc (half_instruction_length = 0)" + ); + assert_eq!(pad[cols::HALF_INSTRUCTION_LENGTH], 0u64.into()); + assert_eq!(pad[cols::WORD_INSTR], 0u64.into()); } #[test] -fn test_trace_generation_ext_bits() { - let ops = ops4(CpuOperation { - decode: DecodeEntry { - word_instr: true, - op_add: true, - ..Default::default() - }, - rv1: 0x0000_0000_8000_0000u64, // bit 31 set - res: 0x0000_0000_8000_0000u64, // bit 31 set - ..Default::default() +fn test_trace_word_row_columns_masked() { + let ops = ops4(Instruction::ArithW { + dst: 3, + src1: 1, + src2: 2, + op: ArithOp::Add, }); - let trace = generate_cpu_trace(&ops); - let row0 = trace.main_table.get_row(0); - - assert_eq!(row0[cols::RV1_EXT_BIT], FE::one()); - assert_eq!(row0[cols::RES_EXT_BIT], FE::one()); -} - -#[test] -fn test_bus_interactions_count() { - let interactions = bus_interactions(); - - // Expected interactions: - // - 8 AND_BYTE - // - 8 OR_BYTE - // - 8 XOR_BYTE - // - 2 MSB16 (rv1_sign_bit, arg2_sign_bit) - // - 1 MSB8 (res_sign_bit) - // - 1 ZERO (is_equal for BEQ) - // - 1 LT (less-than comparison) - // - 1 M1 (MEMW read rs1 register) - // - 1 M3 (MEMW read rs2 register) - // - 1 M5 (MEMW write rd register) - // - 1 M6 (LOAD from memory) - // - 1 M7 (STORE to memory) - // - 4 inline PC (2 reads + 2 writes to Memory bus for x255) - // - 1 DECODE (instruction fetch) - // - 1 MUL (multiplication) - // - 1 DVRM (division/remainder) - // - 1 SHIFT (shift operations) - // - 1 BRANCH (branch/jump target calculation) - // - 1 ECALL (shared bus for HALT, COMMIT, and KECCAK, mult = ECALL) - // - 1 ARE_BYTES for (RS1, RS2) paired - // - 1 ARE_BYTES for (RD, 0) - // - 12 ARE_BYTES (ARG1/ARG2/RES byte pairs: 4 pairs × 3 arrays) - // Inline PC replaces CM54: -1 CM54, +4 inline PC → net +3 vs pre-PR main. - // Total: 8 + 8 + 8 + 2 + 1 + 1 + 1 + 1 + 5 + 4 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 12 = 58 - assert_eq!(interactions.len(), 58); -} - -#[test] -fn test_column_count() { - assert_eq!(cols::NUM_COLUMNS, 76); -} - -#[test] -fn test_column_arrays() { - // Verify ARG1, ARG2, RES arrays are correct - assert_eq!(cols::ARG1.len(), 8); - assert_eq!(cols::ARG2.len(), 8); - assert_eq!(cols::RES.len(), 8); - - // Check they're consecutive - for i in 0..7 { - assert_eq!(cols::ARG1[i + 1], cols::ARG1[i] + 1); - assert_eq!(cols::ARG2[i + 1], cols::ARG2[i] + 1); - assert_eq!(cols::RES[i + 1], cols::RES[i] + 1); - } -} - -// ============================================================================= -// ELF execution helpers and from_log tests -// ============================================================================= - -/// Helper to run an ELF and return the logs and instructions -fn run_elf(path: &str) -> (Vec, U64HashMap) { - let elf_data = std::fs::read(path).expect("Failed to read ELF"); - let program = Elf::load(&elf_data).expect("Failed to load ELF"); - let executor = Executor::new(&program, vec![]).expect("Failed to create executor"); - let result = executor.run().expect("Failed to run program"); - (result.logs, result.instructions) -} - -/// Helper to run an ELF from the program_artifacts directory -fn run_asm_elf(name: &str) -> (Vec, U64HashMap) { - run_elf(&format!( - "{}/executor/program_artifacts/asm/{}.elf", - env!("CARGO_MANIFEST_DIR").replace("/prover", ""), - name - )) -} - -#[test] -fn test_trace_from_logs_subw() { - // subw test - 4 steps (power of 2, works without padding) - let (logs, instructions) = run_asm_elf("subw"); - let traces = Traces::from_logs(&logs, instructions, &Default::default()).unwrap(); - - // Should have SUB instruction with word_instr flag - let has_sub = - (0..logs.len()).any(|i| traces.cpus[0].main_table.get_row(i)[cols::SUB] == FE::one()); - assert!(has_sub, "subw.elf should have SUB instruction"); + let row = trace.main_table.get_row(0); + // Delegate row: word_instr set, but all operational columns masked to 0. + assert_eq!(row[cols::WORD_INSTR], 1u64.into()); + assert_eq!(row[cols::HALF_INSTRUCTION_LENGTH], 2u64.into()); + assert_eq!( + row[cols::RV1_0], + 0u64.into(), + "rv1 column masked on word row" + ); + assert_eq!(row[cols::READ_REGISTER1], 0u64.into()); + assert_eq!(row[cols::ADD], 0u64.into()); + assert_eq!(row[cols::RVD_0], 0u64.into()); } -#[test] -fn test_cpu_operation_from_log_arith() { - use executor::vm::instruction::decoding::ArithOp; - use executor::vm::logs::Log; - - let instruction = Instruction::Arith { - dst: 10, - src1: 11, - src2: 12, - op: ArithOp::Add, - }; - - let log = Log { - current_pc: 0x1000, - next_pc: 0x1004, - src1_val: 100, - src2_val: 200, - dst_val: 300, - }; - - let op = CpuOperation::from_log_and_instruction(&log, 0, instruction); - - assert_eq!(op.decode.pc, 0x1000); - assert_eq!(op.next_pc, 0x1004); - assert_eq!(op.decode.rd, 10); - assert_eq!(op.decode.rs1, 11); - assert_eq!(op.decode.rs2, 12); - assert!(op.decode.op_add); - assert!(op.decode.write_register); - assert_eq!(op.rv1, 100); - assert_eq!(op.rv2, 200); - assert_eq!(op.res, 300); -} +// ========================================================================= +// collect_bitwise_ops +// ========================================================================= #[test] -fn test_cpu_operation_from_log_branch() { - use executor::vm::instruction::decoding::Comparison; - use executor::vm::logs::Log; - - let instruction = Instruction::Branch { - src1: 5, - src2: 6, - cond: Comparison::LessThan, - offset: 8, - }; - - let log = Log { - current_pc: 0x2000, - next_pc: 0x2008, // Branch taken - src1_val: 10, - src2_val: 20, - dst_val: 0, - }; - - let op = CpuOperation::from_log_and_instruction(&log, 4, instruction); - - assert_eq!(op.timestamp, 4); - assert_eq!(op.decode.pc, 0x2000); - assert!(op.decode.op_blt); - assert!(op.decode.signed); - assert!(op.branch_cond); // 10 < 20 - // For BLT, res is the comparison result (0 or 1), not subtraction - // res[0] = 1 if arg1 < arg2, res[1..7] = 0 (enforced by SLT res zero constraint) - assert_eq!(op.res, 1); // 10 < 20 = true +fn test_collect_bitwise_ops_shape() { + use crate::tables::bitwise::BitwiseOperationType; + let op = op_of( + Instruction::Arith { + dst: 3, + src1: 1, + src2: 2, + op: ArithOp::Add, + }, + 10, + 20, + 30, + PC + 4, + ); + let ops = op.collect_bitwise_ops(); + assert_eq!(ops.len(), 7, "3 ARE_BYTES + 4 IS_HALF"); + assert!( + ops[0..3] + .iter() + .all(|o| o.lookup_type == BitwiseOperationType::AreBytes) + ); + assert!( + ops[3..7] + .iter() + .all(|o| o.lookup_type == BitwiseOperationType::IsHalf) + ); + // First ARE_BYTES is (rs1, rs2) = (1, 2). + assert_eq!(ops[0].x, 1); + assert_eq!(ops[0].y, 2); } #[test] -fn test_cpu_operation_from_log_word_instr() { - use executor::vm::instruction::decoding::ArithOp; - use executor::vm::logs::Log; - - let instruction = Instruction::ArithW { - dst: 1, - src1: 2, - src2: 3, - op: ArithOp::Add, - }; - - let log = Log { - current_pc: 0x3000, - next_pc: 0x3004, - src1_val: 0xFFFF_FFFF_8000_0000, // Would be negative as 32-bit - src2_val: 1, - dst_val: 0xFFFF_FFFF_8000_0001, // Result sign-extended - }; - - let op = CpuOperation::from_log_and_instruction(&log, 8, instruction); - - assert!(op.decode.word_instr); - assert!(op.decode.op_add); +fn test_collect_bitwise_ops_word_row_zeroed() { + let op = op_of( + Instruction::ArithW { + dst: 3, + src1: 1, + src2: 2, + op: ArithOp::Add, + }, + 10, + 20, + 30, + PC + 4, + ); + let ops = op.collect_bitwise_ops(); + // On a word delegate row the CPU zeroes rs1/rs2/rd/alu_flags/mem_flags/res, + // but half_instruction_length stays (it is set unconditionally in the trace). + assert_eq!(ops[0].x, 0, "rs1 zeroed"); + assert_eq!(ops[0].y, 0, "rs2 zeroed"); + assert_eq!(ops[1].x, 0, "rd zeroed"); + assert_eq!(ops[1].y, 2, "half_instruction_length retained"); } diff --git a/prover/src/tests/decode_layout_tests.rs b/prover/src/tests/decode_layout_tests.rs new file mode 100644 index 000000000..e56a40316 --- /dev/null +++ b/prover/src/tests/decode_layout_tests.rs @@ -0,0 +1,589 @@ +//! Tests for the `packed_decode` layout. +//! +//! These validate the single source of truth (`types::packed_decode_shrunk`, +//! `build_alu_flags`/`build_mem_flags`, `ShrunkDecode`) before it is wired into +//! the DECODE/CPU tables in Phase 2+3 of the rework. + +use crate::tables::types::{ + ShrunkDecode, alu_op, build_alu_flags, build_mem_flags, packed_decode_shrunk as bits, +}; +use executor::vm::instruction::decoding::{ArithOp, Comparison, Instruction, LoadStoreWidth}; + +#[test] +fn test_build_alu_flags_matches_spec_formula() { + // alu_flags = alu_op + 32·signed + 64·(signed2|invert) + 128·muldiv + assert_eq!(build_alu_flags(alu_op::AND, false, false, false), 0); + assert_eq!(build_alu_flags(alu_op::OR, false, false, false), 1); + assert_eq!(build_alu_flags(alu_op::XOR, false, false, false), 2); + // SLT (signed less-than) + assert_eq!(build_alu_flags(alu_op::LT, true, false, false), 4 + 32); + // SLTU (unsigned) + assert_eq!(build_alu_flags(alu_op::LT, false, false, false), 4); + // SRL (logical right shift): invert set + assert_eq!(build_alu_flags(alu_op::SHIFT, false, true, false), 5 + 64); + // SRA (arithmetic right shift): signed + invert + assert_eq!( + build_alu_flags(alu_op::SHIFT, true, true, false), + 5 + 32 + 64 + ); + // MUL: signed + signed2 + assert_eq!(build_alu_flags(alu_op::MUL, true, true, false), 7 + 32 + 64); + // MULH: signed + signed2 + muldiv + assert_eq!( + build_alu_flags(alu_op::MUL, true, true, true), + 7 + 32 + 64 + 128 + ); + // MULHU: muldiv only + assert_eq!(build_alu_flags(alu_op::MUL, false, false, true), 7 + 128); + // REM (signed): DIVREM + signed + muldiv + assert_eq!( + build_alu_flags(alu_op::DIVREM, true, false, true), + 8 + 32 + 128 + ); +} + +#[test] +fn test_build_mem_flags_matches_spec_formula() { + // mem_flags = jalr_or_op + 2·signed + 4·2B + 8·4B + 16·8B + // LB (signed byte load): mem_signed only + assert_eq!(build_mem_flags(false, true, false, false, false), 2); + // LBU (unsigned byte load): nothing + assert_eq!(build_mem_flags(false, false, false, false, false), 0); + // LH (signed halfword): signed + 2B + assert_eq!(build_mem_flags(false, true, true, false, false), 2 + 4); + // LW (signed word): signed + 4B + assert_eq!(build_mem_flags(false, true, false, true, false), 2 + 8); + // LD (doubleword, always full): 8B + assert_eq!(build_mem_flags(false, false, false, false, true), 16); + // SB (store byte): memory_op bit + assert_eq!(build_mem_flags(true, false, false, false, false), 1); + // SD (store doubleword): memory_op + 8B + assert_eq!(build_mem_flags(true, false, false, false, true), 1 + 16); +} + +#[test] +fn test_field_placement() { + // Each field lands at its declared offset and nowhere else. + assert_eq!( + ShrunkDecode { + memory: true, + ..Default::default() + } + .pack(), + 1 << bits::MEMORY + ); + assert_eq!( + ShrunkDecode { + rs1: 0xFF, + ..Default::default() + } + .pack(), + 0xFF << bits::RS1 + ); + assert_eq!( + ShrunkDecode { + rd: 0xAB, + ..Default::default() + } + .pack(), + 0xAB << bits::RD + ); + assert_eq!( + ShrunkDecode { + half_instruction_length: 2, + ..Default::default() + } + .pack(), + 2 << bits::HALF_INSTRUCTION_LENGTH + ); + assert_eq!( + ShrunkDecode { + alu_flags: 0xFF, + ..Default::default() + } + .pack(), + 0xFF << bits::ALU_FLAGS + ); + assert_eq!( + ShrunkDecode { + mem_flags: 0xFF, + ..Default::default() + } + .pack(), + 0xFF << bits::MEM_FLAGS + ); +} + +#[test] +fn test_fields_are_disjoint_and_fit_in_58_bits() { + // All fields maxed out. + let full = ShrunkDecode { + read_register1: true, + read_register2: true, + write_register: true, + word_instr: true, + alu: true, + add: true, + sub: true, + memory: true, + branch: true, + ecall: true, + rs1: 0xFF, + rs2: 0xFF, + rd: 0xFF, + half_instruction_length: 0xFF, + alu_flags: 0xFF, + mem_flags: 0xFF, + }; + let packed = full.pack(); + + // Fits in 58 bits (mem_flags ends at bit 50+8 = 58). + assert_eq!(packed >> 58, 0, "packed_decode must fit in 58 bits"); + + // Disjointness: with no overlap, summing each field's individual pack + // equals the combined pack (OR == sum iff masks are disjoint). + let individual_sum: u64 = [ + ShrunkDecode { + read_register1: true, + ..Default::default() + }, + ShrunkDecode { + read_register2: true, + ..Default::default() + }, + ShrunkDecode { + write_register: true, + ..Default::default() + }, + ShrunkDecode { + word_instr: true, + ..Default::default() + }, + ShrunkDecode { + alu: true, + ..Default::default() + }, + ShrunkDecode { + add: true, + ..Default::default() + }, + ShrunkDecode { + sub: true, + ..Default::default() + }, + ShrunkDecode { + memory: true, + ..Default::default() + }, + ShrunkDecode { + branch: true, + ..Default::default() + }, + ShrunkDecode { + ecall: true, + ..Default::default() + }, + ShrunkDecode { + rs1: 0xFF, + ..Default::default() + }, + ShrunkDecode { + rs2: 0xFF, + ..Default::default() + }, + ShrunkDecode { + rd: 0xFF, + ..Default::default() + }, + ShrunkDecode { + half_instruction_length: 0xFF, + ..Default::default() + }, + ShrunkDecode { + alu_flags: 0xFF, + ..Default::default() + }, + ShrunkDecode { + mem_flags: 0xFF, + ..Default::default() + }, + ] + .iter() + .map(ShrunkDecode::pack) + .sum(); + + assert_eq!( + individual_sum, packed, + "packed_decode fields must be disjoint" + ); +} + +#[test] +fn test_pack_unpack_round_trip() { + let entries = [ + ShrunkDecode::default(), + // An ALU register op: ADD rd, rs1, rs2 + ShrunkDecode { + read_register1: true, + read_register2: true, + write_register: true, + add: true, + rs1: 0x11, + rs2: 0x22, + rd: 0x33, + half_instruction_length: 2, + ..Default::default() + }, + // A signed word ALU op going through the ALU bus (e.g. SRAW) + ShrunkDecode { + read_register1: true, + write_register: true, + word_instr: true, + alu: true, + rs1: 7, + rd: 9, + half_instruction_length: 2, + alu_flags: build_alu_flags(alu_op::SHIFTW, true, true, false), + ..Default::default() + }, + // A load: LW rd, imm(rs1) + ShrunkDecode { + read_register1: true, + write_register: true, + memory: true, + rs1: 5, + rd: 6, + half_instruction_length: 2, + mem_flags: build_mem_flags(false, true, false, true, false), + ..Default::default() + }, + // Fully saturated. + ShrunkDecode { + read_register1: true, + read_register2: true, + write_register: true, + word_instr: true, + alu: true, + add: true, + sub: true, + memory: true, + branch: true, + ecall: true, + rs1: 0xFF, + rs2: 0xFF, + rd: 0xFF, + half_instruction_length: 0xFF, + alu_flags: 0xFF, + mem_flags: 0xFF, + }, + ]; + + for entry in entries { + assert_eq!(ShrunkDecode::unpack(entry.pack()), entry); + } +} + +#[test] +fn test_from_instruction_arith_ops() { + // ADD rd=3, rs1=1, rs2=2 → ADD fast-path (ALU not set), all reg flags on. + let d = ShrunkDecode::from_instruction( + Instruction::Arith { + dst: 3, + src1: 1, + src2: 2, + op: ArithOp::Add, + }, + 4, + ); + assert!(d.add && !d.alu && !d.sub); + assert!(d.read_register1 && d.read_register2 && d.write_register); + assert_eq!( + (d.rs1, d.rs2, d.rd, d.half_instruction_length), + (1, 2, 3, 2) + ); + assert_eq!(d.alu_flags, 0); + + // AND → ALU path, alu_flags = AND. + let d = ShrunkDecode::from_instruction( + Instruction::Arith { + dst: 5, + src1: 6, + src2: 7, + op: ArithOp::And, + }, + 4, + ); + assert!(d.alu && !d.add && !d.sub); + assert_eq!( + d.alu_flags, + build_alu_flags(alu_op::AND, false, false, false) + ); + + // SUB → SUB fast-path. + let d = ShrunkDecode::from_instruction( + Instruction::Arith { + dst: 1, + src1: 2, + src2: 3, + op: ArithOp::Sub, + }, + 4, + ); + assert!(d.sub && !d.add && !d.alu); + + // SLT (signed) → ALU, LT signed. + let d = ShrunkDecode::from_instruction( + Instruction::Arith { + dst: 1, + src1: 2, + src2: 3, + op: ArithOp::SetLessThan, + }, + 4, + ); + assert_eq!(d.alu_flags, build_alu_flags(alu_op::LT, true, false, false)); + + // x0 operands/dest → no read/write flags. + let d = ShrunkDecode::from_instruction( + Instruction::Arith { + dst: 0, + src1: 0, + src2: 0, + op: ArithOp::Add, + }, + 4, + ); + assert!(!d.write_register && !d.read_register1 && !d.read_register2); +} + +#[test] +fn test_from_instruction_word_shifts() { + // SRAW → word_instr, SHIFTW, signed + invert. + let d = ShrunkDecode::from_instruction( + Instruction::ArithW { + dst: 1, + src1: 2, + src2: 3, + op: ArithOp::ShiftRightArith, + }, + 4, + ); + assert!(d.word_instr && d.alu); + assert_eq!( + d.alu_flags, + build_alu_flags(alu_op::SHIFTW, true, true, false) + ); + + // SLL (non-word) → SHIFT, no invert. + let d = ShrunkDecode::from_instruction( + Instruction::Arith { + dst: 1, + src1: 2, + src2: 3, + op: ArithOp::ShiftLeftLogical, + }, + 4, + ); + assert!(!d.word_instr); + assert_eq!( + d.alu_flags, + build_alu_flags(alu_op::SHIFT, false, false, false) + ); +} + +#[test] +fn test_from_instruction_mul_div() { + // MULHU → unsigned, muldiv. + let d = ShrunkDecode::from_instruction( + Instruction::Arith { + dst: 1, + src1: 2, + src2: 3, + op: ArithOp::MulHighUnsigned, + }, + 4, + ); + assert_eq!( + d.alu_flags, + build_alu_flags(alu_op::MUL, false, false, true) + ); + + // REM (signed) → DIVREM, signed, muldiv. + let d = ShrunkDecode::from_instruction( + Instruction::Arith { + dst: 1, + src1: 2, + src2: 3, + op: ArithOp::Remainder, + }, + 4, + ); + assert_eq!( + d.alu_flags, + build_alu_flags(alu_op::DIVREM, true, false, true) + ); +} + +#[test] +fn test_from_instruction_branches_set_branch_and_alu() { + // Q3: conditional branches set BRANCH ∧ ALU; mem_flags = 0 (not JALR); no rd write. + let d = ShrunkDecode::from_instruction( + Instruction::Branch { + src1: 1, + src2: 2, + cond: Comparison::Equal, + offset: 16, + }, + 4, + ); + assert!(d.branch && d.alu && !d.write_register); + assert_eq!( + d.alu_flags, + build_alu_flags(alu_op::EQ, false, false, false) + ); + assert_eq!(d.mem_flags, 0); + + // BNE → EQ inverted. + let d = ShrunkDecode::from_instruction( + Instruction::Branch { + src1: 1, + src2: 2, + cond: Comparison::NotEqual, + offset: 16, + }, + 4, + ); + assert_eq!(d.alu_flags, build_alu_flags(alu_op::EQ, false, true, false)); + + // BGE → LT signed inverted. + let d = ShrunkDecode::from_instruction( + Instruction::Branch { + src1: 1, + src2: 2, + cond: Comparison::GreaterOrEqual, + offset: 16, + }, + 4, + ); + assert_eq!(d.alu_flags, build_alu_flags(alu_op::LT, true, true, false)); +} + +#[test] +fn test_from_instruction_jumps() { + // JAL → BRANCH + JALR bit, no ALU op, rs1 = x255. + let d = ShrunkDecode::from_instruction(Instruction::JumpAndLink { dst: 1, offset: 32 }, 4); + assert!(d.branch && d.write_register && d.read_register1); + assert!(!d.add && !d.sub && !d.alu); + assert_eq!(d.rs1, 255); + assert_eq!( + d.mem_flags, + build_mem_flags(true, false, false, false, false) + ); + + // JALR → BRANCH + JALR bit, no ALU op, rs1 = base. + let d = ShrunkDecode::from_instruction( + Instruction::JumpAndLinkRegister { + base: 9, + dst: 1, + offset: 0, + }, + 4, + ); + assert!(d.branch); + assert!(!d.add && !d.sub && !d.alu); + assert_eq!(d.rs1, 9); + assert_eq!(d.mem_flags & 1, 1); +} + +#[test] +fn test_from_instruction_load_store() { + // LW (signed) → ADD + MEMORY, mem_signed + mem_4B. + let d = ShrunkDecode::from_instruction( + Instruction::Load { + dst: 1, + offset: 0, + base: 2, + width: LoadStoreWidth::Word, + }, + 4, + ); + assert!(d.add && d.memory && d.write_register); + assert_eq!( + d.mem_flags, + build_mem_flags(false, true, false, true, false) + ); + + // LBU → no signed, no width bits. + let d = ShrunkDecode::from_instruction( + Instruction::Load { + dst: 1, + offset: 0, + base: 2, + width: LoadStoreWidth::ByteUnsigned, + }, + 4, + ); + assert_eq!(d.mem_flags, 0); + + // SD → ADD + MEMORY, memory_op + mem_8B, no rd write. + let d = ShrunkDecode::from_instruction( + Instruction::Store { + src: 3, + offset: 0, + base: 2, + width: LoadStoreWidth::DoubleWord, + }, + 4, + ); + assert!(d.add && d.memory && !d.write_register); + assert_eq!( + d.mem_flags, + build_mem_flags(true, false, false, false, true) + ); +} + +#[test] +fn test_from_instruction_system() { + // ECALL → ECALL, rs1 = x17 (a7). + let d = ShrunkDecode::from_instruction(Instruction::EcallEbreak, 4); + assert!(d.ecall && d.read_register1); + assert_eq!(d.rs1, 17); + + // LUI → ADD, rs1 = x0. + let d = ShrunkDecode::from_instruction( + Instruction::LoadUpperImm { + dst: 5, + imm: 0x1000, + }, + 4, + ); + assert!(d.add && d.write_register); + assert_eq!(d.rs1, 0); + + // AUIPC → ADD, rs1 = x255. + let d = ShrunkDecode::from_instruction( + Instruction::AddUpperImmToPc { + dst: 5, + imm: 0x1000, + }, + 4, + ); + assert!(d.add && d.read_register1); + assert_eq!(d.rs1, 255); + + // FENCE → ADD no-op. + let d = ShrunkDecode::from_instruction(Instruction::Fence, 4); + assert!(d.add); + + // Compressed instruction length (2 bytes) propagates as half = 1. + let d = ShrunkDecode::from_instruction( + Instruction::Arith { + dst: 1, + src1: 2, + src2: 3, + op: ArithOp::Add, + }, + 2, + ); + assert_eq!(d.half_instruction_length, 1); +} diff --git a/prover/src/tests/decode_tests.rs b/prover/src/tests/decode_tests.rs index 84ae8ff3a..6aaeb0b14 100644 --- a/prover/src/tests/decode_tests.rs +++ b/prover/src/tests/decode_tests.rs @@ -1,1169 +1,183 @@ //! Tests for the DECODE table. - -use executor::elf::{Elf, Segment}; -use executor::vm::instruction::decoding::{ArithOp, Instruction}; -use executor::vm::memory::U64HashMap; -use math::field::element::FieldElement; - -use stark::proof::options::GoldilocksCubicProofOptions; - -use crate::tables::decode::{ - DecodeEntry, bus_interactions, cols, commitment_from_elf, generate_decode_trace, - instructions_from_elf, tables_from_elf, update_multiplicities, -}; -use crate::tables::trace_builder::Traces; -use crate::tables::types::{FE, packed_decode as bits}; +//! +//! `decode_layout_tests` covers the `ShrunkDecode` pack/unpack/from_instruction +//! bit layout in isolation; here we test the `DecodeEntry` wrapper (pc/imm +//! extraction, padding) and the DECODE *table* generation (`generate_decode_trace`): +//! the per-instruction rows, the `pc = 1` padding entry, and the `pc_to_row` map. + +use crate::tables::cpu::CPU_PADDING_PC; +use crate::tables::decode::{cols, commitment_from_elf, generate_decode_trace}; +use crate::tables::types::DecodeEntry; use crate::test_utils::asm_elf_bytes; -use crate::test_utils::multi_prove_ram; -use crate::test_utils::run_asm_elf; use crate::{prove, verify_with_options}; -// ========================================================================= -// Packed decode tests -// ========================================================================= - -#[test] -fn test_packed_decode_flags() { - // Test each control flag individually using the constants from packed_decode module. - // This validates that the constants match the actual bit packing logic. - let mut entry = DecodeEntry::new(); - - // READ_REG1: excludes x0 and x255, so we need rs1 != 0 && rs1 != 255 - entry.read_register1 = true; - entry.rs1 = 1; - assert_eq!( - entry.packed_decode() & (1 << bits::READ_REG1), - 1 << bits::READ_REG1 - ); - entry.read_register1 = false; - entry.rs1 = 0; - - // READ_REG2: excludes x0, so we need rs2 != 0 - entry.read_register2 = true; - entry.rs2 = 1; - assert_eq!( - entry.packed_decode() & (1 << bits::READ_REG2), - 1 << bits::READ_REG2 - ); - entry.read_register2 = false; - entry.rs2 = 0; - - // WRITE_REG: excludes x0, so we need rd != 0 - entry.write_register = true; - entry.rd = 1; - assert_eq!( - entry.packed_decode() & (1 << bits::WRITE_REG), - 1 << bits::WRITE_REG - ); - entry.write_register = false; - entry.rd = 0; - - // MEMORY_2BYTES - entry.memory_2bytes = true; - assert_eq!( - entry.packed_decode() & (1 << bits::MEMORY_2BYTES), - 1 << bits::MEMORY_2BYTES - ); - entry.memory_2bytes = false; - - // MEMORY_4BYTES - entry.memory_4bytes = true; - assert_eq!( - entry.packed_decode() & (1 << bits::MEMORY_4BYTES), - 1 << bits::MEMORY_4BYTES - ); - entry.memory_4bytes = false; - - // MEMORY_8BYTES - entry.memory_8bytes = true; - assert_eq!( - entry.packed_decode() & (1 << bits::MEMORY_8BYTES), - 1 << bits::MEMORY_8BYTES - ); - entry.memory_8bytes = false; - - // C_TYPE - entry.c_type = true; - assert_eq!( - entry.packed_decode() & (1 << bits::C_TYPE), - 1 << bits::C_TYPE - ); - entry.c_type = false; - - // SIGNED - entry.signed = true; - assert_eq!( - entry.packed_decode() & (1 << bits::SIGNED), - 1 << bits::SIGNED - ); - entry.signed = false; - - // MP_SELECTOR - entry.mp_selector = true; - assert_eq!( - entry.packed_decode() & (1 << bits::MP_SELECTOR), - 1 << bits::MP_SELECTOR - ); - entry.mp_selector = false; - - // MULDIV_SELECTOR - entry.muldiv_selector = true; - assert_eq!( - entry.packed_decode() & (1 << bits::MULDIV_SELECTOR), - 1 << bits::MULDIV_SELECTOR - ); - entry.muldiv_selector = false; - - // WORD_INSTR - entry.word_instr = true; - assert_eq!( - entry.packed_decode() & (1 << bits::WORD_INSTR), - 1 << bits::WORD_INSTR - ); -} - -#[test] -fn test_packed_decode_alu_flags() { - // ALU flags - using constants to validate they match the packing logic - let mut entry = DecodeEntry::new(); - - entry.op_add = true; - assert_eq!( - entry.packed_decode() & (1 << bits::OP_ADD), - 1 << bits::OP_ADD - ); - entry.op_add = false; - - entry.op_sub = true; - assert_eq!( - entry.packed_decode() & (1 << bits::OP_SUB), - 1 << bits::OP_SUB - ); - entry.op_sub = false; - - entry.op_slt = true; - assert_eq!( - entry.packed_decode() & (1 << bits::OP_SLT), - 1 << bits::OP_SLT - ); - entry.op_slt = false; - - entry.op_and = true; - assert_eq!( - entry.packed_decode() & (1 << bits::OP_AND), - 1 << bits::OP_AND - ); - entry.op_and = false; - - entry.op_or = true; - assert_eq!(entry.packed_decode() & (1 << bits::OP_OR), 1 << bits::OP_OR); - entry.op_or = false; - - entry.op_xor = true; - assert_eq!( - entry.packed_decode() & (1 << bits::OP_XOR), - 1 << bits::OP_XOR - ); - entry.op_xor = false; - - entry.op_shift = true; - assert_eq!( - entry.packed_decode() & (1 << bits::OP_SHIFT), - 1 << bits::OP_SHIFT - ); - entry.op_shift = false; - - entry.op_jalr = true; - assert_eq!( - entry.packed_decode() & (1 << bits::OP_JALR), - 1 << bits::OP_JALR - ); - entry.op_jalr = false; - - entry.op_beq = true; - assert_eq!( - entry.packed_decode() & (1 << bits::OP_BEQ), - 1 << bits::OP_BEQ - ); - entry.op_beq = false; - - entry.op_blt = true; - assert_eq!( - entry.packed_decode() & (1 << bits::OP_BLT), - 1 << bits::OP_BLT - ); - entry.op_blt = false; - - entry.op_load = true; - assert_eq!( - entry.packed_decode() & (1 << bits::OP_LOAD), - 1 << bits::OP_LOAD - ); - entry.op_load = false; - - entry.op_store = true; - assert_eq!( - entry.packed_decode() & (1 << bits::OP_STORE), - 1 << bits::OP_STORE - ); - entry.op_store = false; - - entry.op_mul = true; - assert_eq!( - entry.packed_decode() & (1 << bits::OP_MUL), - 1 << bits::OP_MUL - ); - entry.op_mul = false; - - entry.op_divrem = true; - assert_eq!( - entry.packed_decode() & (1 << bits::OP_DIVREM), - 1 << bits::OP_DIVREM - ); - entry.op_divrem = false; - - entry.op_ecall = true; - assert_eq!( - entry.packed_decode() & (1 << bits::OP_ECALL), - 1 << bits::OP_ECALL - ); - entry.op_ecall = false; - - entry.op_ebreak = true; - assert_eq!( - entry.packed_decode() & (1 << bits::OP_EBREAK), - 1 << bits::OP_EBREAK - ); -} - -#[test] -fn test_packed_decode_registers() { - // Register positions - using constants - let mut entry = DecodeEntry::new(); - - // rs1 - entry.rs1 = 0b10101010; - let packed = entry.packed_decode(); - let rs1_extracted = (packed >> bits::RS1) & 0xFF; - assert_eq!(rs1_extracted, 0b10101010); - entry.rs1 = 0; - - // rs2 - entry.rs2 = 0b11001100; - let packed = entry.packed_decode(); - let rs2_extracted = (packed >> bits::RS2) & 0xFF; - assert_eq!(rs2_extracted, 0b11001100); - entry.rs2 = 0; - - // rd - entry.rd = 0b11110000; - let packed = entry.packed_decode(); - let rd_extracted = (packed >> bits::RD) & 0xFF; - assert_eq!(rd_extracted, 0b11110000); -} - -#[test] -fn test_packed_decode_combined() { - // Test with realistic ADD instruction: rd=10, rs1=5, rs2=6 - // Per decode.md spec: read_register1 at bit 0, read_register2 at bit 1, - // write_register at bit 2, op_add at bit 11 - let entry = DecodeEntry { - pc: 0x1000, - rs1: 5, - rs2: 6, - rd: 10, - read_register1: true, - read_register2: true, - write_register: true, - op_add: true, - ..Default::default() - }; - - let packed = entry.packed_decode(); - - // Verify flags per spec - assert_eq!( - packed & (1 << 0), - 1 << 0, - "read_register1 should be set at bit 0" - ); - assert_eq!( - packed & (1 << 1), - 1 << 1, - "read_register2 should be set at bit 1" - ); - assert_eq!( - packed & (1 << 2), - 1 << 2, - "write_register should be set at bit 2" - ); - assert_eq!( - packed & (1 << 11), - 1 << 11, - "op_add should be set at bit 11" - ); - - // Verify registers per spec: rs1 at bits 27-34, rs2 at bits 35-42, rd at bits 43-50 - assert_eq!((packed >> 27) & 0xFF, 5, "rs1 should be 5"); - assert_eq!((packed >> 35) & 0xFF, 6, "rs2 should be 6"); - assert_eq!((packed >> 43) & 0xFF, 10, "rd should be 10"); -} - -// ========================================================================= -// Padding entry tests -// ========================================================================= - -#[test] -fn test_padding_entry() { - let padding = DecodeEntry::padding_entry(); - - assert_eq!(padding.pc, 7, "Padding entry should have pc=7"); - assert!(padding.op_ebreak, "Padding entry should have EBREAK=1"); - - // All other flags should be false - assert!(!padding.read_register1); - assert!(!padding.read_register2); - assert!(!padding.write_register); - assert!(!padding.op_add); - assert!(!padding.op_sub); - assert_eq!(padding.rs1, 0); - assert_eq!(padding.rs2, 0); - assert_eq!(padding.rd, 0); - assert_eq!(padding.imm, 0); -} - -// ========================================================================= -// from_instruction tests -// ========================================================================= - -#[test] -fn test_from_instruction_arith() { - // ADD x10, x5, x6 - let instr = Instruction::Arith { - dst: 10, - src1: 5, - src2: 6, - op: ArithOp::Add, - }; - - let entry = DecodeEntry::from_instruction(0x1000, instr); - - assert_eq!(entry.pc, 0x1000); - assert_eq!(entry.rd, 10); - assert_eq!(entry.rs1, 5); - assert_eq!(entry.rs2, 6); - assert!(entry.read_register1); - assert!(entry.read_register2); - assert!(entry.write_register); - assert!(entry.op_add); -} - -#[test] -fn test_from_instruction_arith_imm() { - // ADDI x10, x5, 100 - let instr = Instruction::ArithImm { - dst: 10, - src: 5, - imm: 100, - op: ArithOp::Add, - }; - - let entry = DecodeEntry::from_instruction(0x1000, instr); - - assert_eq!(entry.pc, 0x1000); - assert_eq!(entry.rd, 10); - assert_eq!(entry.rs1, 5); - assert_eq!(entry.rs2, 0); - assert_eq!(entry.imm, 100); - assert!(entry.read_register1); - assert!(!entry.read_register2); - assert!(entry.write_register); - assert!(entry.op_add); -} +use executor::elf::Elf; +use executor::vm::instruction::decoding::{ArithOp, Comparison, Instruction, LoadStoreWidth}; +use executor::vm::memory::U64HashMap; +use stark::proof::options::GoldilocksCubicProofOptions; // ========================================================================= -// Trace generation tests +// DecodeEntry // ========================================================================= #[test] -fn test_trace_generation_basic() { - let mut instructions = U64HashMap::default(); - instructions.insert( - 0x1000, - Instruction::Arith { - dst: 1, - src1: 2, - src2: 3, - op: ArithOp::Add, - }, - ); - instructions.insert( - 0x1004, - Instruction::Arith { - dst: 4, - src1: 5, - src2: 6, - op: ArithOp::Sub, - }, - ); - - let (trace, _pc_to_row) = generate_decode_trace(&instructions); - - // 2 instructions + 1 CPU padding entry = 3, padded to power of 2 = 4 - assert_eq!(trace.main_table.height, 4); - assert_eq!(trace.main_table.width, cols::NUM_COLUMNS); +fn test_decode_entry_default_and_padding() { + let d = DecodeEntry::new(); + assert_eq!(d.pc, 0); + assert_eq!(d.imm, 0); + assert_eq!(d.packed_decode(), 0); + + let pad = DecodeEntry::padding_entry(); + assert_eq!(pad.pc, CPU_PADDING_PC, "padding sits at the odd address 1"); + assert_eq!(pad.imm, 0); + assert_eq!(pad.packed_decode(), 0, "padding has all flags zero"); } #[test] -fn test_trace_multiplicities() { - let mut instructions = U64HashMap::default(); - instructions.insert( - 0x1000, +fn test_decode_entry_packed_decode_matches_fields() { + let d = DecodeEntry::from_instruction( + 0x2000, Instruction::Arith { - dst: 1, - src1: 2, - src2: 3, + dst: 3, + src1: 1, + src2: 2, op: ArithOp::Add, }, + 4, ); - - let (mut trace, pc_to_row) = generate_decode_trace(&instructions); - - // PC 0x1000 executed 5 times - let lookups = vec![0x1000, 0x1000, 0x1000, 0x1000, 0x1000]; - update_multiplicities(&mut trace, &pc_to_row, &lookups); - - // Should be padded to 2 (1 entry -> next power of 2) - assert_eq!(trace.main_table.height, 2); - - // Find the row with pc=0x1000 - let mut found = false; - for row_idx in 0..trace.main_table.height { - let row = trace.main_table.get_row(row_idx); - if row[cols::PC_0] == FE::from(0x1000u64) { - assert_eq!(row[cols::MU], FE::from(5u64), "Multiplicity should be 5"); - found = true; - } - } - assert!(found, "Row with pc=0x1000 not found"); + assert_eq!(d.packed_decode(), d.fields.pack()); + assert!(d.fields.add, "ADD is a fast-path flag"); + assert_eq!(d.fields.half_instruction_length, 2); } #[test] -fn test_trace_multiple_instructions_different_multiplicities() { - let mut instructions = U64HashMap::default(); - instructions.insert( - 0x1000, +fn test_decode_entry_imm_extraction() { + let add = DecodeEntry::from_instruction( + 0, Instruction::Arith { - dst: 1, - src1: 2, - src2: 3, + dst: 3, + src1: 1, + src2: 2, op: ArithOp::Add, }, + 4, ); - instructions.insert( - 0x1004, - Instruction::Arith { - dst: 4, - src1: 5, - src2: 6, - op: ArithOp::Sub, - }, - ); - - let (mut trace, pc_to_row) = generate_decode_trace(&instructions); - - // 0x1000 executed 3 times, 0x1004 executed 7 times - let lookups = vec![ - 0x1000, 0x1004, 0x1000, 0x1004, 0x1004, 0x1000, 0x1004, 0x1004, 0x1004, 0x1004, - ]; - update_multiplicities(&mut trace, &pc_to_row, &lookups); - - // 2 instructions + 1 CPU padding entry = 3, padded to 4 - assert_eq!(trace.main_table.height, 4); - - let mut mu_1000 = None; - let mut mu_1004 = None; + assert_eq!(add.imm, 0, "reg-reg has no immediate"); - for row_idx in 0..trace.main_table.height { - let row = trace.main_table.get_row(row_idx); - if row[cols::PC_0] == FE::from(0x1000u64) { - mu_1000 = Some(row[cols::MU]); - } - if row[cols::PC_0] == FE::from(0x1004u64) { - mu_1004 = Some(row[cols::MU]); - } - } - - assert_eq!(mu_1000, Some(FE::from(3u64)), "PC 0x1000 should have mu=3"); - assert_eq!(mu_1004, Some(FE::from(7u64)), "PC 0x1004 should have mu=7"); -} - -#[test] -fn test_trace_padding_to_power_of_two() { - let mut instructions = U64HashMap::default(); - instructions.insert( - 0x1000, - Instruction::Arith { - dst: 1, - src1: 2, - src2: 3, - op: ArithOp::Add, - }, - ); - instructions.insert( - 0x1004, - Instruction::Arith { - dst: 4, - src1: 5, - src2: 6, - op: ArithOp::Sub, - }, - ); - instructions.insert( - 0x1008, - Instruction::Arith { - dst: 7, - src1: 8, - src2: 9, - op: ArithOp::Add, - }, - ); - - let (trace, _pc_to_row) = generate_decode_trace(&instructions); - - // 3 instructions + 1 CPU padding entry = 4, already power of 2 - assert_eq!( - trace.main_table.height, 4, - "3 instructions + 1 CPU padding entry = 4 rows" - ); - - // Verify the CPU padding row has pc=1 and all flags=0 - let mut found_cpu_padding = false; - for row_idx in 0..trace.main_table.height { - let row = trace.main_table.get_row(row_idx); - if row[cols::PC_0] == FE::from(1u64) { - assert_eq!( - row[cols::PACKED_DECODE], - FE::zero(), - "CPU padding entry should have all flags=0" - ); - assert_eq!( - row[cols::MU], - FE::zero(), - "CPU padding entry should have mu=0" - ); - found_cpu_padding = true; - } - } - assert!(found_cpu_padding, "CPU padding row with pc=1 not found"); -} - -#[test] -fn test_trace_dword_encoding() { - // Test 64-bit PC and immediate encoding as DWordWL - let mut instructions = U64HashMap::default(); - instructions.insert( - 0xDEAD_BEEF_1234_5678, + let addi = DecodeEntry::from_instruction( + 0, Instruction::ArithImm { - dst: 1, - src: 2, - imm: 0x8765_4321u32 as i32, // Will be sign-extended - op: ArithOp::Add, - }, - ); - - let (trace, _pc_to_row) = generate_decode_trace(&instructions); - - // Find the row (could be row 0 or 1 due to HashMap ordering) - let mut found = false; - for row_idx in 0..trace.main_table.height { - let row = trace.main_table.get_row(row_idx); - if row[cols::PC_0] == FE::from(0x1234_5678u64) { - // PC low word - assert_eq!(row[cols::PC_0], FE::from(0x1234_5678u64)); - // PC high word - assert_eq!(row[cols::PC_1], FE::from(0xDEAD_BEEFu64)); - found = true; - } - } - assert!(found, "Row with expected PC not found"); -} - -// ========================================================================= -// Bus interaction tests -// ========================================================================= - -#[test] -fn test_bus_interactions_count() { - let interactions = bus_interactions(); - - // DECODE table should have exactly 1 interaction (receiver for DECODE bus) - assert_eq!( - interactions.len(), - 1, - "DECODE should have 1 bus interaction" - ); -} - -#[test] -fn test_bus_interactions_is_receiver() { - let interactions = bus_interactions(); - - // The single interaction should be a receiver (is_sender = false) - assert!( - !interactions[0].is_sender, - "DECODE should be a receiver, not sender" - ); -} - -// ========================================================================= -// Precomputed commitment tests -// ========================================================================= - -#[test] -fn test_compute_precomputed_commitment_deterministic() { - use crate::tables::decode::compute_precomputed_commitment; - use stark::proof::options::ProofOptions; - - // Same instructions should produce same commitment - let mut instructions = U64HashMap::default(); - instructions.insert( - 0x1000, - Instruction::Arith { - dst: 1, - src1: 2, - src2: 3, + dst: 3, + src: 1, + imm: 5, op: ArithOp::Add, }, + 4, ); - instructions.insert( - 0x1004, - Instruction::Arith { - dst: 4, - src1: 5, - src2: 6, - op: ArithOp::Sub, - }, - ); + assert_eq!(addi.imm, 5); - let options = ProofOptions::default_test_options(); - - let commitment1 = compute_precomputed_commitment(&instructions, &options); - let commitment2 = compute_precomputed_commitment(&instructions, &options); - - assert_eq!( - commitment1, commitment2, - "Same instructions should produce same commitment" - ); -} - -#[test] -fn test_compute_precomputed_commitment_different_programs() { - use crate::tables::decode::compute_precomputed_commitment; - use stark::proof::options::ProofOptions; - - let options = ProofOptions::default_test_options(); - - // Program A: ADD instruction - let mut program_a = U64HashMap::default(); - program_a.insert( - 0x1000, - Instruction::Arith { - dst: 1, - src1: 2, - src2: 3, - op: ArithOp::Add, + let beq = DecodeEntry::from_instruction( + 0, + Instruction::Branch { + src1: 1, + src2: 2, + cond: Comparison::Equal, + offset: 8, }, + 4, ); + assert_eq!(beq.imm, 8, "branch offset"); - // Program B: SUB instruction (different from A) - let mut program_b = U64HashMap::default(); - program_b.insert( - 0x1000, - Instruction::Arith { - dst: 1, - src1: 2, - src2: 3, - op: ArithOp::Sub, // Different operation + let lw = DecodeEntry::from_instruction( + 0, + Instruction::Load { + dst: 3, + offset: 16, + base: 1, + width: LoadStoreWidth::Word, }, + 4, ); - - let commitment_a = compute_precomputed_commitment(&program_a, &options); - let commitment_b = compute_precomputed_commitment(&program_b, &options); - - assert_ne!( - commitment_a, commitment_b, - "Different programs should produce different commitments" - ); + assert_eq!(lw.imm, 16, "load offset"); } #[test] -fn test_compute_precomputed_commitment_different_pc() { - use crate::tables::decode::compute_precomputed_commitment; - use stark::proof::options::ProofOptions; - - let options = ProofOptions::default_test_options(); - - // Program A: instruction at PC 0x1000 - let mut program_a = U64HashMap::default(); - program_a.insert( - 0x1000, - Instruction::Arith { - dst: 1, - src1: 2, - src2: 3, - op: ArithOp::Add, - }, - ); - - // Program B: same instruction at different PC - let mut program_b = U64HashMap::default(); - program_b.insert( - 0x2000, // Different PC - Instruction::Arith { - dst: 1, - src1: 2, - src2: 3, +fn test_decode_entry_negative_imm_sign_extended() { + let addi = DecodeEntry::from_instruction( + 0, + Instruction::ArithImm { + dst: 3, + src: 1, + imm: -1, op: ArithOp::Add, }, + 4, ); - - let commitment_a = compute_precomputed_commitment(&program_a, &options); - let commitment_b = compute_precomputed_commitment(&program_b, &options); - - assert_ne!( - commitment_a, commitment_b, - "Programs with different PCs should produce different commitments" + assert_eq!( + addi.imm, + u64::MAX, + "-1 sign-extends to the full 64-bit word" ); } // ========================================================================= -// instructions_from_elf tests (verifier vs executor consistency) +// generate_decode_trace // ========================================================================= -/// Test that instructions_from_elf produces the same result as the executor. -#[test] -fn test_instructions_from_elf_matches_executor() { - // Run executor to get instructions - let (_elf, _logs, executor_instructions) = run_asm_elf("arith_8"); - - // Load the same ELF and extract instructions directly - let manifest_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")); - let elf_path = manifest_dir - .parent() - .unwrap() - .join("executor/program_artifacts/asm/arith_8.elf"); - let elf_bytes = std::fs::read(&elf_path).expect("Failed to read ELF file"); - let elf = Elf::load(&elf_bytes).expect("Failed to load ELF"); - - let verifier_instructions = - instructions_from_elf(&elf).expect("Failed to extract instructions"); +const TEST_PC: u64 = 0x1000; - // Compare via DecodeEntry (what matters for the DECODE table) - for (pc, executor_instr) in executor_instructions.iter() { - let verifier_instr = verifier_instructions - .get(pc) - .unwrap_or_else(|| panic!("Verifier missing instruction at PC {:#x}", pc)); - - // Compare by converting to DecodeEntry - this is what the DECODE table uses - let executor_entry = DecodeEntry::from_instruction(*pc, *executor_instr); - let verifier_entry = DecodeEntry::from_instruction(*pc, *verifier_instr); - - assert_eq!( - executor_entry.packed_decode(), - verifier_entry.packed_decode(), - "packed_decode mismatch at PC {:#x}", - pc - ); - assert_eq!( - executor_entry.imm, verifier_entry.imm, - "imm mismatch at PC {:#x}", - pc - ); +fn test_instr() -> Instruction { + Instruction::ArithImm { + dst: 3, + src: 1, + imm: 7, + op: ArithOp::Add, } - - // Verifier may have more instructions (all executable code vs only executed code) - // but every executed instruction must match - assert!( - verifier_instructions.len() >= executor_instructions.len(), - "Verifier should have at least as many instructions as executor" - ); } -/// Test instructions_from_elf with a more complex program. #[test] -fn test_instructions_from_elf_matches_executor_complex() { - let (_elf, _logs, executor_instructions) = run_asm_elf("all_instructions_64"); - - let manifest_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")); - let elf_path = manifest_dir - .parent() - .unwrap() - .join("executor/program_artifacts/asm/all_instructions_64.elf"); - let elf_bytes = std::fs::read(&elf_path).expect("Failed to read ELF file"); - let elf = Elf::load(&elf_bytes).expect("Failed to load ELF"); - - let verifier_instructions = - instructions_from_elf(&elf).expect("Failed to extract instructions"); - - // Every executed instruction must be present and match - for (pc, executor_instr) in executor_instructions.iter() { - let verifier_instr = verifier_instructions - .get(pc) - .unwrap_or_else(|| panic!("Verifier missing instruction at PC {:#x}", pc)); - - // Compare via DecodeEntry - let executor_entry = DecodeEntry::from_instruction(*pc, *executor_instr); - let verifier_entry = DecodeEntry::from_instruction(*pc, *verifier_instr); - - assert_eq!( - executor_entry.packed_decode(), - verifier_entry.packed_decode(), - "packed_decode mismatch at PC {:#x}", - pc - ); - assert_eq!( - executor_entry.imm, verifier_entry.imm, - "imm mismatch at PC {:#x}", - pc - ); - } +fn test_decode_table_instruction_row() { + let entry = DecodeEntry::from_instruction(TEST_PC, test_instr(), 4); + let mut instrs: U64HashMap = U64HashMap::default(); + instrs.insert(TEST_PC, test_instr()); + let (trace, pc_to_row) = generate_decode_trace(&instrs); + + let row = trace.main_table.get_row(pc_to_row[&TEST_PC]); + assert_eq!(row[cols::PC_0], (TEST_PC & 0xFFFF_FFFF).into()); + assert_eq!(row[cols::PACKED_DECODE], entry.packed_decode().into()); + assert_eq!(row[cols::IMM_0], (entry.imm & 0xFFFF_FFFF).into()); } -/// Test that instructions_from_elf includes all executable instructions, -/// not just the ones that were executed. #[test] -fn test_instructions_from_elf_includes_all_executable() { - let manifest_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")); - let elf_path = manifest_dir - .parent() - .unwrap() - .join("executor/program_artifacts/asm/all_branches_16.elf"); - let elf_bytes = std::fs::read(&elf_path).expect("Failed to read ELF file"); - let elf = Elf::load(&elf_bytes).expect("Failed to load ELF"); - - let instructions = instructions_from_elf(&elf).expect("Failed to extract instructions"); - - // Should have decoded all executable code - assert!( - !instructions.is_empty(), - "Should have extracted some instructions" - ); - - // All PCs should be 4-byte aligned - for (pc, _) in instructions.iter() { - assert_eq!(pc % 4, 0, "PC {:#x} is not 4-byte aligned", pc); - } -} - -// ========================================================================= -// Soundness tests (prover/verifier decoupling) -// ========================================================================= - -/// SECURITY TEST: Verifier with different ELF rejects proof. -/// -/// This test proves the security model works: -/// - Prover runs program A, generates proof with DECODE commitment from ELF A -/// - Verifier has ELF B, computes DECODE commitment from ELF B -/// - Commitments differ → Fiat-Shamir challenges differ → verification FAILS -/// -/// This demonstrates that a verifier who independently has the correct ELF -/// will reject proofs from a prover who ran a different program. -#[test] -fn test_decode_soundness_different_elf_rejected() { - use crypto::fiat_shamir::default_transcript::DefaultTranscript; - use stark::proof::options::ProofOptions; - use stark::traits::AIR; - use stark::verifier::{IsStarkVerifier, Verifier}; - - use crate::tables::decode::{self, commitment_from_elf}; - use crate::tables::trace_builder::Traces; - use crate::tables::types::{GoldilocksExtension, GoldilocksField}; - use crate::test_utils::{ - create_bitwise_air, create_branch_air, create_cpu_air, create_decode_air, create_halt_air, - create_load_air, create_lt_air, create_memw_air, - }; - - type F = GoldilocksField; - type E = GoldilocksExtension; - - let proof_options = ProofOptions::default_test_options(); - - // Load two DIFFERENT ELF files - let manifest_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")); - let elf_path_a = manifest_dir - .parent() - .unwrap() - .join("executor/program_artifacts/asm/arith_8.elf"); - let elf_path_b = manifest_dir - .parent() - .unwrap() - .join("executor/program_artifacts/asm/test_sub_8.elf"); - - let elf_bytes_a = std::fs::read(&elf_path_a).expect("Failed to read ELF A"); - let elf_bytes_b = std::fs::read(&elf_path_b).expect("Failed to read ELF B"); - - let elf_a = Elf::load(&elf_bytes_a).expect("Failed to load ELF A"); - let elf_b = Elf::load(&elf_bytes_b).expect("Failed to load ELF B"); - - // Verify the two programs produce different commitments - let commitment_a = commitment_from_elf(&elf_a, &proof_options).expect("commitment A"); - let commitment_b = commitment_from_elf(&elf_b, &proof_options).expect("commitment B"); - assert_ne!( - commitment_a, commitment_b, - "Test requires two different programs with different commitments" - ); - - // ========================================================================= - // PROVER: Runs program A, builds traces, generates proof - // ========================================================================= - let executor_a = - executor::vm::execution::Executor::new(&elf_a, vec![]).expect("Failed to create executor"); - let result_a = executor_a.run().expect("Failed to run program A"); - - let mut traces = - Traces::from_logs_minimal(&result_a.logs, result_a.instructions, &Default::default()) - .unwrap(); - - // Prover builds AIRs with commitment from ELF A - let prover_cpu_air = create_cpu_air(&proof_options); - let prover_bitwise_air = create_bitwise_air(&proof_options); - let prover_lt_air = create_lt_air(&proof_options); - let prover_memw_air = create_memw_air(&proof_options); - let prover_load_air = create_load_air(&proof_options); - let prover_branch_air = create_branch_air(&proof_options); - let prover_halt_air = create_halt_air(&proof_options); - let prover_decode_air = create_decode_air(&proof_options).with_preprocessed( - commitment_a, // Prover uses commitment from ELF A - decode::NUM_PRECOMPUTED_COLS, - ); - - let air_trace_pairs: Vec<( - &dyn AIR, - _, - _, - )> = vec![ - (&prover_cpu_air, &mut traces.cpus[0], &()), - (&prover_bitwise_air, &mut traces.bitwise, &()), - (&prover_lt_air, &mut traces.lts[0], &()), - (&prover_memw_air, &mut traces.memws[0], &()), - (&prover_load_air, &mut traces.loads[0], &()), - (&prover_branch_air, &mut traces.branches[0], &()), - (&prover_halt_air, &mut traces.halt, &()), - (&prover_decode_air, &mut traces.decode, &()), - ]; - - let proof = multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])) - .expect("Prover failed to generate proof"); - - // ========================================================================= - // VERIFIER: Has ELF B (different program!), computes commitment from it - // ========================================================================= - let verifier_cpu_air = create_cpu_air(&proof_options); - let verifier_bitwise_air = create_bitwise_air(&proof_options); - let verifier_lt_air = create_lt_air(&proof_options); - let verifier_memw_air = create_memw_air(&proof_options); - let verifier_load_air = create_load_air(&proof_options); - let verifier_branch_air = create_branch_air(&proof_options); - let verifier_halt_air = create_halt_air(&proof_options); - let verifier_decode_air = create_decode_air(&proof_options).with_preprocessed( - commitment_b, // Verifier uses commitment from ELF B (DIFFERENT!) - decode::NUM_PRECOMPUTED_COLS, - ); - - let verifier_airs: Vec<&dyn AIR> = vec![ - &verifier_cpu_air, - &verifier_bitwise_air, - &verifier_lt_air, - &verifier_memw_air, - &verifier_load_air, - &verifier_branch_air, - &verifier_halt_air, - &verifier_decode_air, - ]; +fn test_decode_table_padding_row() { + let mut instrs: U64HashMap = U64HashMap::default(); + instrs.insert(TEST_PC, test_instr()); + let (trace, pc_to_row) = generate_decode_trace(&instrs); - let result = Verifier::multi_verify( - &verifier_airs, - &proof, - &mut DefaultTranscript::::new(&[]), - &FieldElement::zero(), - ); - - // With different ELFs, verification should FAIL (secure!) - assert!( - !result, - "Verifier with different ELF should REJECT the proof" - ); -} - -/// SECURITY TEST: Verifier with same ELF accepts proof. -/// -/// Complementary test: when prover and verifier have the SAME ELF, -/// verification should succeed. -#[test] -fn test_decode_soundness_same_elf_accepted() { - use crypto::fiat_shamir::default_transcript::DefaultTranscript; - use stark::proof::options::ProofOptions; - use stark::verifier::{IsStarkVerifier, Verifier}; - - use crate::VmAirs; - use crate::tables::types::GoldilocksExtension; - - type E = GoldilocksExtension; - - let proof_options = ProofOptions::default_test_options(); - - // Load the SAME ELF for both prover and verifier - let manifest_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")); - let elf_path = manifest_dir - .parent() - .unwrap() - .join("executor/program_artifacts/asm/arith_8.elf"); - - let elf_bytes = std::fs::read(&elf_path).expect("Failed to read ELF"); - - // Prover loads ELF - let prover_elf = Elf::load(&elf_bytes).expect("Prover: failed to load ELF"); - // Verifier loads ELF independently (same bytes) - let verifier_elf = Elf::load(&elf_bytes).expect("Verifier: failed to load ELF"); - - // ========================================================================= - // PROVER: Runs program, builds traces, generates proof - // ========================================================================= - let executor = executor::vm::execution::Executor::new(&prover_elf, vec![]) - .expect("Failed to create executor"); - let result = executor.run().expect("Failed to run program"); - - let mut traces = Traces::from_elf_and_logs( - &prover_elf, - &result.logs, - &Default::default(), - &[], - #[cfg(feature = "disk-spill")] - stark::storage_mode::StorageMode::Ram, - ) - .unwrap(); - let table_counts = traces.table_counts(); - let prover_airs = VmAirs::new( - &prover_elf, - &proof_options, - false, - &traces.page_configs, - &table_counts, - None, - ); - - let proof = multi_prove_ram( - prover_airs.air_trace_pairs(&mut traces), - &mut DefaultTranscript::::new(&[]), - ) - .expect("Prover failed to generate proof"); - // ========================================================================= - // VERIFIER: Loads same ELF independently, verifies proof - // ========================================================================= - let verifier_airs = VmAirs::new( - &verifier_elf, - &proof_options, - false, - &traces.page_configs, - &table_counts, - None, - ); - let verifier_air_refs = verifier_airs.air_refs(); - let mut replay_transcript = DefaultTranscript::::new(&[]); - let expected_bus_balance = crate::compute_expected_commit_bus_balance( - &verifier_air_refs, - &proof, - &traces.public_output_bytes, - &mut replay_transcript, - ) - .expect("fingerprint collision in test"); - - let result = Verifier::multi_verify( - &verifier_air_refs, - &proof, - &mut DefaultTranscript::::new(&[]), - &expected_bus_balance, - ); - - // With same ELF, verification should SUCCEED - assert!(result, "Verifier with same ELF should ACCEPT the proof"); -} - -#[test] -fn test_tables_from_elf_single_executable_segment() { - // ADDI x1, x0, 42 (opcode: 0x02a00093) - // ADDI x2, x1, 10 (opcode: 0x00a08113) - let elf = Elf { - entry_point: 0x1000, - data: vec![Segment { - base_addr: 0x1000, - values: vec![0x02a00093, 0x00a08113], - is_executable: true, - }], - }; - - let tables = tables_from_elf(&elf).unwrap(); - - // Check DECODE table - assert_eq!(tables.pc_to_row.len(), 3); // 2 instructions + CPU padding - assert!(tables.pc_to_row.contains_key(&0x1000)); - assert!(tables.pc_to_row.contains_key(&0x1004)); - assert!( - tables - .pc_to_row - .contains_key(&crate::tables::cpu::CPU_PADDING_PC) + let row = trace.main_table.get_row(pc_to_row[&CPU_PADDING_PC]); + assert_eq!(row[cols::PC_0], CPU_PADDING_PC.into()); + assert_eq!( + row[cols::PACKED_DECODE], + 0u64.into(), + "padding entry has packed_decode = 0" ); + assert_eq!(row[cols::IMM_0], 0u64.into()); } #[test] -fn test_tables_from_elf_mixed_segments() { - // Executable segment with instructions - // Data segment with data (not included in DECODE) - let elf = Elf { - entry_point: 0x1000, - data: vec![ - Segment { - base_addr: 0x1000, - values: vec![0x02a00093], // ADDI instruction - is_executable: true, - }, - Segment { - base_addr: 0x2000, - values: vec![0xDEADBEEF, 0xCAFEBABE], // Data - is_executable: false, - }, - ], - }; - - let tables = tables_from_elf(&elf).unwrap(); - - // DECODE: only executable segment (1 instruction + CPU padding) - assert_eq!(tables.pc_to_row.len(), 2); - assert!(tables.pc_to_row.contains_key(&0x1000)); - assert!(!tables.pc_to_row.contains_key(&0x2000)); // Data not in decode -} - -#[test] -fn test_tables_from_elf_empty() { - let elf = Elf { - entry_point: 0x1000, - data: vec![], - }; - - let tables = tables_from_elf(&elf).unwrap(); - - // DECODE: only CPU padding entry - assert_eq!(tables.pc_to_row.len(), 1); +fn test_decode_table_is_power_of_two() { + let mut instrs: U64HashMap = U64HashMap::default(); + instrs.insert(TEST_PC, test_instr()); + let (trace, _) = generate_decode_trace(&instrs); assert!( - tables - .pc_to_row - .contains_key(&crate::tables::cpu::CPU_PADDING_PC) + trace.main_table.height.is_power_of_two(), + "decode table is padded to a power of two" ); + assert_eq!(trace.main_table.width, cols::NUM_COLUMNS); } // ========================================================================= -// verify_with_options: optional decode_commitment parameter +// verify_with_options: optional decode_commitment parameter (#640) // ========================================================================= #[test] @@ -1228,8 +242,8 @@ fn decode_commitment_zero_bytes_rejects() { /// AIR or FFT pipeline changes, this drifts and the test fails — /// regenerate via the `print_decode_commitment_for_sub` helper below. const SUB_DECODE_COMMITMENT_BLOWUP_2: [u8; 32] = [ - 0x00, 0x83, 0x59, 0xa3, 0x34, 0x5f, 0x86, 0x79, 0x59, 0x71, 0xc8, 0x71, 0x54, 0x2c, 0xc4, 0xac, - 0x8b, 0x9c, 0x48, 0x9b, 0x25, 0xa3, 0x6a, 0xc7, 0x48, 0xee, 0x71, 0xe6, 0x77, 0xfb, 0x59, 0xfa, + 0x60, 0x66, 0x0b, 0x18, 0x0d, 0x41, 0x08, 0xb3, 0x3a, 0x03, 0x99, 0x03, 0x8c, 0x9d, 0x12, 0x57, + 0x68, 0x8d, 0xed, 0x13, 0x60, 0xeb, 0x1d, 0x2b, 0xa8, 0xea, 0x1c, 0x76, 0xc9, 0xdd, 0x25, 0xaf, ]; #[test] diff --git a/prover/src/tests/eq_tests.rs b/prover/src/tests/eq_tests.rs new file mode 100644 index 000000000..df0c76fdb --- /dev/null +++ b/prover/src/tests/eq_tests.rs @@ -0,0 +1,125 @@ +//! Tests for the EQ (equality) table. + +use crate::tables::eq::{EqOperation, bus_interactions, cols, generate_eq_trace}; +use crate::tables::types::{BusId, FE}; + +#[test] +fn test_compute_eq_and_res() { + assert!(EqOperation::new(5, 5, false).compute_eq()); + assert!(!EqOperation::new(5, 3, false).compute_eq()); + + // res = eq XOR invert + assert!(EqOperation::new(5, 5, false).compute_res()); // 1 XOR 0 + assert!(!EqOperation::new(5, 5, true).compute_res()); // 1 XOR 1 + assert!(!EqOperation::new(5, 3, false).compute_res()); // 0 XOR 0 + assert!(EqOperation::new(5, 3, true).compute_res()); // 0 XOR 1 +} + +#[test] +fn test_trace_equal_operands() { + // a == b → diff = 0, eq = 1, res = 1 (invert = 0) + let trace = generate_eq_trace(&[EqOperation::new(5, 5, false)]); + assert_eq!(trace.main_table.width, cols::NUM_COLUMNS); + assert_eq!(trace.main_table.height, 4); // padded to min 4 + + let row = trace.main_table.get_row(0); + assert_eq!(row[cols::A_0], FE::from(5u64)); + assert_eq!(row[cols::B_0], FE::from(5u64)); + assert_eq!(row[cols::DIFF_0], FE::from(0u64)); + assert_eq!(row[cols::DIFF_1], FE::from(0u64)); + assert_eq!(row[cols::DIFF_2], FE::from(0u64)); + assert_eq!(row[cols::DIFF_3], FE::from(0u64)); + assert_eq!(row[cols::EQ], FE::from(1u64)); + assert_eq!(row[cols::RES], FE::from(1u64)); + assert_eq!(row[cols::INVERT], FE::from(0u64)); + assert_eq!(row[cols::MU], FE::from(1u64)); +} + +#[test] +fn test_trace_unequal_operands() { + // a = 5, b = 3 → diff = 2, eq = 0, res = 0 + let trace = generate_eq_trace(&[EqOperation::new(5, 3, false)]); + let row = trace.main_table.get_row(0); + assert_eq!(row[cols::DIFF_0], FE::from(2u64)); + assert_eq!(row[cols::EQ], FE::from(0u64)); + assert_eq!(row[cols::RES], FE::from(0u64)); +} + +#[test] +fn test_trace_invert_and_wrapping() { + // a == b with invert = 1 → eq = 1, res = 0 + let trace = generate_eq_trace(&[EqOperation::new(7, 7, true)]); + let row = trace.main_table.get_row(0); + assert_eq!(row[cols::EQ], FE::from(1u64)); + assert_eq!(row[cols::INVERT], FE::from(1u64)); + assert_eq!(row[cols::RES], FE::from(0u64)); + + // a = 0, b = 1 → diff = 0 - 1 = 0xFFFF_FFFF_FFFF_FFFF (all halves 0xFFFF), eq = 0 + let trace = generate_eq_trace(&[EqOperation::new(0, 1, false)]); + let row = trace.main_table.get_row(0); + assert_eq!(row[cols::DIFF_0], FE::from(0xFFFFu64)); + assert_eq!(row[cols::DIFF_3], FE::from(0xFFFFu64)); + assert_eq!(row[cols::EQ], FE::from(0u64)); +} + +#[test] +fn test_trace_dword_split() { + // a spanning both words: 0x1234_5678_9ABC_DEF0 + let a = 0x1234_5678_9ABC_DEF0u64; + let trace = generate_eq_trace(&[EqOperation::new(a, 0, false)]); + let row = trace.main_table.get_row(0); + assert_eq!(row[cols::A_0], FE::from(0x9ABC_DEF0u64)); + assert_eq!(row[cols::A_1], FE::from(0x1234_5678u64)); +} + +#[test] +fn test_multiplicity_aggregation() { + // Same op three times + one distinct → 2 unique rows, padded to 4. + let ops = vec![ + EqOperation::new(5, 5, false), + EqOperation::new(9, 8, false), + EqOperation::new(5, 5, false), + EqOperation::new(5, 5, false), + ]; + let trace = generate_eq_trace(&ops); + assert_eq!(trace.main_table.height, 4); + + let mut found = false; + for row_idx in 0..4 { + let row = trace.main_table.get_row(row_idx); + if row[cols::A_0] == FE::from(5u64) && row[cols::B_0] == FE::from(5u64) { + assert_eq!(row[cols::MU], FE::from(3u64)); + found = true; + } + } + assert!(found, "expected the (5,5) row with multiplicity 3"); +} + +#[test] +fn test_bus_interactions_shape() { + let interactions = bus_interactions(); + // 4 IS_HALF senders + 1 ZERO sender + 1 ALU receiver + assert_eq!(interactions.len(), 6); + + let is_half = interactions + .iter() + .filter(|i| i.bus_id == u64::from(BusId::IsHalfword) && i.is_sender) + .count(); + assert_eq!(is_half, 4); + + let zero = interactions + .iter() + .filter(|i| i.bus_id == u64::from(BusId::Zero) && i.is_sender) + .count(); + assert_eq!(zero, 1); + + // Exactly one ALU receiver carrying [a, b, flags, res]. + let alu: Vec<_> = interactions + .iter() + .filter(|i| i.bus_id == u64::from(BusId::Alu)) + .collect(); + assert_eq!(alu.len(), 1); + assert!(!alu[0].is_sender, "ALU is a receiver for EQ"); + // [a, b, flags, res, 0] — the ALU output is DWordWL ([res, 0]). + assert_eq!(alu[0].values.len(), 5); +} diff --git a/prover/src/tests/lt_bus_tests.rs b/prover/src/tests/lt_bus_tests.rs index dcc555780..b6148cfdc 100644 --- a/prover/src/tests/lt_bus_tests.rs +++ b/prover/src/tests/lt_bus_tests.rs @@ -70,7 +70,7 @@ fn new_sender_air( let auxiliary_trace_build_data = AuxiliaryTraceBuildData { interactions: vec![BusInteraction::sender( - BusId::Lt, + BusId::Alu, Multiplicity::Column(sender_cols::MU), vec![ BusValue::Packed { @@ -126,7 +126,7 @@ fn new_receiver_air( // Use the same bus interaction as the LT table let auxiliary_trace_build_data = AuxiliaryTraceBuildData { interactions: vec![BusInteraction::receiver( - BusId::Lt, + BusId::Alu, Multiplicity::Column(cols::MU), vec![ BusValue::Packed { diff --git a/prover/src/tests/lt_tests.rs b/prover/src/tests/lt_tests.rs index 859b2808f..21ce549aa 100644 --- a/prover/src/tests/lt_tests.rs +++ b/prover/src/tests/lt_tests.rs @@ -162,6 +162,9 @@ fn test_multiplicity_different_signed_flags() { #[test] fn test_bus_interactions_count() { let interactions = bus_interactions(); - // MSB16 x2 + IS_HALFWORD x6 (lhs_sub_rhs x4 + lhs[1] + rhs[1]) + LT x1 = 9 interactions + // MSB16 x2 + IS_HALFWORD x6 (lhs_sub_rhs x4 + lhs[1] + rhs[1]) + // + ALU receiver x1 (every LT lookup goes through the unified ALU bus + // — CPU SLT/BLT/BGE dispatch and the internal memw/dvrm + // timestamp / |r|<|d| checks) = 9. assert_eq!(interactions.len(), 9); } diff --git a/prover/src/tests/mod.rs b/prover/src/tests/mod.rs index 47d8956f3..f9c7e93f5 100644 --- a/prover/src/tests/mod.rs +++ b/prover/src/tests/mod.rs @@ -9,6 +9,8 @@ pub mod branch_bus_tests; #[cfg(test)] pub mod branch_constraints_tests; #[cfg(test)] +pub mod bytewise_tests; +#[cfg(test)] pub mod commit_tests; #[cfg(test)] pub mod compute_commit_bus_offset_tests; @@ -17,14 +19,20 @@ pub mod constraints_tests; #[cfg(all(test, feature = "disk-spill"))] pub mod count_table_lengths_drift_tests; #[cfg(test)] +pub mod cpu32_tests; +#[cfg(test)] pub mod cpu_tests; #[cfg(test)] +pub mod decode_layout_tests; +#[cfg(test)] pub mod decode_tests; #[cfg(all(test, feature = "disk-spill"))] pub mod disk_spill_tests; #[cfg(test)] pub mod dvrm_tests; #[cfg(test)] +pub mod eq_tests; +#[cfg(test)] pub mod keccak_rnd_tests; #[cfg(test)] pub mod load_tests; @@ -51,6 +59,8 @@ pub mod statement_tests; #[cfg(test)] pub mod static_commitments_tests; #[cfg(test)] +pub mod store_tests; +#[cfg(test)] pub mod templates_tests; #[cfg(test)] pub mod trace_builder_tests; diff --git a/prover/src/tests/mul_tests.rs b/prover/src/tests/mul_tests.rs index 731e7952d..b58c4ecef 100644 --- a/prover/src/tests/mul_tests.rs +++ b/prover/src/tests/mul_tests.rs @@ -254,11 +254,12 @@ fn test_different_signed_flags_separate_rows() { #[test] fn test_bus_interactions_count() { let interactions = bus_interactions(); - // Expected interactions: + // Expected interactions (every MUL lookup goes through the unified ALU + // bus — CPU MUL/MULH dispatch and dvrm's `d*q` consistency): // - 2x MSB16 senders (lhs sign, rhs sign) // - 8x IS_HALF senders (lo[0..4], hi[0..4]) // - 4x IS_B20 senders (carry[0..4] virtual range checks) - // - 2x MUL receivers (lo, hi) + // - 2x ALU receivers (lo, hi) // Total: 2 + 8 + 4 + 2 = 16 assert_eq!(interactions.len(), 16, "Expected 16 bus interactions"); } diff --git a/prover/src/tests/prove_elfs_tests.rs b/prover/src/tests/prove_elfs_tests.rs index 5b9ec0f8a..4ba28ced7 100644 --- a/prover/src/tests/prove_elfs_tests.rs +++ b/prover/src/tests/prove_elfs_tests.rs @@ -2165,6 +2165,10 @@ fn test_verify_rejects_zero_table_counts() { shift: 0, branch: 0, memw_register: 0, + eq: 0, + bytewise: 0, + store: 0, + cpu32: 0, }, ..vm_proof }; @@ -2233,6 +2237,10 @@ fn test_crafted_zero_count_proof_must_not_verify() { shift: 0, branch: 0, memw_register: 0, + eq: 0, + bytewise: 0, + store: 0, + cpu32: 0, }; let airs = VmAirs::new(&elf, &proof_options, true, &[], &zero_counts, None); diff --git a/prover/src/tests/statement_tests.rs b/prover/src/tests/statement_tests.rs index 659113be6..55ac5a15b 100644 --- a/prover/src/tests/statement_tests.rs +++ b/prover/src/tests/statement_tests.rs @@ -19,6 +19,10 @@ fn sample_counts() -> TableCounts { shift: 1, branch: 2, memw_register: 1, + eq: 1, + bytewise: 1, + store: 1, + cpu32: 1, } } diff --git a/prover/src/tests/store_tests.rs b/prover/src/tests/store_tests.rs new file mode 100644 index 000000000..6b0ba1fd9 --- /dev/null +++ b/prover/src/tests/store_tests.rs @@ -0,0 +1,83 @@ +//! Tests for the STORE table. + +use crate::tables::store::{StoreOperation, bus_interactions, cols, generate_store_trace}; +use crate::tables::types::{BusId, FE}; +use stark::lookup::{BusValue, LinearTerm}; + +#[test] +fn test_new_sets_width_flags() { + let sb = StoreOperation::new(0, 0, 0, 1); + assert!(!sb.write2 && !sb.write4 && !sb.write8); // 1 byte: none set + assert!(StoreOperation::new(0, 0, 0, 2).write2); + assert!(StoreOperation::new(0, 0, 0, 4).write4); + assert!(StoreOperation::new(0, 0, 0, 8).write8); +} + +#[test] +fn test_trace_layout() { + let op = StoreOperation::new(0xDEAD_BEEF_0000_1000, 0x40, 0x1122_3344_5566_7788, 8); + let trace = generate_store_trace(&[op]); + assert_eq!(trace.main_table.width, cols::NUM_COLUMNS); + assert_eq!(trace.main_table.height, 4); // padded to min 4 + + let row = trace.main_table.get_row(0); + assert_eq!(row[cols::BASE_ADDRESS_0], FE::from(0x0000_1000u64)); + assert_eq!(row[cols::BASE_ADDRESS_1], FE::from(0xDEAD_BEEFu64)); + assert_eq!(row[cols::TIMESTAMP_0], FE::from(0x40u64)); + assert_eq!(row[cols::WRITE8], FE::from(1u64)); + assert_eq!(row[cols::WRITE2], FE::from(0u64)); + // value little-endian byte split + assert_eq!(row[cols::VALUE[0]], FE::from(0x88u64)); + assert_eq!(row[cols::VALUE[7]], FE::from(0x11u64)); + assert_eq!(row[cols::MU], FE::from(1u64)); +} + +#[test] +fn test_bus_interactions_shape() { + let interactions = bus_interactions(); + // 1 MEMW write + 1 MEMORY receiver + 8 ARE_BYTES. + assert_eq!(interactions.len(), 10); + + let memw = interactions + .iter() + .filter(|i| i.bus_id == u64::from(BusId::Memw) && i.is_sender) + .count(); + assert_eq!(memw, 1); + + let are_bytes = interactions + .iter() + .filter(|i| i.bus_id == u64::from(BusId::AreBytes) && i.is_sender) + .count(); + assert_eq!(are_bytes, 8); + + let memory: Vec<_> = interactions + .iter() + .filter(|i| i.bus_id == u64::from(BusId::MemoryOp)) + .collect(); + assert_eq!(memory.len(), 1); + assert!(!memory[0].is_sender, "STORE receives MEMORY"); + // [timestamp, base_address, value, flags, out_lo, out_hi] + assert_eq!(memory[0].values.len(), 6); +} + +#[test] +fn test_memory_flags_include_memory_op_bit() { + // Q7 fix: the MEMORY flags must carry the memory_op bit (constant 1). + let interactions = bus_interactions(); + let memory = interactions + .iter() + .find(|i| i.bus_id == u64::from(BusId::MemoryOp)) + .expect("MEMORY receiver exists"); + + // flags is the 4th value (index 3). + match &memory.values[3] { + BusValue::Linear(terms) => { + let has_memory_op = terms.iter().any(|t| matches!(t, LinearTerm::Constant(1))); + assert!( + has_memory_op, + "MEMORY flags must include the memory_op constant 1 (Q7 fix)" + ); + } + _ => panic!("expected a linear flags term"), + } +} diff --git a/prover/src/tests/trace_builder_tests.rs b/prover/src/tests/trace_builder_tests.rs index 199ce71db..ee57f2075 100644 --- a/prover/src/tests/trace_builder_tests.rs +++ b/prover/src/tests/trace_builder_tests.rs @@ -217,7 +217,9 @@ fn test_lt_deduplication() { && row[lt::cols::RHS_0] == FE::from(10u64) && row[lt::cols::SIGNED] == FE::from(1u64) { - // Found our SLT row - verify multiplicity is 3 + // Found our SLT row - verify multiplicity is 3. Every LT lookup + // (including SLT) goes through the unified ALU bus and + // is counted in the single `MU` column. assert_eq!(row[lt::cols::MU], FE::from(3u64)); found_slt = true; break; @@ -268,10 +270,11 @@ fn test_bitwise_lookups_collected() { let traces = Traces::from_logs(&logs, instructions, &Default::default()).unwrap(); - // Check AND multiplicity was updated for (0x12, 0x34, 0) + // AND/OR/XOR now go through the BYTEWISE chip on the unified BYTE_ALU bus, + // so the AND byte (0x12, 0x34) increments MU_BYTE_ALU_AND. let row_idx = bitwise::row_index(0x12, 0x34, 0); let row = traces.bitwise.main_table.get_row(row_idx); - assert_eq!(row[bitwise::cols::MU_AND], FE::one()); + assert_eq!(row[bitwise::cols::MU_BYTE_ALU_AND], FE::one()); } #[test]