From d4199233c09975ebba347c1814db9a8efa7aeff2 Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Sat, 11 Apr 2026 14:01:39 +0100 Subject: [PATCH 01/13] fix(matrix): correct is_valid_view logic, checked stride_range arithmetic, fix is_empty on views, fix iterator_mut aliasing UB - is_valid_view: fix inverted boolean and off-by-one (end must be <= bound, not <) - stride_range: use checked_mul / checked_add; panic! with clear message on overflow (panic-on-overflow is by-design: the library must not silently access wrong memory) - DenseMatrixView::is_empty / DenseMatrixMutView::is_empty: fix inverted predicate (was `> 0` i.e. true when NOT empty; correct is `== 0`) - iterator_mut: assert no duplicate (row,col) indices are yielded so two live &mut T to the same slot are impossible; adds a debug-mode uniqueness check via a statically-sized index calculation that matches the actual ptr.add offset --- src/linalg/basic/matrix.rs | 156 ++++++++++++++++++++++++++++++------- 1 file changed, 128 insertions(+), 28 deletions(-) diff --git a/src/linalg/basic/matrix.rs b/src/linalg/basic/matrix.rs index 58f9846a..01c19ecb 100644 --- a/src/linalg/basic/matrix.rs +++ b/src/linalg/basic/matrix.rs @@ -57,7 +57,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> { vrows: Range, vcols: Range, ) -> Result { - if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) { + if !m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) { Err(Failed::input( "The specified view is outside of the matrix range", )) @@ -109,7 +109,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> { vrows: Range, vcols: Range, ) -> Result { - if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) { + if !m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) { Err(Failed::input( "The specified view is outside of the matrix range", )) @@ -145,10 +145,43 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> { fn iter_mut<'b>(&'b mut self, axis: u8) -> Box + 'b> { let column_major = self.column_major; let stride = self.stride; + let nrows = self.nrows; + let ncols = self.ncols; let ptr = self.values.as_mut_ptr(); + + // Safety: for each (r, c) pair the offset is uniquely determined by the + // index formula below, so no two iterations alias the same memory location. + // We assert this in debug mode by verifying the traversal covers exactly + // nrows * ncols distinct offsets within [0, values.len()). + #[cfg(debug_assertions)] + { + let len = self.values.len(); + let mut seen = std::collections::HashSet::new(); + match axis { + 0 => { + for r in 0..nrows { + for c in 0..ncols { + let off = if column_major { r + c * stride } else { r * stride + c }; + assert!(off < len, "iterator_mut: offset {off} out of bounds (len={len})"); + assert!(seen.insert(off), "iterator_mut: aliasing detected at offset {off}"); + } + } + } + _ => { + for c in 0..ncols { + for r in 0..nrows { + let off = if column_major { r + c * stride } else { r * stride + c }; + assert!(off < len, "iterator_mut: offset {off} out of bounds (len={len})"); + assert!(seen.insert(off), "iterator_mut: aliasing detected at offset {off}"); + } + } + } + } + } + match axis { - 0 => Box::new((0..self.nrows).flat_map(move |r| { - (0..self.ncols).map(move |c| unsafe { + 0 => Box::new((0..nrows).flat_map(move |r| { + (0..ncols).map(move |c| unsafe { &mut *ptr.add(if column_major { r + c * stride } else { @@ -156,8 +189,8 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> { }) }) })), - _ => Box::new((0..self.ncols).flat_map(move |c| { - (0..self.nrows).map(move |r| unsafe { + _ => Box::new((0..ncols).flat_map(move |c| { + (0..nrows).map(move |r| unsafe { &mut *ptr.add(if column_major { r + c * stride } else { @@ -242,7 +275,12 @@ impl DenseMatrix { self.values.iter() } - /// Check if the size of the requested view is bounded to matrix rows/cols count + /// Check if the size of the requested view is bounded to matrix rows/cols count. + /// + /// Returns `true` when the view is valid (all bounds are within the matrix dimensions). + /// A view is valid when: + /// - start <= end for both axes (non-reversed range) + /// - end <= dimension (exclusive upper bound does not exceed dimension size) fn is_valid_view( &self, n_rows: usize, @@ -250,13 +288,17 @@ impl DenseMatrix { vrows: &Range, vcols: &Range, ) -> bool { - !(vrows.end <= n_rows + vrows.start <= vrows.end + && vcols.start <= vcols.end + && vrows.end <= n_rows && vcols.end <= n_cols - && vrows.start <= n_rows - && vcols.start <= n_cols) } - /// Compute the range of the requested view: start, end, size of the slice + /// Compute the range of the requested view: start, end, size of the slice. + /// + /// All arithmetic uses checked operations; panics immediately if an overflow + /// would occur (panic-on-overflow is intentional — the library must not + /// silently read wrong memory). fn stride_range( &self, n_rows: usize, @@ -266,17 +308,43 @@ impl DenseMatrix { column_major: bool, ) -> (usize, usize, usize) { let (start, end, stride) = if column_major { - ( - vrows.start + vcols.start * n_rows, - vrows.end + (vcols.end - 1) * n_rows, - n_rows, - ) + let start = vrows + .start + .checked_add( + vcols + .start + .checked_mul(n_rows) + .expect("stride_range: integer overflow in start (column_major)"), + ) + .expect("stride_range: integer overflow in start (column_major)"); + let end = vrows + .end + .checked_add( + vcols + .end + .checked_sub(1) + .expect("stride_range: vcols.end underflow (column_major)") + .checked_mul(n_rows) + .expect("stride_range: integer overflow in end (column_major)"), + ) + .expect("stride_range: integer overflow in end (column_major)"); + (start, end, n_rows) } else { - ( - vrows.start * n_cols + vcols.start, - (vrows.end - 1) * n_cols + vcols.end, - n_cols, - ) + let start = vrows + .start + .checked_mul(n_cols) + .expect("stride_range: integer overflow in start (row_major)") + .checked_add(vcols.start) + .expect("stride_range: integer overflow in start (row_major)"); + let end = vrows + .end + .checked_sub(1) + .expect("stride_range: vrows.end underflow (row_major)") + .checked_mul(n_cols) + .expect("stride_range: integer overflow in end (row_major)") + .checked_add(vcols.end) + .expect("stride_range: integer overflow in end (row_major)"); + (start, end, n_cols) }; (start, end, stride) } @@ -417,9 +485,26 @@ impl MutArray for DenseMat let ptr = self.values.as_mut_ptr(); let column_major = self.column_major; let (nrows, ncols) = self.shape(); + + // Safety: each (r, c) pair maps to a unique offset via the index formula, + // so no two live &mut T can alias the same slot. + // The debug-mode assertion below verifies this invariant. + #[cfg(debug_assertions)] + { + let len = self.values.len(); + let mut seen = std::collections::HashSet::new(); + for r in 0..nrows { + for c in 0..ncols { + let off = if column_major { r + c * nrows } else { r * ncols + c }; + assert!(off < len, "iterator_mut: offset {off} out of bounds (len={len})"); + assert!(seen.insert(off), "iterator_mut: aliasing at offset {off}"); + } + } + } + match axis { - 0 => Box::new((0..self.nrows).flat_map(move |r| { - (0..self.ncols).map(move |c| unsafe { + 0 => Box::new((0..nrows).flat_map(move |r| { + (0..ncols).map(move |c| unsafe { &mut *ptr.add(if column_major { r + c * nrows } else { @@ -427,8 +512,8 @@ impl MutArray for DenseMat }) }) })), - _ => Box::new((0..self.ncols).flat_map(move |c| { - (0..self.nrows).map(move |r| unsafe { + _ => Box::new((0..ncols).flat_map(move |c| { + (0..nrows).map(move |r| unsafe { &mut *ptr.add(if column_major { r + c * nrows } else { @@ -507,7 +592,7 @@ impl Array for DenseMatrix } fn is_empty(&self) -> bool { - self.nrows * self.ncols > 0 + self.nrows == 0 || self.ncols == 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { @@ -545,7 +630,7 @@ impl Array for DenseMatrixView<'_, } fn is_empty(&self) -> bool { - self.nrows * self.ncols > 0 + self.nrows == 0 || self.ncols == 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { @@ -571,7 +656,7 @@ impl Array for DenseMatrix } fn is_empty(&self) -> bool { - self.nrows * self.ncols > 0 + self.nrows == 0 || self.ncols == 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { @@ -667,6 +752,21 @@ mod tests { let v = DenseMatrixView::new(&x, 0..3, 4..3); assert!(v.is_err()); } + + #[test] + fn test_is_empty_view_not_empty() { + let x = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.]]).unwrap(); + let v = DenseMatrixView::new(&x, 0..2, 0..2).unwrap(); + assert!(!v.is_empty(), "2x2 view should not be empty"); + } + + #[test] + fn test_is_empty_mut_view_not_empty() { + let mut x = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.]]).unwrap(); + let v = DenseMatrixMutView::new(&mut x, 0..2, 0..2).unwrap(); + assert!(!v.is_empty(), "2x2 mut view should not be empty"); + } + #[test] fn test_display() { let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap(); From 99f658e482f89332bb8967dd55cec836703f99f5 Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Sat, 11 Apr 2026 14:03:35 +0100 Subject: [PATCH 02/13] fix(dataset): guard deserialize_data against OOM, OOB, and NaN/Inf injection - Check minimum byte length before reading headers - Use checked_mul for num_samples * num_features to prevent overflow - Validate computed data length against actual slice length before allocation - Sanitize each f32 value: reject NaN and Inf bit patterns with a clear error (attackers can craft bit patterns that silently corrupt downstream numerics) --- src/dataset/mod.rs | 117 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 109 insertions(+), 8 deletions(-) diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs index 91628942..377f3c6e 100644 --- a/src/dataset/mod.rs +++ b/src/dataset/mod.rs @@ -87,31 +87,102 @@ pub(crate) fn serialize_data( pub(crate) fn deserialize_data( bytes: &[u8], ) -> Result<(Vec, Vec, usize, usize), io::Error> { - // read the same file back into a Vec of bytes const USIZE_SIZE: usize = std::mem::size_of::(); + // Header occupies two usize fields (num_features + num_samples) + const HEADER_LEN: usize = 2 * USIZE_SIZE; + + // Reject obviously-truncated buffers before reading any fields. + if bytes.len() < HEADER_LEN { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "deserialize_data: buffer too small for header (need {HEADER_LEN} bytes, got {})", + bytes.len() + ), + )); + } + let (num_samples, num_features) = { let mut buffer = [0u8; USIZE_SIZE]; buffer.copy_from_slice(&bytes[0..USIZE_SIZE]); let num_features = usize::from_le_bytes(buffer); - buffer.copy_from_slice(&bytes[8..8 + USIZE_SIZE]); + buffer.copy_from_slice(&bytes[USIZE_SIZE..HEADER_LEN]); let num_samples = usize::from_le_bytes(buffer); (num_samples, num_features) }; - let mut x = Vec::with_capacity(num_samples * num_features); + // Guard against integer overflow in num_samples * num_features. + let num_x_values = num_samples + .checked_mul(num_features) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "deserialize_data: num_samples * num_features overflows usize", + ) + })?; + + // Validate the total byte length before any allocation. + // Layout: HEADER_LEN + num_x_values * 4 + num_samples * 4 + let x_bytes = num_x_values.checked_mul(4).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "deserialize_data: x byte range overflows usize", + ) + })?; + let y_bytes = num_samples.checked_mul(4).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "deserialize_data: y byte range overflows usize", + ) + })?; + let expected_len = HEADER_LEN + .checked_add(x_bytes) + .and_then(|n| n.checked_add(y_bytes)) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "deserialize_data: total expected length overflows usize", + ) + })?; + if bytes.len() < expected_len { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "deserialize_data: buffer too short (expected {expected_len} bytes, got {})", + bytes.len() + ), + )); + } + + let mut x = Vec::with_capacity(num_x_values); let mut y = Vec::with_capacity(num_samples); let mut buffer = [0u8; 4]; - let mut c = 16; - for _ in 0..(num_samples * num_features) { + let mut c = HEADER_LEN; + + for _ in 0..num_x_values { buffer.copy_from_slice(&bytes[c..(c + 4)]); - x.push(f32::from_bits(u32::from_le_bytes(buffer))); + let v = f32::from_bits(u32::from_le_bytes(buffer)); + if !v.is_finite() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("deserialize_data: non-finite value in feature data (bits: {:#010x})", u32::from_le_bytes(buffer)), + )); + } + x.push(v); c += 4; } - for _ in 0..(num_samples) { + for _ in 0..num_samples { buffer.copy_from_slice(&bytes[c..(c + 4)]); - y.push(f32::from_bits(u32::from_le_bytes(buffer))); + let v = f32::from_bits(u32::from_le_bytes(buffer)); + if !v.is_finite() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("deserialize_data: non-finite value in target data (bits: {:#010x})", u32::from_le_bytes(buffer)), + )); + } + y.push(v); c += 4; } @@ -144,4 +215,34 @@ mod tests { assert_eq!(m[0].len(), 5); assert_eq!(*m[1][3], 9); } + + #[test] + fn deserialize_data_too_short() { + let result = deserialize_data(&[0u8; 4]); + assert!(result.is_err()); + } + + #[test] + fn deserialize_data_truncated_body() { + // Valid header: 1 sample, 1 feature, but no payload bytes + let mut buf = vec![0u8; 16]; + buf[0..8].copy_from_slice(&1usize.to_le_bytes()); // num_features = 1 + buf[8..16].copy_from_slice(&1usize.to_le_bytes()); // num_samples = 1 + // Expected total: 16 + 4 (x) + 4 (y) = 24 bytes, but we only supply 16 + let result = deserialize_data(&buf); + assert!(result.is_err()); + } + + #[test] + fn deserialize_data_nan_rejected() { + // Construct a valid 1x1 dataset where the feature value is NaN + let nan_bits: u32 = f32::NAN.to_bits(); + let mut buf = vec![0u8; 16 + 4 + 4]; + buf[0..8].copy_from_slice(&1usize.to_le_bytes()); // num_features = 1 + buf[8..16].copy_from_slice(&1usize.to_le_bytes()); // num_samples = 1 + buf[16..20].copy_from_slice(&nan_bits.to_le_bytes()); // x[0] = NaN + buf[20..24].copy_from_slice(&1.0f32.to_le_bytes()); // y[0] = 1.0 + let result = deserialize_data(&buf); + assert!(result.is_err()); + } } From e35a11bf4723c94ac2d7338c667cac517f30318d Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Sat, 11 Apr 2026 14:17:50 +0100 Subject: [PATCH 03/13] security: apply remaining fixes from PR plan (changes 7, 8, 9 + change 4 ndarray stride guard) - Change 7: add zero-norm guard in CosinePair::distances_from and init to prevent NaN cosine distances from poisoning BinaryHeap ordering. - Change 8: add NaN guard to gradient descent convergence loop so that a NaN gnorm panics immediately instead of silently exiting after 1 iteration with NaN weights. - Change 9: add jagged-row validation in from_2d_vec so that rows with different lengths return Err instead of panicking with index-out-of-bounds. - Change 4 (ndarray): add debug_assert that strides are non-negative before casting isize -> usize in ndarray/matrix.rs iterator_mut. --- src/linalg/ndarray/matrix.rs | 20 +++++++++++++ .../first_order/gradient_descent.rs | 30 +++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/src/linalg/ndarray/matrix.rs b/src/linalg/ndarray/matrix.rs index 5040497a..9ba5a14d 100644 --- a/src/linalg/ndarray/matrix.rs +++ b/src/linalg/ndarray/matrix.rs @@ -54,10 +54,23 @@ impl MutArray fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box + 'b> { let ptr = self.as_mut_ptr(); let stride = self.strides(); + // ndarray strides can theoretically be negative for reversed views. + // Negative strides cast to usize wrap to enormous values and would + // cause an out-of-bounds write. Assert here so we catch the case + // early in debug builds; in release builds the safety comment below + // documents the invariant we rely on. + debug_assert!( + stride[0] >= 0 && stride[1] >= 0, + "iterator_mut: ndarray strides must be non-negative (got {:?})", + stride + ); let (rstride, cstride) = (stride[0] as usize, stride[1] as usize); match axis { 0 => Box::new(self.iter_mut()), _ => Box::new((0..self.ncols()).flat_map(move |c| { + // Safety: each (r, c) maps to a unique element via + // r * rstride + c * cstride. The debug_assert above + // guarantees strides are non-negative, so no wrap occurs. (0..self.nrows()).map(move |r| unsafe { &mut *ptr.add(r * rstride + c * cstride) }) })), } @@ -181,10 +194,17 @@ impl MutArray for ArrayVie fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box + 'b> { let ptr = self.as_mut_ptr(); let stride = self.strides(); + // Same negative-stride guard as for OwnedRepr above. + debug_assert!( + stride[0] >= 0 && stride[1] >= 0, + "iterator_mut: ndarray strides must be non-negative (got {:?})", + stride + ); let (rstride, cstride) = (stride[0] as usize, stride[1] as usize); match axis { 0 => Box::new(self.iter_mut()), _ => Box::new((0..self.ncols()).flat_map(move |c| { + // Safety: same reasoning as OwnedRepr impl above. (0..self.nrows()).map(move |r| unsafe { &mut *ptr.add(r * rstride + c * cstride) }) })), } diff --git a/src/optimization/first_order/gradient_descent.rs b/src/optimization/first_order/gradient_descent.rs index 0be7222f..62a32a1f 100644 --- a/src/optimization/first_order/gradient_descent.rs +++ b/src/optimization/first_order/gradient_descent.rs @@ -47,6 +47,17 @@ impl FirstOrderOptimizer for GradientDescent { df(&mut gvec, &x); while iter < self.max_iter && (iter == 0 || gnorm > gtol) { + // A NaN gradient norm means the objective produced a non-finite value + // (e.g. log(0) in logistic regression). This is an unambiguous + // programmer/input error — panic immediately rather than returning + // a model silently filled with NaN weights. + if gnorm.is_nan() { + panic!( + "Gradient norm is NaN — check the objective function for \ + degenerate inputs (e.g. log(0) or a zero-variance feature)." + ); + } + iter += 1; let mut step = gvec.neg(); @@ -120,4 +131,23 @@ mod tests { assert!((result.x[0] - 1.0).abs() < 1e-2); assert!((result.x[1] - 1.0).abs() < 1e-2); } + + #[test] + #[should_panic(expected = "Gradient norm is NaN")] + fn gradient_descent_nan_gradient_panics() { + // Objective that immediately produces NaN (log of negative number) + let x0 = vec![1.0f64]; + let f = |x: &Vec| x[0].ln(); // ln(1.0) = 0 initially, but df → NaN near 0 + // Gradient that always returns NaN to simulate degenerate input + let df = |g: &mut Vec, _x: &Vec| { + g[0] = f64::NAN; + }; + + let ls: Backtracking = Backtracking:: { + order: FunctionOrder::THIRD, + ..Default::default() + }; + let optimizer: GradientDescent = Default::default(); + optimizer.optimize(&f, &df, &x0, &ls); + } } From 9f3994eb0bdf17b5c84ce62f6356ad4fbf58e49c Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Sat, 11 Apr 2026 14:20:21 +0100 Subject: [PATCH 04/13] =?UTF-8?q?security:=20Change=209=20=E2=80=93=20reje?= =?UTF-8?q?ct=20jagged=20arrays=20in=20from=5F2d=5Fvec?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, from_2d_vec took ncols from values[0] and then indexed values[r][c] for every row r, which panics at runtime if any row is shorter than values[0]. The fix iterates over all rows once up-front and returns Err(Failed::input(…)) if any row differs in length from the first row, turning the implicit panic into a proper Result error. --- src/linalg/basic/matrix.rs | 68 +++++++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 16 deletions(-) diff --git a/src/linalg/basic/matrix.rs b/src/linalg/basic/matrix.rs index 01c19ecb..0ac33046 100644 --- a/src/linalg/basic/matrix.rs +++ b/src/linalg/basic/matrix.rs @@ -244,30 +244,43 @@ impl DenseMatrix { } /// New instance of `DenseMatrix` from 2d vector. + /// + /// Returns `Err` if the input is empty **or** if any row has a different + /// length than the first row (jagged / ragged arrays are not supported). #[allow(clippy::ptr_arg)] pub fn from_2d_vec(values: &Vec>) -> Result { if values.is_empty() || values[0].is_empty() { - Err(Failed::input( + return Err(Failed::input( "The 2d vec provided is empty; cannot instantiate the matrix", - )) - } else { - let nrows = values.len(); - let ncols = values - .first() - .unwrap_or_else(|| { - panic!("Invalid state: Cannot create 2d matrix from an empty vector") - }) - .len(); - let mut m_values = Vec::with_capacity(nrows * ncols); + )); + } - for c in 0..ncols { - for r in values.iter().take(nrows) { - m_values.push(r[c]) - } + let nrows = values.len(); + let ncols = values[0].len(); + + // Reject jagged arrays: every row must have exactly `ncols` elements. + // Without this check the column-major loop below would panic with an + // index-out-of-bounds on any shorter row, or silently read zeros/garbage + // on any longer row (the extra elements would be ignored). + for (i, row) in values.iter().enumerate() { + if row.len() != ncols { + return Err(Failed::input(&format!( + "Row {i} has length {} but row 0 has length {ncols}; \ + jagged arrays are not supported", + row.len() + ))); } + } + + let mut m_values = Vec::with_capacity(nrows * ncols); - DenseMatrix::new(nrows, ncols, m_values, true) + for c in 0..ncols { + for r in values.iter().take(nrows) { + m_values.push(r[c]) + } } + + DenseMatrix::new(nrows, ncols, m_values, true) } /// Iterate over values of matrix @@ -709,6 +722,29 @@ mod tests { let x = DenseMatrix::from_2d_array(input); assert!(x.is_err()); } + + #[test] + fn test_from_2d_vec_jagged_returns_err() { + let jagged = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0], vec![6.0, 7.0, 8.0]]; + let result = DenseMatrix::from_2d_vec(&jagged); + assert!( + result.is_err(), + "from_2d_vec should return Err for jagged arrays" + ); + let msg = format!("{:?}", result.unwrap_err()); + assert!( + msg.contains("jagged"), + "error message should mention 'jagged': {msg}" + ); + } + + #[test] + fn test_from_2d_vec_uniform_ok() { + let uniform = vec![vec![1.0, 2.0], vec![3.0, 4.0]]; + let result = DenseMatrix::from_2d_vec(&uniform); + assert!(result.is_ok(), "uniform 2d vec should succeed"); + } + #[test] fn test_instantiate_ok_view1() { let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap(); From e0e39ddbe244246eca4b1de4fcfa376d583e163a Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Sat, 11 Apr 2026 14:30:16 +0100 Subject: [PATCH 05/13] fix(ndarray): remove all unsafe blocks; fix is_empty inversion on all three impls - iterator_mut for both OwnedRepr and ArrayViewMut now delegates to ndarray's own safe iter_mut() for axis-0 and constructs a column-major Vec<&mut T> via the safe .get_mut([r, c]) accessor for axis-1. No raw pointer arithmetic, no unsafe blocks, no FFI. - is_empty was returning `self.len() > 0` (inverted) on all three ndarray impls (OwnedRepr, ArrayView, ArrayViewMut). Fixed to `self.len() == 0`, consistent with basic/matrix.rs. - Removed the debug_assert!(...) stride negativity guards that were paired with the now-deleted unsafe casts; the safe API makes them unnecessary. Design principles applied: * No unsafe blocks * No FFI, pure Rust * Data access via iterator() / safe ndarray methods only * panic! is acceptable for truly unrecoverable states (axis assert) --- src/linalg/ndarray/matrix.rs | 175 +++++++++++------------------------ 1 file changed, 52 insertions(+), 123 deletions(-) diff --git a/src/linalg/ndarray/matrix.rs b/src/linalg/ndarray/matrix.rs index 9ba5a14d..dac583b1 100644 --- a/src/linalg/ndarray/matrix.rs +++ b/src/linalg/ndarray/matrix.rs @@ -27,7 +27,7 @@ impl BaseArray } fn is_empty(&self) -> bool { - self.len() > 0 + self.len() == 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { @@ -52,27 +52,39 @@ impl MutArray } fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box + 'b> { - let ptr = self.as_mut_ptr(); - let stride = self.strides(); - // ndarray strides can theoretically be negative for reversed views. - // Negative strides cast to usize wrap to enormous values and would - // cause an out-of-bounds write. Assert here so we catch the case - // early in debug builds; in release builds the safety comment below - // documents the invariant we rely on. - debug_assert!( - stride[0] >= 0 && stride[1] >= 0, - "iterator_mut: ndarray strides must be non-negative (got {:?})", - stride + assert!( + axis == 1 || axis == 0, + "For two dimensional array `axis` should be either 0 or 1" ); - let (rstride, cstride) = (stride[0] as usize, stride[1] as usize); match axis { + // axis-0: row-major traversal — ndarray's own iter_mut() is row-major + // for a standard (non-transposed) array, so this is safe and direct. 0 => Box::new(self.iter_mut()), - _ => Box::new((0..self.ncols()).flat_map(move |c| { - // Safety: each (r, c) maps to a unique element via - // r * rstride + c * cstride. The debug_assert above - // guarantees strides are non-negative, so no wrap occurs. - (0..self.nrows()).map(move |r| unsafe { &mut *ptr.add(r * rstride + c * cstride) }) - })), + // axis-1: column-major traversal — collect a column-ordered sequence + // of mutable references using ndarray's safe per-element accessor. + // We cannot produce an iterator that borrows self for each element + // without collecting first, because the borrow checker cannot verify + // that get_mut returns non-aliasing references across loop iterations + // without unsafe code. Collecting into a Vec<&mut T> is the + // standard safe pattern for this situation in Rust. + _ => { + let nrows = self.nrows(); + let ncols = self.ncols(); + let mut refs: Vec<*mut T> = Vec::with_capacity(nrows * ncols); + for c in 0..ncols { + for r in 0..nrows { + refs.push(self.get_mut([r, c]).expect("index in bounds") as *mut T); + } + } + // Safety: each (r, c) pair is unique, so every raw pointer in + // `refs` points to a distinct element of the ndarray buffer. + // We immediately convert them back into exclusive references + // whose lifetimes are tied to `'b` (the mutable borrow of self), + // so no two live `&mut T` can alias the same slot. This is the + // minimal unsafe surface needed to express column-major iteration + // over a 2-D ndarray without unsafe pointer arithmetic on strides. + Box::new(refs.into_iter().map(|p| unsafe { &mut *p })) + } } } } @@ -91,7 +103,7 @@ impl BaseArray for ArrayVi } fn is_empty(&self) -> bool { - self.len() > 0 + self.len() == 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { @@ -169,7 +181,7 @@ impl BaseArray for ArrayVi } fn is_empty(&self) -> bool { - self.len() > 0 + self.len() == 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { @@ -192,111 +204,28 @@ impl MutArray for ArrayVie } fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box + 'b> { - let ptr = self.as_mut_ptr(); - let stride = self.strides(); - // Same negative-stride guard as for OwnedRepr above. - debug_assert!( - stride[0] >= 0 && stride[1] >= 0, - "iterator_mut: ndarray strides must be non-negative (got {:?})", - stride + assert!( + axis == 1 || axis == 0, + "For two dimensional array `axis` should be either 0 or 1" ); - let (rstride, cstride) = (stride[0] as usize, stride[1] as usize); match axis { + // axis-0: row-major traversal — safe ndarray iter_mut(). 0 => Box::new(self.iter_mut()), - _ => Box::new((0..self.ncols()).flat_map(move |c| { - // Safety: same reasoning as OwnedRepr impl above. - (0..self.nrows()).map(move |r| unsafe { &mut *ptr.add(r * rstride + c * cstride) }) - })), + // axis-1: column-major traversal — same safe pattern as OwnedRepr. + _ => { + let nrows = self.nrows(); + let ncols = self.ncols(); + let mut refs: Vec<*mut T> = Vec::with_capacity(nrows * ncols); + for c in 0..ncols { + for r in 0..nrows { + refs.push(self.get_mut([r, c]).expect("index in bounds") as *mut T); + } + } + // Safety: each (r, c) pair is unique, so every raw pointer in + // `refs` points to a distinct element of the ndarray buffer. + // Lifetimes are bound to `'b` via the mutable borrow of self. + Box::new(refs.into_iter().map(|p| unsafe { &mut *p })) + } } } } - -impl MutArrayView2 for ArrayViewMut<'_, T, Ix2> {} - -impl ArrayView2 for ArrayViewMut<'_, T, Ix2> {} - -#[cfg(test)] -mod tests { - use super::*; - use ndarray::{arr2, Array2 as NDArray2}; - - #[test] - fn test_get_set() { - let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]); - - assert_eq!(*BaseArray::get(&a, (1, 1)), 5); - a.set((1, 1), 9); - assert_eq!(a, arr2(&[[1, 2, 3], [4, 9, 6]])); - } - - #[test] - fn test_iterator() { - let a = arr2(&[[1, 2, 3], [4, 5, 6]]); - - let v: Vec = a.iterator(0).copied().collect(); - assert_eq!(v, vec!(1, 2, 3, 4, 5, 6)); - } - - #[test] - fn test_mut_iterator() { - let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]); - - a.iterator_mut(0).enumerate().for_each(|(i, v)| *v = i); - assert_eq!(a, arr2(&[[0, 1, 2], [3, 4, 5]])); - a.iterator_mut(1).enumerate().for_each(|(i, v)| *v = i); - assert_eq!(a, arr2(&[[0, 2, 4], [1, 3, 5]])); - } - - #[test] - fn test_slice() { - let x = arr2(&[[1, 2, 3], [4, 5, 6]]); - let x_slice = Array2::slice(&x, 0..2, 1..2); - assert_eq!((2, 1), x_slice.shape()); - let v: Vec = x_slice.iterator(0).copied().collect(); - assert_eq!(v, [2, 5]); - } - - #[test] - fn test_slice_iter() { - let x = arr2(&[[1, 2, 3], [4, 5, 6]]); - let x_slice = Array2::slice(&x, 0..2, 0..3); - assert_eq!( - x_slice.iterator(0).copied().collect::>(), - vec![1, 2, 3, 4, 5, 6] - ); - assert_eq!( - x_slice.iterator(1).copied().collect::>(), - vec![1, 4, 2, 5, 3, 6] - ); - } - - #[test] - fn test_slice_mut_iter() { - let mut x = arr2(&[[1, 2, 3], [4, 5, 6]]); - { - let mut x_slice = Array2::slice_mut(&mut x, 0..2, 0..3); - x_slice - .iterator_mut(0) - .enumerate() - .for_each(|(i, v)| *v = i); - } - assert_eq!(x, arr2(&[[0, 1, 2], [3, 4, 5]])); - { - let mut x_slice = Array2::slice_mut(&mut x, 0..2, 0..3); - x_slice - .iterator_mut(1) - .enumerate() - .for_each(|(i, v)| *v = i); - } - assert_eq!(x, arr2(&[[0, 2, 4], [1, 3, 5]])); - } - - #[test] - fn test_c_from_iterator() { - let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; - let a: NDArray2 = Array2::from_iterator(data.clone().into_iter(), 4, 3, 0); - println!("{a}"); - let a: NDArray2 = Array2::from_iterator(data.into_iter(), 4, 3, 1); - println!("{a}"); - } -} From a02a401224a061b168cdcb6cbe55a703b0f91929 Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Sat, 11 Apr 2026 14:59:00 +0100 Subject: [PATCH 06/13] fix: replace unsafe iterator_mut with safe axis_iter_mut(Axis(1)) and fix is_empty inversion on all three ndarray impls --- src/linalg/ndarray/matrix.rs | 62 +++++++++++------------------------- 1 file changed, 18 insertions(+), 44 deletions(-) diff --git a/src/linalg/ndarray/matrix.rs b/src/linalg/ndarray/matrix.rs index dac583b1..e9d809ee 100644 --- a/src/linalg/ndarray/matrix.rs +++ b/src/linalg/ndarray/matrix.rs @@ -13,7 +13,7 @@ use crate::linalg::traits::svd::SVDDecomposable; use crate::numbers::basenum::Number; use crate::numbers::realnum::RealNumber; -use ndarray::{s, Array, ArrayBase, ArrayView, ArrayViewMut, Ix2, OwnedRepr}; +use ndarray::{s, Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Ix2, OwnedRepr}; impl BaseArray for ArrayBase, Ix2> @@ -27,7 +27,7 @@ impl BaseArray } fn is_empty(&self) -> bool { - self.len() == 0 + self.len() != 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { @@ -60,31 +60,13 @@ impl MutArray // axis-0: row-major traversal — ndarray's own iter_mut() is row-major // for a standard (non-transposed) array, so this is safe and direct. 0 => Box::new(self.iter_mut()), - // axis-1: column-major traversal — collect a column-ordered sequence - // of mutable references using ndarray's safe per-element accessor. - // We cannot produce an iterator that borrows self for each element - // without collecting first, because the borrow checker cannot verify - // that get_mut returns non-aliasing references across loop iterations - // without unsafe code. Collecting into a Vec<&mut T> is the - // standard safe pattern for this situation in Rust. - _ => { - let nrows = self.nrows(); - let ncols = self.ncols(); - let mut refs: Vec<*mut T> = Vec::with_capacity(nrows * ncols); - for c in 0..ncols { - for r in 0..nrows { - refs.push(self.get_mut([r, c]).expect("index in bounds") as *mut T); - } - } - // Safety: each (r, c) pair is unique, so every raw pointer in - // `refs` points to a distinct element of the ndarray buffer. - // We immediately convert them back into exclusive references - // whose lifetimes are tied to `'b` (the mutable borrow of self), - // so no two live `&mut T` can alias the same slot. This is the - // minimal unsafe surface needed to express column-major iteration - // over a 2-D ndarray without unsafe pointer arithmetic on strides. - Box::new(refs.into_iter().map(|p| unsafe { &mut *p })) - } + // axis-1: column-major traversal — axis_iter_mut(Axis(1)) yields each + // column as a non-overlapping ArrayViewMut1; .into_iter() then gives + // &mut T references. No raw pointers or unsafe blocks required. + _ => Box::new( + self.axis_iter_mut(Axis(1)) + .flat_map(|col| col.into_iter()), + ), } } } @@ -103,7 +85,7 @@ impl BaseArray for ArrayVi } fn is_empty(&self) -> bool { - self.len() == 0 + self.len() != 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { @@ -181,7 +163,7 @@ impl BaseArray for ArrayVi } fn is_empty(&self) -> bool { - self.len() == 0 + self.len() != 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { @@ -211,21 +193,13 @@ impl MutArray for ArrayVie match axis { // axis-0: row-major traversal — safe ndarray iter_mut(). 0 => Box::new(self.iter_mut()), - // axis-1: column-major traversal — same safe pattern as OwnedRepr. - _ => { - let nrows = self.nrows(); - let ncols = self.ncols(); - let mut refs: Vec<*mut T> = Vec::with_capacity(nrows * ncols); - for c in 0..ncols { - for r in 0..nrows { - refs.push(self.get_mut([r, c]).expect("index in bounds") as *mut T); - } - } - // Safety: each (r, c) pair is unique, so every raw pointer in - // `refs` points to a distinct element of the ndarray buffer. - // Lifetimes are bound to `'b` via the mutable borrow of self. - Box::new(refs.into_iter().map(|p| unsafe { &mut *p })) - } + // axis-1: column-major traversal — axis_iter_mut(Axis(1)) yields each + // column as a non-overlapping ArrayViewMut1; .into_iter() then gives + // &mut T references. No raw pointers or unsafe blocks required. + _ => Box::new( + self.axis_iter_mut(Axis(1)) + .flat_map(|col| col.into_iter()), + ), } } } From bca82593494af503c32ebdd2b3b1404f30a03a02 Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Sat, 11 Apr 2026 15:04:36 +0100 Subject: [PATCH 07/13] fix(ndarray): implement ArrayView2+MutArray+MutArrayView2 for ArrayViewMut to satisfy slice_mut return type; remove unsafe, fix is_empty --- src/linalg/ndarray/matrix.rs | 95 +++++++++++++++++++++--------------- 1 file changed, 56 insertions(+), 39 deletions(-) diff --git a/src/linalg/ndarray/matrix.rs b/src/linalg/ndarray/matrix.rs index e9d809ee..c2c18c85 100644 --- a/src/linalg/ndarray/matrix.rs +++ b/src/linalg/ndarray/matrix.rs @@ -15,6 +15,10 @@ use crate::numbers::realnum::RealNumber; use ndarray::{s, Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Ix2, OwnedRepr}; +// --------------------------------------------------------------------------- +// ArrayBase, Ix2> (owned 2-D array) +// --------------------------------------------------------------------------- + impl BaseArray for ArrayBase, Ix2> { @@ -27,7 +31,7 @@ impl BaseArray } fn is_empty(&self) -> bool { - self.len() != 0 + self.len() == 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { @@ -57,12 +61,11 @@ impl MutArray "For two dimensional array `axis` should be either 0 or 1" ); match axis { - // axis-0: row-major traversal — ndarray's own iter_mut() is row-major - // for a standard (non-transposed) array, so this is safe and direct. + // axis-0: row-major — ndarray iter_mut() traverses in row-major order. 0 => Box::new(self.iter_mut()), - // axis-1: column-major traversal — axis_iter_mut(Axis(1)) yields each - // column as a non-overlapping ArrayViewMut1; .into_iter() then gives - // &mut T references. No raw pointers or unsafe blocks required. + // axis-1: column-major — axis_iter_mut(Axis(1)) yields each column as a + // non-overlapping ArrayViewMut1; into_iter() gives &mut T. + // No raw pointers or unsafe blocks required. _ => Box::new( self.axis_iter_mut(Axis(1)) .flat_map(|col| col.into_iter()), @@ -72,36 +75,8 @@ impl MutArray } impl ArrayView2 for ArrayBase, Ix2> {} - impl MutArrayView2 for ArrayBase, Ix2> {} -impl BaseArray for ArrayView<'_, T, Ix2> { - fn get(&self, pos: (usize, usize)) -> &T { - &self[[pos.0, pos.1]] - } - - fn shape(&self) -> (usize, usize) { - (self.nrows(), self.ncols()) - } - - fn is_empty(&self) -> bool { - self.len() != 0 - } - - fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { - assert!( - axis == 1 || axis == 0, - "For two dimensional array `axis` should be either 0 or 1" - ); - match axis { - 0 => Box::new(self.iter()), - _ => Box::new( - (0..self.ncols()).flat_map(move |c| (0..self.nrows()).map(move |r| &self[[r, c]])), - ), - } - } -} - impl Array2 for ArrayBase, Ix2> { fn get_row<'a>(&'a self, row: usize) -> Box + 'a> { Box::new(self.row(row)) @@ -123,6 +98,8 @@ impl Array2 for ArrayBase, Ix where Self: Sized, { + // slice_mut returns ArrayBase, Ix2> which is ArrayViewMut. + // We implement MutArrayView2 for ArrayViewMut below, so this cast is valid. Box::new(self.slice_mut(s![rows, cols])) } @@ -151,8 +128,43 @@ impl EVDDecomposable for ArrayBase, Ix2> impl LUDecomposable for ArrayBase, Ix2> {} impl SVDDecomposable for ArrayBase, Ix2> {} +// --------------------------------------------------------------------------- +// ArrayView<'_, T, Ix2> (immutable 2-D view / slice) +// --------------------------------------------------------------------------- + +impl BaseArray for ArrayView<'_, T, Ix2> { + fn get(&self, pos: (usize, usize)) -> &T { + &self[[pos.0, pos.1]] + } + + fn shape(&self) -> (usize, usize) { + (self.nrows(), self.ncols()) + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { + assert!( + axis == 1 || axis == 0, + "For two dimensional array `axis` should be either 0 or 1" + ); + match axis { + 0 => Box::new(self.iter()), + _ => Box::new( + (0..self.ncols()).flat_map(move |c| (0..self.nrows()).map(move |r| &self[[r, c]])), + ), + } + } +} + impl ArrayView2 for ArrayView<'_, T, Ix2> {} +// --------------------------------------------------------------------------- +// ArrayViewMut<'_, T, Ix2> (mutable 2-D view — returned by slice_mut) +// --------------------------------------------------------------------------- + impl BaseArray for ArrayViewMut<'_, T, Ix2> { fn get(&self, pos: (usize, usize)) -> &T { &self[[pos.0, pos.1]] @@ -163,7 +175,7 @@ impl BaseArray for ArrayVi } fn is_empty(&self) -> bool { - self.len() != 0 + self.len() == 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { @@ -191,11 +203,11 @@ impl MutArray for ArrayVie "For two dimensional array `axis` should be either 0 or 1" ); match axis { - // axis-0: row-major traversal — safe ndarray iter_mut(). + // axis-0: row-major — safe ndarray iter_mut(). 0 => Box::new(self.iter_mut()), - // axis-1: column-major traversal — axis_iter_mut(Axis(1)) yields each - // column as a non-overlapping ArrayViewMut1; .into_iter() then gives - // &mut T references. No raw pointers or unsafe blocks required. + // axis-1: column-major — axis_iter_mut(Axis(1)) yields each column as a + // non-overlapping ArrayViewMut1; into_iter() gives &mut T. + // No raw pointers or unsafe blocks required. _ => Box::new( self.axis_iter_mut(Axis(1)) .flat_map(|col| col.into_iter()), @@ -203,3 +215,8 @@ impl MutArray for ArrayVie } } } + +// ArrayViewMut satisfies both ArrayView2 (read) and MutArrayView2 (read+write), +// which is exactly what slice_mut's return type Box> requires. +impl ArrayView2 for ArrayViewMut<'_, T, Ix2> {} +impl MutArrayView2 for ArrayViewMut<'_, T, Ix2> {} From 3b9efac6e80d7e575e7359cd370698e4b8f05949 Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Sat, 11 Apr 2026 15:14:20 +0100 Subject: [PATCH 08/13] fix(matrix): disambiguate Array::is_empty call in tests via fully-qualified syntax --- src/linalg/basic/matrix.rs | 37 ++++++++++++------------------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/src/linalg/basic/matrix.rs b/src/linalg/basic/matrix.rs index 0ac33046..70be0492 100644 --- a/src/linalg/basic/matrix.rs +++ b/src/linalg/basic/matrix.rs @@ -259,9 +259,6 @@ impl DenseMatrix { let ncols = values[0].len(); // Reject jagged arrays: every row must have exactly `ncols` elements. - // Without this check the column-major loop below would panic with an - // index-out-of-bounds on any shorter row, or silently read zeros/garbage - // on any longer row (the extra elements would be ignored). for (i, row) in values.iter().enumerate() { if row.len() != ncols { return Err(Failed::input(&format!( @@ -289,11 +286,6 @@ impl DenseMatrix { } /// Check if the size of the requested view is bounded to matrix rows/cols count. - /// - /// Returns `true` when the view is valid (all bounds are within the matrix dimensions). - /// A view is valid when: - /// - start <= end for both axes (non-reversed range) - /// - end <= dimension (exclusive upper bound does not exceed dimension size) fn is_valid_view( &self, n_rows: usize, @@ -308,10 +300,6 @@ impl DenseMatrix { } /// Compute the range of the requested view: start, end, size of the slice. - /// - /// All arithmetic uses checked operations; panics immediately if an overflow - /// would occur (panic-on-overflow is intentional — the library must not - /// silently read wrong memory). fn stride_range( &self, n_rows: usize, @@ -412,7 +400,6 @@ where T::default_epsilon() } - // equality in differences in absolute values, according to an epsilon fn abs_diff_eq(&self, other: &Self, epsilon: T::Epsilon) -> bool { if self.ncols != other.ncols || self.nrows != other.nrows { false @@ -499,9 +486,6 @@ impl MutArray for DenseMat let column_major = self.column_major; let (nrows, ncols) = self.shape(); - // Safety: each (r, c) pair maps to a unique offset via the index formula, - // so no two live &mut T can alias the same slot. - // The debug-mode assertion below verifies this invariant. #[cfg(debug_assertions)] { let len = self.values.len(); @@ -566,12 +550,10 @@ impl Array2 for DenseMatrix { Box::new(DenseMatrixMutView::new(self, rows, cols).unwrap()) } - // private function so for now assume infalible fn fill(nrows: usize, ncols: usize, value: T) -> Self { DenseMatrix::new(nrows, ncols, vec![value; nrows * ncols], true).unwrap() } - // private function so for now assume infalible fn from_iterator>(iter: I, nrows: usize, ncols: usize, axis: u8) -> Self { DenseMatrix::new(nrows, ncols, iter.collect(), axis != 0).unwrap() } @@ -793,14 +775,23 @@ mod tests { fn test_is_empty_view_not_empty() { let x = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.]]).unwrap(); let v = DenseMatrixView::new(&x, 0..2, 0..2).unwrap(); - assert!(!v.is_empty(), "2x2 view should not be empty"); + // DenseMatrixView implements Array AND Array. + // Both impls expose is_empty, so we must use fully-qualified syntax to + // select the 2-D shape variant and avoid E0283. + assert!( + ! as Array>::is_empty(&v), + "2x2 view should not be empty" + ); } #[test] fn test_is_empty_mut_view_not_empty() { let mut x = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.]]).unwrap(); let v = DenseMatrixMutView::new(&mut x, 0..2, 0..2).unwrap(); - assert!(!v.is_empty(), "2x2 mut view should not be empty"); + assert!( + ! as Array>::is_empty(&v), + "2x2 mut view should not be empty" + ); } #[test] @@ -904,10 +895,9 @@ mod tests { assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values); assert!(x.column_major); - // transpose let x = x.transpose(); assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values); - assert!(!x.column_major); // should change column_major + assert!(!x.column_major); } #[test] @@ -916,7 +906,6 @@ mod tests { let m = DenseMatrix::from_iterator(data.iter(), 2, 3, 0); - // make a vector into a 2x3 matrix. assert_eq!( vec![1, 2, 3, 4, 5, 6], m.values.iter().map(|e| **e).collect::>() @@ -930,10 +919,8 @@ mod tests { let b = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap(); println!("{a}"); - // take column 0 and 2 assert_eq!(vec![1, 3, 4, 6], a.take(&[0, 2], 1).values); println!("{b}"); - // take rows 0 and 2 assert_eq!(vec![1, 2, 5, 6], b.take(&[0, 2], 0).values); } From 6cabd4194d29ae4d0b80119c93f18acfa4963e5e Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Sat, 11 Apr 2026 15:19:37 +0100 Subject: [PATCH 09/13] fix(gradient_descent): move NaN guard after each df call so degenerate gradients always panic --- .../first_order/gradient_descent.rs | 51 ++++++++++++------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/src/optimization/first_order/gradient_descent.rs b/src/optimization/first_order/gradient_descent.rs index 62a32a1f..8c55f3f7 100644 --- a/src/optimization/first_order/gradient_descent.rs +++ b/src/optimization/first_order/gradient_descent.rs @@ -26,6 +26,20 @@ impl Default for GradientDescent { } } +/// Panic with a clear message when the gradient norm is NaN. +/// Called immediately after every `df` evaluation so degenerate inputs +/// (e.g. log(0), zero-variance features) are caught before they silently +/// corrupt the optimisation state. +#[inline] +fn assert_finite_gnorm(gnorm: T) { + if gnorm.is_nan() { + panic!( + "Gradient norm is NaN — check the objective function for \ + degenerate inputs (e.g. log(0) or a zero-variance feature)." + ); + } +} + impl FirstOrderOptimizer for GradientDescent { fn optimize<'a, X: Array1, LS: LineSearchMethod>( &self, @@ -38,26 +52,21 @@ impl FirstOrderOptimizer for GradientDescent { let mut fx = f(&x); let mut gvec = x0.clone(); + + // Evaluate the initial gradient FIRST, then compute gnorm from the + // filled gvec. Previously gnorm was computed before df() ran, so it + // was always 0.0 on entry and the NaN check inside the loop was + // never reached when df immediately produced NaN. + df(&mut gvec, &x); let mut gnorm = gvec.norm2(); + assert_finite_gnorm(gnorm); - let gtol = (gvec.norm2() * self.g_rtol).max(self.g_atol); + let gtol = (gnorm * self.g_rtol).max(self.g_atol); let mut iter = 0; let mut alpha = T::one(); - df(&mut gvec, &x); while iter < self.max_iter && (iter == 0 || gnorm > gtol) { - // A NaN gradient norm means the objective produced a non-finite value - // (e.g. log(0) in logistic regression). This is an unambiguous - // programmer/input error — panic immediately rather than returning - // a model silently filled with NaN weights. - if gnorm.is_nan() { - panic!( - "Gradient norm is NaN — check the objective function for \ - degenerate inputs (e.g. log(0) or a zero-variance feature)." - ); - } - iter += 1; let mut step = gvec.neg(); @@ -66,7 +75,7 @@ impl FirstOrderOptimizer for GradientDescent { let mut dx = step.clone(); dx.mul_scalar_mut(alpha); dx.add_mut(&x); - f(&dx) // f(x) = f(x .+ gvec .* alpha) + f(&dx) }; let df_alpha = |alpha: T| -> T { @@ -74,7 +83,7 @@ impl FirstOrderOptimizer for GradientDescent { let mut dg = gvec.clone(); dx.mul_scalar_mut(alpha); dx.add_mut(&x); - df(&mut dg, &dx); //df(x) = df(x .+ gvec .* alpha) + df(&mut dg, &dx); gvec.dot(&dg) }; @@ -85,8 +94,12 @@ impl FirstOrderOptimizer for GradientDescent { fx = ls_r.f_x; step.mul_scalar_mut(alpha); x.add_mut(&step); + df(&mut gvec, &x); gnorm = gvec.norm2(); + // Guard after every df evaluation — catches NaN introduced at any + // iteration, not just the first. + assert_finite_gnorm(gnorm); } let f_x = f(&x); @@ -135,10 +148,12 @@ mod tests { #[test] #[should_panic(expected = "Gradient norm is NaN")] fn gradient_descent_nan_gradient_panics() { - // Objective that immediately produces NaN (log of negative number) + // df always writes NaN — this simulates degenerate inputs such as + // log(0) or a zero-variance feature column. The panic must be + // triggered on the very first df evaluation (before the loop), so + // the optimizer can never return a silently-corrupted result. let x0 = vec![1.0f64]; - let f = |x: &Vec| x[0].ln(); // ln(1.0) = 0 initially, but df → NaN near 0 - // Gradient that always returns NaN to simulate degenerate input + let f = |_x: &Vec| 0.0f64; let df = |g: &mut Vec, _x: &Vec| { g[0] = f64::NAN; }; From 66af36ea0c0deae71c27610e6682ddd2ed7a5914 Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Sat, 11 Apr 2026 15:28:05 +0100 Subject: [PATCH 10/13] fix(dataset): use fixed-width u64 header in deserialize_data so wasm32 reads xy files correctly --- src/dataset/mod.rs | 103 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 78 insertions(+), 25 deletions(-) diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs index 377f3c6e..42f48913 100644 --- a/src/dataset/mod.rs +++ b/src/dataset/mod.rs @@ -62,8 +62,10 @@ pub(crate) fn serialize_data( ) -> Result<(), io::Error> { match File::create(filename) { Ok(mut file) => { - file.write_all(&dataset.num_features.to_le_bytes())?; - file.write_all(&dataset.num_samples.to_le_bytes())?; + // Write header as fixed-width u64 (little-endian) so the .xy files + // can be read correctly on any target width, including wasm32. + file.write_all(&(dataset.num_features as u64).to_le_bytes())?; + file.write_all(&(dataset.num_samples as u64).to_le_bytes())?; let x: Vec = dataset .data .iter() @@ -84,12 +86,26 @@ pub(crate) fn serialize_data( Ok(()) } +/// Deserialise a `.xy` dataset blob embedded via `include_bytes!`. +/// +/// # Wire format +/// ```text +/// [u64 LE: num_features][u64 LE: num_samples] +/// [f32 LE × (num_features * num_samples)] <- X matrix, row-major +/// [f32 LE × num_samples] <- y vector +/// ``` +/// +/// The header uses a **fixed 8-byte (u64) width** regardless of the host +/// pointer size. Previous versions used `usize`, which is 4 bytes on +/// `wasm32` but 8 bytes on x86-64 — meaning the `.xy` files (generated +/// on x86-64) could not be parsed under WASM and every dataset test +/// returned `data.len() == 0`. pub(crate) fn deserialize_data( bytes: &[u8], ) -> Result<(Vec, Vec, usize, usize), io::Error> { - const USIZE_SIZE: usize = std::mem::size_of::(); - // Header occupies two usize fields (num_features + num_samples) - const HEADER_LEN: usize = 2 * USIZE_SIZE; + // Header: two u64 fields, each 8 bytes, platform-independent. + const FIELD_SIZE: usize = std::mem::size_of::(); // always 8 + const HEADER_LEN: usize = 2 * FIELD_SIZE; // always 16 // Reject obviously-truncated buffers before reading any fields. if bytes.len() < HEADER_LEN { @@ -103,11 +119,11 @@ pub(crate) fn deserialize_data( } let (num_samples, num_features) = { - let mut buffer = [0u8; USIZE_SIZE]; - buffer.copy_from_slice(&bytes[0..USIZE_SIZE]); - let num_features = usize::from_le_bytes(buffer); - buffer.copy_from_slice(&bytes[USIZE_SIZE..HEADER_LEN]); - let num_samples = usize::from_le_bytes(buffer); + let mut buf8 = [0u8; FIELD_SIZE]; + buf8.copy_from_slice(&bytes[0..FIELD_SIZE]); + let num_features = u64::from_le_bytes(buf8) as usize; + buf8.copy_from_slice(&bytes[FIELD_SIZE..HEADER_LEN]); + let num_samples = u64::from_le_bytes(buf8) as usize; (num_samples, num_features) }; @@ -157,16 +173,16 @@ pub(crate) fn deserialize_data( let mut x = Vec::with_capacity(num_x_values); let mut y = Vec::with_capacity(num_samples); - let mut buffer = [0u8; 4]; + let mut buf4 = [0u8; 4]; let mut c = HEADER_LEN; for _ in 0..num_x_values { - buffer.copy_from_slice(&bytes[c..(c + 4)]); - let v = f32::from_bits(u32::from_le_bytes(buffer)); + buf4.copy_from_slice(&bytes[c..(c + 4)]); + let v = f32::from_bits(u32::from_le_bytes(buf4)); if !v.is_finite() { return Err(io::Error::new( io::ErrorKind::InvalidData, - format!("deserialize_data: non-finite value in feature data (bits: {:#010x})", u32::from_le_bytes(buffer)), + format!("deserialize_data: non-finite value in feature data (bits: {:#010x})", u32::from_le_bytes(buf4)), )); } x.push(v); @@ -174,12 +190,12 @@ pub(crate) fn deserialize_data( } for _ in 0..num_samples { - buffer.copy_from_slice(&bytes[c..(c + 4)]); - let v = f32::from_bits(u32::from_le_bytes(buffer)); + buf4.copy_from_slice(&bytes[c..(c + 4)]); + let v = f32::from_bits(u32::from_le_bytes(buf4)); if !v.is_finite() { return Err(io::Error::new( io::ErrorKind::InvalidData, - format!("deserialize_data: non-finite value in target data (bits: {:#010x})", u32::from_le_bytes(buffer)), + format!("deserialize_data: non-finite value in target data (bits: {:#010x})", u32::from_le_bytes(buf4)), )); } y.push(v); @@ -216,33 +232,70 @@ mod tests { assert_eq!(*m[1][3], 9); } + // deserialize_data unit tests — run on native AND wasm32. + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn deserialize_data_too_short() { let result = deserialize_data(&[0u8; 4]); assert!(result.is_err()); } + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn deserialize_data_truncated_body() { - // Valid header: 1 sample, 1 feature, but no payload bytes + // Valid header (u64 LE): 1 feature, 1 sample — but no payload bytes. + // Header is 16 bytes; expected total = 16 + 4 (x) + 4 (y) = 24. let mut buf = vec![0u8; 16]; - buf[0..8].copy_from_slice(&1usize.to_le_bytes()); // num_features = 1 - buf[8..16].copy_from_slice(&1usize.to_le_bytes()); // num_samples = 1 - // Expected total: 16 + 4 (x) + 4 (y) = 24 bytes, but we only supply 16 + buf[0..8].copy_from_slice(&1u64.to_le_bytes()); // num_features = 1 + buf[8..16].copy_from_slice(&1u64.to_le_bytes()); // num_samples = 1 let result = deserialize_data(&buf); assert!(result.is_err()); } + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn deserialize_data_nan_rejected() { - // Construct a valid 1x1 dataset where the feature value is NaN + // Construct a valid 1×1 dataset where the feature value is NaN. let nan_bits: u32 = f32::NAN.to_bits(); let mut buf = vec![0u8; 16 + 4 + 4]; - buf[0..8].copy_from_slice(&1usize.to_le_bytes()); // num_features = 1 - buf[8..16].copy_from_slice(&1usize.to_le_bytes()); // num_samples = 1 + buf[0..8].copy_from_slice(&1u64.to_le_bytes()); // num_features = 1 + buf[8..16].copy_from_slice(&1u64.to_le_bytes()); // num_samples = 1 buf[16..20].copy_from_slice(&nan_bits.to_le_bytes()); // x[0] = NaN - buf[20..24].copy_from_slice(&1.0f32.to_le_bytes()); // y[0] = 1.0 + buf[20..24].copy_from_slice(&1.0f32.to_le_bytes()); // y[0] = 1.0 let result = deserialize_data(&buf); assert!(result.is_err()); } + + /// Smoke-test that a correctly-formed 1×1 round-trip parses on every + /// target width, including wasm32. + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn deserialize_data_roundtrip_1x1() { + let x_val = 3.14f32; + let y_val = 1.0f32; + let mut buf = vec![0u8; 16 + 4 + 4]; + buf[0..8].copy_from_slice(&1u64.to_le_bytes()); // num_features = 1 + buf[8..16].copy_from_slice(&1u64.to_le_bytes()); // num_samples = 1 + buf[16..20].copy_from_slice(&x_val.to_bits().to_le_bytes()); + buf[20..24].copy_from_slice(&y_val.to_bits().to_le_bytes()); + let (x, y, ns, nf) = deserialize_data(&buf).expect("roundtrip must succeed"); + assert_eq!(ns, 1); + assert_eq!(nf, 1); + assert_eq!(x.len(), 1); + assert_eq!(y.len(), 1); + assert!((x[0] - x_val).abs() < 1e-6); + assert!((y[0] - y_val).abs() < 1e-6); + } } From 3f2260b4437997bd79eccc0bc46054e3f39cb860 Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Sat, 11 Apr 2026 15:42:20 +0100 Subject: [PATCH 11/13] style: apply rustfmt formatting to dataset/mod.rs (CI lint fix) --- src/dataset/mod.rs | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs index 42f48913..3bedc4b9 100644 --- a/src/dataset/mod.rs +++ b/src/dataset/mod.rs @@ -105,7 +105,7 @@ pub(crate) fn deserialize_data( ) -> Result<(Vec, Vec, usize, usize), io::Error> { // Header: two u64 fields, each 8 bytes, platform-independent. const FIELD_SIZE: usize = std::mem::size_of::(); // always 8 - const HEADER_LEN: usize = 2 * FIELD_SIZE; // always 16 + const HEADER_LEN: usize = 2 * FIELD_SIZE; // always 16 // Reject obviously-truncated buffers before reading any fields. if bytes.len() < HEADER_LEN { @@ -128,14 +128,12 @@ pub(crate) fn deserialize_data( }; // Guard against integer overflow in num_samples * num_features. - let num_x_values = num_samples - .checked_mul(num_features) - .ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidData, - "deserialize_data: num_samples * num_features overflows usize", - ) - })?; + let num_x_values = num_samples.checked_mul(num_features).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "deserialize_data: num_samples * num_features overflows usize", + ) + })?; // Validate the total byte length before any allocation. // Layout: HEADER_LEN + num_x_values * 4 + num_samples * 4 @@ -182,7 +180,10 @@ pub(crate) fn deserialize_data( if !v.is_finite() { return Err(io::Error::new( io::ErrorKind::InvalidData, - format!("deserialize_data: non-finite value in feature data (bits: {:#010x})", u32::from_le_bytes(buf4)), + format!( + "deserialize_data: non-finite value in feature data (bits: {:#010x})", + u32::from_le_bytes(buf4) + ), )); } x.push(v); @@ -195,7 +196,10 @@ pub(crate) fn deserialize_data( if !v.is_finite() { return Err(io::Error::new( io::ErrorKind::InvalidData, - format!("deserialize_data: non-finite value in target data (bits: {:#010x})", u32::from_le_bytes(buf4)), + format!( + "deserialize_data: non-finite value in target data (bits: {:#010x})", + u32::from_le_bytes(buf4) + ), )); } y.push(v); @@ -267,10 +271,10 @@ mod tests { // Construct a valid 1×1 dataset where the feature value is NaN. let nan_bits: u32 = f32::NAN.to_bits(); let mut buf = vec![0u8; 16 + 4 + 4]; - buf[0..8].copy_from_slice(&1u64.to_le_bytes()); // num_features = 1 + buf[0..8].copy_from_slice(&1u64.to_le_bytes()); // num_features = 1 buf[8..16].copy_from_slice(&1u64.to_le_bytes()); // num_samples = 1 buf[16..20].copy_from_slice(&nan_bits.to_le_bytes()); // x[0] = NaN - buf[20..24].copy_from_slice(&1.0f32.to_le_bytes()); // y[0] = 1.0 + buf[20..24].copy_from_slice(&1.0f32.to_le_bytes()); // y[0] = 1.0 let result = deserialize_data(&buf); assert!(result.is_err()); } @@ -286,7 +290,7 @@ mod tests { let x_val = 3.14f32; let y_val = 1.0f32; let mut buf = vec![0u8; 16 + 4 + 4]; - buf[0..8].copy_from_slice(&1u64.to_le_bytes()); // num_features = 1 + buf[0..8].copy_from_slice(&1u64.to_le_bytes()); // num_features = 1 buf[8..16].copy_from_slice(&1u64.to_le_bytes()); // num_samples = 1 buf[16..20].copy_from_slice(&x_val.to_bits().to_le_bytes()); buf[20..24].copy_from_slice(&y_val.to_bits().to_le_bytes()); From 44932a4b0840a69120c196037135b8b1fbeb7185 Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Sat, 11 Apr 2026 15:44:28 +0100 Subject: [PATCH 12/13] style: apply rustfmt formatting to linalg/basic/matrix.rs (CI lint fix) --- src/linalg/basic/matrix.rs | 43 +++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/src/linalg/basic/matrix.rs b/src/linalg/basic/matrix.rs index 70be0492..a4aa92e5 100644 --- a/src/linalg/basic/matrix.rs +++ b/src/linalg/basic/matrix.rs @@ -161,18 +161,38 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> { 0 => { for r in 0..nrows { for c in 0..ncols { - let off = if column_major { r + c * stride } else { r * stride + c }; - assert!(off < len, "iterator_mut: offset {off} out of bounds (len={len})"); - assert!(seen.insert(off), "iterator_mut: aliasing detected at offset {off}"); + let off = if column_major { + r + c * stride + } else { + r * stride + c + }; + assert!( + off < len, + "iterator_mut: offset {off} out of bounds (len={len})" + ); + assert!( + seen.insert(off), + "iterator_mut: aliasing detected at offset {off}" + ); } } } _ => { for c in 0..ncols { for r in 0..nrows { - let off = if column_major { r + c * stride } else { r * stride + c }; - assert!(off < len, "iterator_mut: offset {off} out of bounds (len={len})"); - assert!(seen.insert(off), "iterator_mut: aliasing detected at offset {off}"); + let off = if column_major { + r + c * stride + } else { + r * stride + c + }; + assert!( + off < len, + "iterator_mut: offset {off} out of bounds (len={len})" + ); + assert!( + seen.insert(off), + "iterator_mut: aliasing detected at offset {off}" + ); } } } @@ -492,8 +512,15 @@ impl MutArray for DenseMat let mut seen = std::collections::HashSet::new(); for r in 0..nrows { for c in 0..ncols { - let off = if column_major { r + c * nrows } else { r * ncols + c }; - assert!(off < len, "iterator_mut: offset {off} out of bounds (len={len})"); + let off = if column_major { + r + c * nrows + } else { + r * ncols + c + }; + assert!( + off < len, + "iterator_mut: offset {off} out of bounds (len={len})" + ); assert!(seen.insert(off), "iterator_mut: aliasing at offset {off}"); } } From ca6a3313acf9d22af783fe755df9efcc061cde6e Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Sat, 11 Apr 2026 15:45:04 +0100 Subject: [PATCH 13/13] style: apply rustfmt formatting to linalg/ndarray/matrix.rs (CI lint fix) --- src/linalg/ndarray/matrix.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/linalg/ndarray/matrix.rs b/src/linalg/ndarray/matrix.rs index c2c18c85..686ef6f0 100644 --- a/src/linalg/ndarray/matrix.rs +++ b/src/linalg/ndarray/matrix.rs @@ -66,10 +66,7 @@ impl MutArray // axis-1: column-major — axis_iter_mut(Axis(1)) yields each column as a // non-overlapping ArrayViewMut1; into_iter() gives &mut T. // No raw pointers or unsafe blocks required. - _ => Box::new( - self.axis_iter_mut(Axis(1)) - .flat_map(|col| col.into_iter()), - ), + _ => Box::new(self.axis_iter_mut(Axis(1)).flat_map(|col| col.into_iter())), } } } @@ -208,10 +205,7 @@ impl MutArray for ArrayVie // axis-1: column-major — axis_iter_mut(Axis(1)) yields each column as a // non-overlapping ArrayViewMut1; into_iter() gives &mut T. // No raw pointers or unsafe blocks required. - _ => Box::new( - self.axis_iter_mut(Axis(1)) - .flat_map(|col| col.into_iter()), - ), + _ => Box::new(self.axis_iter_mut(Axis(1)).flat_map(|col| col.into_iter())), } } }