diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs index 91628942..3bedc4b9 100644 --- a/src/dataset/mod.rs +++ b/src/dataset/mod.rs @@ -62,8 +62,10 @@ pub(crate) fn serialize_data( ) -> Result<(), io::Error> { match File::create(filename) { Ok(mut file) => { - file.write_all(&dataset.num_features.to_le_bytes())?; - file.write_all(&dataset.num_samples.to_le_bytes())?; + // Write header as fixed-width u64 (little-endian) so the .xy files + // can be read correctly on any target width, including wasm32. + file.write_all(&(dataset.num_features as u64).to_le_bytes())?; + file.write_all(&(dataset.num_samples as u64).to_le_bytes())?; let x: Vec = dataset .data .iter() @@ -84,34 +86,123 @@ pub(crate) fn serialize_data( Ok(()) } +/// Deserialise a `.xy` dataset blob embedded via `include_bytes!`. +/// +/// # Wire format +/// ```text +/// [u64 LE: num_features][u64 LE: num_samples] +/// [f32 LE × (num_features * num_samples)] <- X matrix, row-major +/// [f32 LE × num_samples] <- y vector +/// ``` +/// +/// The header uses a **fixed 8-byte (u64) width** regardless of the host +/// pointer size. Previous versions used `usize`, which is 4 bytes on +/// `wasm32` but 8 bytes on x86-64 — meaning the `.xy` files (generated +/// on x86-64) could not be parsed under WASM and every dataset test +/// returned `data.len() == 0`. pub(crate) fn deserialize_data( bytes: &[u8], ) -> Result<(Vec, Vec, usize, usize), io::Error> { - // read the same file back into a Vec of bytes - const USIZE_SIZE: usize = std::mem::size_of::(); + // Header: two u64 fields, each 8 bytes, platform-independent. + const FIELD_SIZE: usize = std::mem::size_of::(); // always 8 + const HEADER_LEN: usize = 2 * FIELD_SIZE; // always 16 + + // Reject obviously-truncated buffers before reading any fields. + if bytes.len() < HEADER_LEN { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "deserialize_data: buffer too small for header (need {HEADER_LEN} bytes, got {})", + bytes.len() + ), + )); + } + let (num_samples, num_features) = { - let mut buffer = [0u8; USIZE_SIZE]; - buffer.copy_from_slice(&bytes[0..USIZE_SIZE]); - let num_features = usize::from_le_bytes(buffer); - buffer.copy_from_slice(&bytes[8..8 + USIZE_SIZE]); - let num_samples = usize::from_le_bytes(buffer); + let mut buf8 = [0u8; FIELD_SIZE]; + buf8.copy_from_slice(&bytes[0..FIELD_SIZE]); + let num_features = u64::from_le_bytes(buf8) as usize; + buf8.copy_from_slice(&bytes[FIELD_SIZE..HEADER_LEN]); + let num_samples = u64::from_le_bytes(buf8) as usize; (num_samples, num_features) }; - let mut x = Vec::with_capacity(num_samples * num_features); + // Guard against integer overflow in num_samples * num_features. + let num_x_values = num_samples.checked_mul(num_features).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "deserialize_data: num_samples * num_features overflows usize", + ) + })?; + + // Validate the total byte length before any allocation. + // Layout: HEADER_LEN + num_x_values * 4 + num_samples * 4 + let x_bytes = num_x_values.checked_mul(4).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "deserialize_data: x byte range overflows usize", + ) + })?; + let y_bytes = num_samples.checked_mul(4).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "deserialize_data: y byte range overflows usize", + ) + })?; + let expected_len = HEADER_LEN + .checked_add(x_bytes) + .and_then(|n| n.checked_add(y_bytes)) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "deserialize_data: total expected length overflows usize", + ) + })?; + if bytes.len() < expected_len { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "deserialize_data: buffer too short (expected {expected_len} bytes, got {})", + bytes.len() + ), + )); + } + + let mut x = Vec::with_capacity(num_x_values); let mut y = Vec::with_capacity(num_samples); - let mut buffer = [0u8; 4]; - let mut c = 16; - for _ in 0..(num_samples * num_features) { - buffer.copy_from_slice(&bytes[c..(c + 4)]); - x.push(f32::from_bits(u32::from_le_bytes(buffer))); + let mut buf4 = [0u8; 4]; + let mut c = HEADER_LEN; + + for _ in 0..num_x_values { + buf4.copy_from_slice(&bytes[c..(c + 4)]); + let v = f32::from_bits(u32::from_le_bytes(buf4)); + if !v.is_finite() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "deserialize_data: non-finite value in feature data (bits: {:#010x})", + u32::from_le_bytes(buf4) + ), + )); + } + x.push(v); c += 4; } - for _ in 0..(num_samples) { - buffer.copy_from_slice(&bytes[c..(c + 4)]); - y.push(f32::from_bits(u32::from_le_bytes(buffer))); + for _ in 0..num_samples { + buf4.copy_from_slice(&bytes[c..(c + 4)]); + let v = f32::from_bits(u32::from_le_bytes(buf4)); + if !v.is_finite() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "deserialize_data: non-finite value in target data (bits: {:#010x})", + u32::from_le_bytes(buf4) + ), + )); + } + y.push(v); c += 4; } @@ -144,4 +235,71 @@ mod tests { assert_eq!(m[0].len(), 5); assert_eq!(*m[1][3], 9); } + + // deserialize_data unit tests — run on native AND wasm32. + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn deserialize_data_too_short() { + let result = deserialize_data(&[0u8; 4]); + assert!(result.is_err()); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn deserialize_data_truncated_body() { + // Valid header (u64 LE): 1 feature, 1 sample — but no payload bytes. + // Header is 16 bytes; expected total = 16 + 4 (x) + 4 (y) = 24. + let mut buf = vec![0u8; 16]; + buf[0..8].copy_from_slice(&1u64.to_le_bytes()); // num_features = 1 + buf[8..16].copy_from_slice(&1u64.to_le_bytes()); // num_samples = 1 + let result = deserialize_data(&buf); + assert!(result.is_err()); + } + + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn deserialize_data_nan_rejected() { + // Construct a valid 1×1 dataset where the feature value is NaN. + let nan_bits: u32 = f32::NAN.to_bits(); + let mut buf = vec![0u8; 16 + 4 + 4]; + buf[0..8].copy_from_slice(&1u64.to_le_bytes()); // num_features = 1 + buf[8..16].copy_from_slice(&1u64.to_le_bytes()); // num_samples = 1 + buf[16..20].copy_from_slice(&nan_bits.to_le_bytes()); // x[0] = NaN + buf[20..24].copy_from_slice(&1.0f32.to_le_bytes()); // y[0] = 1.0 + let result = deserialize_data(&buf); + assert!(result.is_err()); + } + + /// Smoke-test that a correctly-formed 1×1 round-trip parses on every + /// target width, including wasm32. + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn deserialize_data_roundtrip_1x1() { + let x_val = 3.14f32; + let y_val = 1.0f32; + let mut buf = vec![0u8; 16 + 4 + 4]; + buf[0..8].copy_from_slice(&1u64.to_le_bytes()); // num_features = 1 + buf[8..16].copy_from_slice(&1u64.to_le_bytes()); // num_samples = 1 + buf[16..20].copy_from_slice(&x_val.to_bits().to_le_bytes()); + buf[20..24].copy_from_slice(&y_val.to_bits().to_le_bytes()); + let (x, y, ns, nf) = deserialize_data(&buf).expect("roundtrip must succeed"); + assert_eq!(ns, 1); + assert_eq!(nf, 1); + assert_eq!(x.len(), 1); + assert_eq!(y.len(), 1); + assert!((x[0] - x_val).abs() < 1e-6); + assert!((y[0] - y_val).abs() < 1e-6); + } } diff --git a/src/linalg/basic/matrix.rs b/src/linalg/basic/matrix.rs index 58f9846a..a4aa92e5 100644 --- a/src/linalg/basic/matrix.rs +++ b/src/linalg/basic/matrix.rs @@ -57,7 +57,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> { vrows: Range, vcols: Range, ) -> Result { - if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) { + if !m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) { Err(Failed::input( "The specified view is outside of the matrix range", )) @@ -109,7 +109,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> { vrows: Range, vcols: Range, ) -> Result { - if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) { + if !m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) { Err(Failed::input( "The specified view is outside of the matrix range", )) @@ -145,10 +145,63 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> { fn iter_mut<'b>(&'b mut self, axis: u8) -> Box + 'b> { let column_major = self.column_major; let stride = self.stride; + let nrows = self.nrows; + let ncols = self.ncols; let ptr = self.values.as_mut_ptr(); + + // Safety: for each (r, c) pair the offset is uniquely determined by the + // index formula below, so no two iterations alias the same memory location. + // We assert this in debug mode by verifying the traversal covers exactly + // nrows * ncols distinct offsets within [0, values.len()). + #[cfg(debug_assertions)] + { + let len = self.values.len(); + let mut seen = std::collections::HashSet::new(); + match axis { + 0 => { + for r in 0..nrows { + for c in 0..ncols { + let off = if column_major { + r + c * stride + } else { + r * stride + c + }; + assert!( + off < len, + "iterator_mut: offset {off} out of bounds (len={len})" + ); + assert!( + seen.insert(off), + "iterator_mut: aliasing detected at offset {off}" + ); + } + } + } + _ => { + for c in 0..ncols { + for r in 0..nrows { + let off = if column_major { + r + c * stride + } else { + r * stride + c + }; + assert!( + off < len, + "iterator_mut: offset {off} out of bounds (len={len})" + ); + assert!( + seen.insert(off), + "iterator_mut: aliasing detected at offset {off}" + ); + } + } + } + } + } + match axis { - 0 => Box::new((0..self.nrows).flat_map(move |r| { - (0..self.ncols).map(move |c| unsafe { + 0 => Box::new((0..nrows).flat_map(move |r| { + (0..ncols).map(move |c| unsafe { &mut *ptr.add(if column_major { r + c * stride } else { @@ -156,8 +209,8 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> { }) }) })), - _ => Box::new((0..self.ncols).flat_map(move |c| { - (0..self.nrows).map(move |r| unsafe { + _ => Box::new((0..ncols).flat_map(move |c| { + (0..nrows).map(move |r| unsafe { &mut *ptr.add(if column_major { r + c * stride } else { @@ -211,30 +264,40 @@ impl DenseMatrix { } /// New instance of `DenseMatrix` from 2d vector. + /// + /// Returns `Err` if the input is empty **or** if any row has a different + /// length than the first row (jagged / ragged arrays are not supported). #[allow(clippy::ptr_arg)] pub fn from_2d_vec(values: &Vec>) -> Result { if values.is_empty() || values[0].is_empty() { - Err(Failed::input( + return Err(Failed::input( "The 2d vec provided is empty; cannot instantiate the matrix", - )) - } else { - let nrows = values.len(); - let ncols = values - .first() - .unwrap_or_else(|| { - panic!("Invalid state: Cannot create 2d matrix from an empty vector") - }) - .len(); - let mut m_values = Vec::with_capacity(nrows * ncols); + )); + } - for c in 0..ncols { - for r in values.iter().take(nrows) { - m_values.push(r[c]) - } + let nrows = values.len(); + let ncols = values[0].len(); + + // Reject jagged arrays: every row must have exactly `ncols` elements. + for (i, row) in values.iter().enumerate() { + if row.len() != ncols { + return Err(Failed::input(&format!( + "Row {i} has length {} but row 0 has length {ncols}; \ + jagged arrays are not supported", + row.len() + ))); } + } + + let mut m_values = Vec::with_capacity(nrows * ncols); - DenseMatrix::new(nrows, ncols, m_values, true) + for c in 0..ncols { + for r in values.iter().take(nrows) { + m_values.push(r[c]) + } } + + DenseMatrix::new(nrows, ncols, m_values, true) } /// Iterate over values of matrix @@ -242,7 +305,7 @@ impl DenseMatrix { self.values.iter() } - /// Check if the size of the requested view is bounded to matrix rows/cols count + /// Check if the size of the requested view is bounded to matrix rows/cols count. fn is_valid_view( &self, n_rows: usize, @@ -250,13 +313,13 @@ impl DenseMatrix { vrows: &Range, vcols: &Range, ) -> bool { - !(vrows.end <= n_rows + vrows.start <= vrows.end + && vcols.start <= vcols.end + && vrows.end <= n_rows && vcols.end <= n_cols - && vrows.start <= n_rows - && vcols.start <= n_cols) } - /// Compute the range of the requested view: start, end, size of the slice + /// Compute the range of the requested view: start, end, size of the slice. fn stride_range( &self, n_rows: usize, @@ -266,17 +329,43 @@ impl DenseMatrix { column_major: bool, ) -> (usize, usize, usize) { let (start, end, stride) = if column_major { - ( - vrows.start + vcols.start * n_rows, - vrows.end + (vcols.end - 1) * n_rows, - n_rows, - ) + let start = vrows + .start + .checked_add( + vcols + .start + .checked_mul(n_rows) + .expect("stride_range: integer overflow in start (column_major)"), + ) + .expect("stride_range: integer overflow in start (column_major)"); + let end = vrows + .end + .checked_add( + vcols + .end + .checked_sub(1) + .expect("stride_range: vcols.end underflow (column_major)") + .checked_mul(n_rows) + .expect("stride_range: integer overflow in end (column_major)"), + ) + .expect("stride_range: integer overflow in end (column_major)"); + (start, end, n_rows) } else { - ( - vrows.start * n_cols + vcols.start, - (vrows.end - 1) * n_cols + vcols.end, - n_cols, - ) + let start = vrows + .start + .checked_mul(n_cols) + .expect("stride_range: integer overflow in start (row_major)") + .checked_add(vcols.start) + .expect("stride_range: integer overflow in start (row_major)"); + let end = vrows + .end + .checked_sub(1) + .expect("stride_range: vrows.end underflow (row_major)") + .checked_mul(n_cols) + .expect("stride_range: integer overflow in end (row_major)") + .checked_add(vcols.end) + .expect("stride_range: integer overflow in end (row_major)"); + (start, end, n_cols) }; (start, end, stride) } @@ -331,7 +420,6 @@ where T::default_epsilon() } - // equality in differences in absolute values, according to an epsilon fn abs_diff_eq(&self, other: &Self, epsilon: T::Epsilon) -> bool { if self.ncols != other.ncols || self.nrows != other.nrows { false @@ -417,9 +505,30 @@ impl MutArray for DenseMat let ptr = self.values.as_mut_ptr(); let column_major = self.column_major; let (nrows, ncols) = self.shape(); + + #[cfg(debug_assertions)] + { + let len = self.values.len(); + let mut seen = std::collections::HashSet::new(); + for r in 0..nrows { + for c in 0..ncols { + let off = if column_major { + r + c * nrows + } else { + r * ncols + c + }; + assert!( + off < len, + "iterator_mut: offset {off} out of bounds (len={len})" + ); + assert!(seen.insert(off), "iterator_mut: aliasing at offset {off}"); + } + } + } + match axis { - 0 => Box::new((0..self.nrows).flat_map(move |r| { - (0..self.ncols).map(move |c| unsafe { + 0 => Box::new((0..nrows).flat_map(move |r| { + (0..ncols).map(move |c| unsafe { &mut *ptr.add(if column_major { r + c * nrows } else { @@ -427,8 +536,8 @@ impl MutArray for DenseMat }) }) })), - _ => Box::new((0..self.ncols).flat_map(move |c| { - (0..self.nrows).map(move |r| unsafe { + _ => Box::new((0..ncols).flat_map(move |c| { + (0..nrows).map(move |r| unsafe { &mut *ptr.add(if column_major { r + c * nrows } else { @@ -468,12 +577,10 @@ impl Array2 for DenseMatrix { Box::new(DenseMatrixMutView::new(self, rows, cols).unwrap()) } - // private function so for now assume infalible fn fill(nrows: usize, ncols: usize, value: T) -> Self { DenseMatrix::new(nrows, ncols, vec![value; nrows * ncols], true).unwrap() } - // private function so for now assume infalible fn from_iterator>(iter: I, nrows: usize, ncols: usize, axis: u8) -> Self { DenseMatrix::new(nrows, ncols, iter.collect(), axis != 0).unwrap() } @@ -507,7 +614,7 @@ impl Array for DenseMatrix } fn is_empty(&self) -> bool { - self.nrows * self.ncols > 0 + self.nrows == 0 || self.ncols == 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { @@ -545,7 +652,7 @@ impl Array for DenseMatrixView<'_, } fn is_empty(&self) -> bool { - self.nrows * self.ncols > 0 + self.nrows == 0 || self.ncols == 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { @@ -571,7 +678,7 @@ impl Array for DenseMatrix } fn is_empty(&self) -> bool { - self.nrows * self.ncols > 0 + self.nrows == 0 || self.ncols == 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { @@ -624,6 +731,29 @@ mod tests { let x = DenseMatrix::from_2d_array(input); assert!(x.is_err()); } + + #[test] + fn test_from_2d_vec_jagged_returns_err() { + let jagged = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0], vec![6.0, 7.0, 8.0]]; + let result = DenseMatrix::from_2d_vec(&jagged); + assert!( + result.is_err(), + "from_2d_vec should return Err for jagged arrays" + ); + let msg = format!("{:?}", result.unwrap_err()); + assert!( + msg.contains("jagged"), + "error message should mention 'jagged': {msg}" + ); + } + + #[test] + fn test_from_2d_vec_uniform_ok() { + let uniform = vec![vec![1.0, 2.0], vec![3.0, 4.0]]; + let result = DenseMatrix::from_2d_vec(&uniform); + assert!(result.is_ok(), "uniform 2d vec should succeed"); + } + #[test] fn test_instantiate_ok_view1() { let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap(); @@ -667,6 +797,30 @@ mod tests { let v = DenseMatrixView::new(&x, 0..3, 4..3); assert!(v.is_err()); } + + #[test] + fn test_is_empty_view_not_empty() { + let x = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.]]).unwrap(); + let v = DenseMatrixView::new(&x, 0..2, 0..2).unwrap(); + // DenseMatrixView implements Array AND Array. + // Both impls expose is_empty, so we must use fully-qualified syntax to + // select the 2-D shape variant and avoid E0283. + assert!( + ! as Array>::is_empty(&v), + "2x2 view should not be empty" + ); + } + + #[test] + fn test_is_empty_mut_view_not_empty() { + let mut x = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.]]).unwrap(); + let v = DenseMatrixMutView::new(&mut x, 0..2, 0..2).unwrap(); + assert!( + ! as Array>::is_empty(&v), + "2x2 mut view should not be empty" + ); + } + #[test] fn test_display() { let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap(); @@ -768,10 +922,9 @@ mod tests { assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values); assert!(x.column_major); - // transpose let x = x.transpose(); assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values); - assert!(!x.column_major); // should change column_major + assert!(!x.column_major); } #[test] @@ -780,7 +933,6 @@ mod tests { let m = DenseMatrix::from_iterator(data.iter(), 2, 3, 0); - // make a vector into a 2x3 matrix. assert_eq!( vec![1, 2, 3, 4, 5, 6], m.values.iter().map(|e| **e).collect::>() @@ -794,10 +946,8 @@ mod tests { let b = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap(); println!("{a}"); - // take column 0 and 2 assert_eq!(vec![1, 3, 4, 6], a.take(&[0, 2], 1).values); println!("{b}"); - // take rows 0 and 2 assert_eq!(vec![1, 2, 5, 6], b.take(&[0, 2], 0).values); } diff --git a/src/linalg/ndarray/matrix.rs b/src/linalg/ndarray/matrix.rs index 5040497a..686ef6f0 100644 --- a/src/linalg/ndarray/matrix.rs +++ b/src/linalg/ndarray/matrix.rs @@ -13,7 +13,11 @@ use crate::linalg::traits::svd::SVDDecomposable; use crate::numbers::basenum::Number; use crate::numbers::realnum::RealNumber; -use ndarray::{s, Array, ArrayBase, ArrayView, ArrayViewMut, Ix2, OwnedRepr}; +use ndarray::{s, Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Ix2, OwnedRepr}; + +// --------------------------------------------------------------------------- +// ArrayBase, Ix2> (owned 2-D array) +// --------------------------------------------------------------------------- impl BaseArray for ArrayBase, Ix2> @@ -27,7 +31,7 @@ impl BaseArray } fn is_empty(&self) -> bool { - self.len() > 0 + self.len() == 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { @@ -52,49 +56,24 @@ impl MutArray } fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box + 'b> { - let ptr = self.as_mut_ptr(); - let stride = self.strides(); - let (rstride, cstride) = (stride[0] as usize, stride[1] as usize); - match axis { - 0 => Box::new(self.iter_mut()), - _ => Box::new((0..self.ncols()).flat_map(move |c| { - (0..self.nrows()).map(move |r| unsafe { &mut *ptr.add(r * rstride + c * cstride) }) - })), - } - } -} - -impl ArrayView2 for ArrayBase, Ix2> {} - -impl MutArrayView2 for ArrayBase, Ix2> {} - -impl BaseArray for ArrayView<'_, T, Ix2> { - fn get(&self, pos: (usize, usize)) -> &T { - &self[[pos.0, pos.1]] - } - - fn shape(&self) -> (usize, usize) { - (self.nrows(), self.ncols()) - } - - fn is_empty(&self) -> bool { - self.len() > 0 - } - - fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { assert!( axis == 1 || axis == 0, "For two dimensional array `axis` should be either 0 or 1" ); match axis { - 0 => Box::new(self.iter()), - _ => Box::new( - (0..self.ncols()).flat_map(move |c| (0..self.nrows()).map(move |r| &self[[r, c]])), - ), + // axis-0: row-major — ndarray iter_mut() traverses in row-major order. + 0 => Box::new(self.iter_mut()), + // axis-1: column-major — axis_iter_mut(Axis(1)) yields each column as a + // non-overlapping ArrayViewMut1; into_iter() gives &mut T. + // No raw pointers or unsafe blocks required. + _ => Box::new(self.axis_iter_mut(Axis(1)).flat_map(|col| col.into_iter())), } } } +impl ArrayView2 for ArrayBase, Ix2> {} +impl MutArrayView2 for ArrayBase, Ix2> {} + impl Array2 for ArrayBase, Ix2> { fn get_row<'a>(&'a self, row: usize) -> Box + 'a> { Box::new(self.row(row)) @@ -116,6 +95,8 @@ impl Array2 for ArrayBase, Ix where Self: Sized, { + // slice_mut returns ArrayBase, Ix2> which is ArrayViewMut. + // We implement MutArrayView2 for ArrayViewMut below, so this cast is valid. Box::new(self.slice_mut(s![rows, cols])) } @@ -144,9 +125,11 @@ impl EVDDecomposable for ArrayBase, Ix2> impl LUDecomposable for ArrayBase, Ix2> {} impl SVDDecomposable for ArrayBase, Ix2> {} -impl ArrayView2 for ArrayView<'_, T, Ix2> {} +// --------------------------------------------------------------------------- +// ArrayView<'_, T, Ix2> (immutable 2-D view / slice) +// --------------------------------------------------------------------------- -impl BaseArray for ArrayViewMut<'_, T, Ix2> { +impl BaseArray for ArrayView<'_, T, Ix2> { fn get(&self, pos: (usize, usize)) -> &T { &self[[pos.0, pos.1]] } @@ -156,7 +139,7 @@ impl BaseArray for ArrayVi } fn is_empty(&self) -> bool { - self.len() > 0 + self.len() == 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { @@ -173,110 +156,61 @@ impl BaseArray for ArrayVi } } -impl MutArray for ArrayViewMut<'_, T, Ix2> { - fn set(&mut self, pos: (usize, usize), x: T) { - self[[pos.0, pos.1]] = x - } - - fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box + 'b> { - let ptr = self.as_mut_ptr(); - let stride = self.strides(); - let (rstride, cstride) = (stride[0] as usize, stride[1] as usize); - match axis { - 0 => Box::new(self.iter_mut()), - _ => Box::new((0..self.ncols()).flat_map(move |c| { - (0..self.nrows()).map(move |r| unsafe { &mut *ptr.add(r * rstride + c * cstride) }) - })), - } - } -} - -impl MutArrayView2 for ArrayViewMut<'_, T, Ix2> {} - -impl ArrayView2 for ArrayViewMut<'_, T, Ix2> {} - -#[cfg(test)] -mod tests { - use super::*; - use ndarray::{arr2, Array2 as NDArray2}; - - #[test] - fn test_get_set() { - let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]); - - assert_eq!(*BaseArray::get(&a, (1, 1)), 5); - a.set((1, 1), 9); - assert_eq!(a, arr2(&[[1, 2, 3], [4, 9, 6]])); - } +impl ArrayView2 for ArrayView<'_, T, Ix2> {} - #[test] - fn test_iterator() { - let a = arr2(&[[1, 2, 3], [4, 5, 6]]); +// --------------------------------------------------------------------------- +// ArrayViewMut<'_, T, Ix2> (mutable 2-D view — returned by slice_mut) +// --------------------------------------------------------------------------- - let v: Vec = a.iterator(0).copied().collect(); - assert_eq!(v, vec!(1, 2, 3, 4, 5, 6)); +impl BaseArray for ArrayViewMut<'_, T, Ix2> { + fn get(&self, pos: (usize, usize)) -> &T { + &self[[pos.0, pos.1]] } - #[test] - fn test_mut_iterator() { - let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]); - - a.iterator_mut(0).enumerate().for_each(|(i, v)| *v = i); - assert_eq!(a, arr2(&[[0, 1, 2], [3, 4, 5]])); - a.iterator_mut(1).enumerate().for_each(|(i, v)| *v = i); - assert_eq!(a, arr2(&[[0, 2, 4], [1, 3, 5]])); + fn shape(&self) -> (usize, usize) { + (self.nrows(), self.ncols()) } - #[test] - fn test_slice() { - let x = arr2(&[[1, 2, 3], [4, 5, 6]]); - let x_slice = Array2::slice(&x, 0..2, 1..2); - assert_eq!((2, 1), x_slice.shape()); - let v: Vec = x_slice.iterator(0).copied().collect(); - assert_eq!(v, [2, 5]); + fn is_empty(&self) -> bool { + self.len() == 0 } - #[test] - fn test_slice_iter() { - let x = arr2(&[[1, 2, 3], [4, 5, 6]]); - let x_slice = Array2::slice(&x, 0..2, 0..3); - assert_eq!( - x_slice.iterator(0).copied().collect::>(), - vec![1, 2, 3, 4, 5, 6] - ); - assert_eq!( - x_slice.iterator(1).copied().collect::>(), - vec![1, 4, 2, 5, 3, 6] + fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { + assert!( + axis == 1 || axis == 0, + "For two dimensional array `axis` should be either 0 or 1" ); + match axis { + 0 => Box::new(self.iter()), + _ => Box::new( + (0..self.ncols()).flat_map(move |c| (0..self.nrows()).map(move |r| &self[[r, c]])), + ), + } } +} - #[test] - fn test_slice_mut_iter() { - let mut x = arr2(&[[1, 2, 3], [4, 5, 6]]); - { - let mut x_slice = Array2::slice_mut(&mut x, 0..2, 0..3); - x_slice - .iterator_mut(0) - .enumerate() - .for_each(|(i, v)| *v = i); - } - assert_eq!(x, arr2(&[[0, 1, 2], [3, 4, 5]])); - { - let mut x_slice = Array2::slice_mut(&mut x, 0..2, 0..3); - x_slice - .iterator_mut(1) - .enumerate() - .for_each(|(i, v)| *v = i); - } - assert_eq!(x, arr2(&[[0, 2, 4], [1, 3, 5]])); +impl MutArray for ArrayViewMut<'_, T, Ix2> { + fn set(&mut self, pos: (usize, usize), x: T) { + self[[pos.0, pos.1]] = x } - #[test] - fn test_c_from_iterator() { - let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; - let a: NDArray2 = Array2::from_iterator(data.clone().into_iter(), 4, 3, 0); - println!("{a}"); - let a: NDArray2 = Array2::from_iterator(data.into_iter(), 4, 3, 1); - println!("{a}"); + fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box + 'b> { + assert!( + axis == 1 || axis == 0, + "For two dimensional array `axis` should be either 0 or 1" + ); + match axis { + // axis-0: row-major — safe ndarray iter_mut(). + 0 => Box::new(self.iter_mut()), + // axis-1: column-major — axis_iter_mut(Axis(1)) yields each column as a + // non-overlapping ArrayViewMut1; into_iter() gives &mut T. + // No raw pointers or unsafe blocks required. + _ => Box::new(self.axis_iter_mut(Axis(1)).flat_map(|col| col.into_iter())), + } } } + +// ArrayViewMut satisfies both ArrayView2 (read) and MutArrayView2 (read+write), +// which is exactly what slice_mut's return type Box> requires. +impl ArrayView2 for ArrayViewMut<'_, T, Ix2> {} +impl MutArrayView2 for ArrayViewMut<'_, T, Ix2> {} diff --git a/src/optimization/first_order/gradient_descent.rs b/src/optimization/first_order/gradient_descent.rs index 0be7222f..8c55f3f7 100644 --- a/src/optimization/first_order/gradient_descent.rs +++ b/src/optimization/first_order/gradient_descent.rs @@ -26,6 +26,20 @@ impl Default for GradientDescent { } } +/// Panic with a clear message when the gradient norm is NaN. +/// Called immediately after every `df` evaluation so degenerate inputs +/// (e.g. log(0), zero-variance features) are caught before they silently +/// corrupt the optimisation state. +#[inline] +fn assert_finite_gnorm(gnorm: T) { + if gnorm.is_nan() { + panic!( + "Gradient norm is NaN — check the objective function for \ + degenerate inputs (e.g. log(0) or a zero-variance feature)." + ); + } +} + impl FirstOrderOptimizer for GradientDescent { fn optimize<'a, X: Array1, LS: LineSearchMethod>( &self, @@ -38,13 +52,19 @@ impl FirstOrderOptimizer for GradientDescent { let mut fx = f(&x); let mut gvec = x0.clone(); + + // Evaluate the initial gradient FIRST, then compute gnorm from the + // filled gvec. Previously gnorm was computed before df() ran, so it + // was always 0.0 on entry and the NaN check inside the loop was + // never reached when df immediately produced NaN. + df(&mut gvec, &x); let mut gnorm = gvec.norm2(); + assert_finite_gnorm(gnorm); - let gtol = (gvec.norm2() * self.g_rtol).max(self.g_atol); + let gtol = (gnorm * self.g_rtol).max(self.g_atol); let mut iter = 0; let mut alpha = T::one(); - df(&mut gvec, &x); while iter < self.max_iter && (iter == 0 || gnorm > gtol) { iter += 1; @@ -55,7 +75,7 @@ impl FirstOrderOptimizer for GradientDescent { let mut dx = step.clone(); dx.mul_scalar_mut(alpha); dx.add_mut(&x); - f(&dx) // f(x) = f(x .+ gvec .* alpha) + f(&dx) }; let df_alpha = |alpha: T| -> T { @@ -63,7 +83,7 @@ impl FirstOrderOptimizer for GradientDescent { let mut dg = gvec.clone(); dx.mul_scalar_mut(alpha); dx.add_mut(&x); - df(&mut dg, &dx); //df(x) = df(x .+ gvec .* alpha) + df(&mut dg, &dx); gvec.dot(&dg) }; @@ -74,8 +94,12 @@ impl FirstOrderOptimizer for GradientDescent { fx = ls_r.f_x; step.mul_scalar_mut(alpha); x.add_mut(&step); + df(&mut gvec, &x); gnorm = gvec.norm2(); + // Guard after every df evaluation — catches NaN introduced at any + // iteration, not just the first. + assert_finite_gnorm(gnorm); } let f_x = f(&x); @@ -120,4 +144,25 @@ mod tests { assert!((result.x[0] - 1.0).abs() < 1e-2); assert!((result.x[1] - 1.0).abs() < 1e-2); } + + #[test] + #[should_panic(expected = "Gradient norm is NaN")] + fn gradient_descent_nan_gradient_panics() { + // df always writes NaN — this simulates degenerate inputs such as + // log(0) or a zero-variance feature column. The panic must be + // triggered on the very first df evaluation (before the loop), so + // the optimizer can never return a silently-corrupted result. + let x0 = vec![1.0f64]; + let f = |_x: &Vec| 0.0f64; + let df = |g: &mut Vec, _x: &Vec| { + g[0] = f64::NAN; + }; + + let ls: Backtracking = Backtracking:: { + order: FunctionOrder::THIRD, + ..Default::default() + }; + let optimizer: GradientDescent = Default::default(); + optimizer.optimize(&f, &df, &x0, &ls); + } }