Add BiLipREN implementations in Jax and Pytorch#48
Conversation
RoverLiu
left a comment
There was a problem hiding this comment.
I did not explicitly look into ren related files (ren_base_jax, ren_composition_jax, ren_composition_torch, ren_jax, ren_torch). @nic-barbara , could you please have a look at it?
| return w_eq_final | ||
|
|
||
|
|
||
| def peaceman_rachford_layer(activation, D11, b, tol=1e-9, alpha=0.8, |
| return w_eq_final | ||
|
|
||
|
|
||
| def douglas_rachford_layer(activation, D11, b, tol=1e-9, alpha=0.6, |
| return total_elements | ||
|
|
||
|
|
||
| def compute_lipschitz_constants( |
There was a problem hiding this comment.
I am not sure if we actually need this function in robustnn package. This feels like an evaluation approach and is not related to model itself
There was a problem hiding this comment.
When using BiLipREN, we might want to estimate the empirical Lip bounds from data as a baseline for choosing Lip bounds for BiLipREN
There was a problem hiding this comment.
Pull request overview
This PR adds bi-Lipschitz Recurrent Equilibrium Network (BiLipREN) implementations and supporting layers in both JAX and PyTorch, along with shared solver utilities and tests to validate forward/inverse behavior, bounds, and gradients.
Changes:
- Added JAX and PyTorch BiLipschitz REN implementations (including explicit inverse support) plus composition networks with (dynamic) unitary layers.
- Consolidated fixed-point / splitting solvers into a shared
robustnn/solvers.pymodule and updated MonLipNet imports accordingly. - Added new JAX and PyTorch test suites covering forward evaluation, inverse round-trips, empirical bi-Lipschitz bounds, and gradient flow.
Reviewed changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| test/test_bilipren.py | New JAX tests for BiLipschitzREN and CompositionREN, including inverse and bound checks. |
| test/test_bilipren_torch.py | New PyTorch tests for BiLipschitzREN and CompositionREN, including inverse and bound checks. |
| robustnn/utils.py | Removes trailing stray whitespace lines. |
| robustnn/solvers.py | New shared solver module (Davis–Yin + equilibrium-layer splitting solvers). |
| robustnn/solver_DYS.py | Removed; Davis–Yin solver moved into robustnn/solvers.py. |
| robustnn/ren_torch.py | New PyTorch REN + bi-Lipschitz parameterization and inverse evaluation utilities. |
| robustnn/ren_jax.py | Adds _get_qsr override hook and introduces BiLipschitzREN as a GeneralREN specialization. |
| robustnn/ren_composition_torch.py | New PyTorch composition of BiLipschitz REN blocks with (dynamic) unitary layers and inverses. |
| robustnn/ren_composition_jax.py | New JAX composition of BiLipschitz REN blocks with (dynamic) unitary layers and inverses. |
| robustnn/ren_base_jax.py | Adds full (non-triangular) equilibrium solver support with an IFT-based custom VJP + explicit inverse plumbing. |
| robustnn/orthogonal_torch.py | Trailing newline addition. |
| robustnn/orthogonal_jax.py | Trailing newline addition. |
| robustnn/monlipnet_torch.py | Updates import to use robustnn.solvers.DavisYinSplit. |
| robustnn/monlipnet_jax.py | Updates import to use robustnn.solvers.DavisYinSplit. |
| robustnn/dyn_orthogonal_torch.py | New PyTorch dynamic unitary (stateful orthogonal) layer. |
| robustnn/dyn_orthogonal_jax.py | New JAX dynamic unitary (stateful orthogonal) layer. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| for i in range(w_eq_bar.shape[0]): | ||
| ji = j_diag[i, ...] | ||
| y_bar_i = y_bar[i, ...] | ||
| w_grad = jnp.linalg.solve(I - (ji * D11.T), y_bar_i.T).T | ||
| w_eq_bar = w_eq_bar.at[i, ...].set(w_grad) |
| def _equilibrium_ift_grad_fwd(activation, D11, v, w_eq): | ||
| I = jnp.identity(v.shape[-1]) | ||
| return w_eq, (D11, v, I) |
| def _get_qsr(self): | ||
| n = self.input_size | ||
| I = torch.eye(n, dtype=self._dtype) | ||
| alpha1 = 2.0 * (self.mu * self.nu) / (self.mu + self.nu) | ||
| alpha2 = 2.0 / (self.mu + self.nu) | ||
| Q = -alpha2 * I | ||
| S = I.clone() | ||
| R = -alpha1 * I | ||
| return Q, S, R |
| nu, nx, ny = self.input_size, self.state_size, self.output_size | ||
| Q, S, R = self._get_qsr() | ||
| I_ny = torch.eye(ny, dtype=Q.dtype) | ||
| I_nu = torch.eye(nu, dtype=Q.dtype) | ||
|
|
| R1 = R + S @ D22 + D22.T @ S.T + D22.T @ Q @ D22 | ||
| mul_Q = torch.cat((self.C2, self.D21, torch.zeros((ny, nx), dtype=Q.dtype)), dim=1) | ||
| mul_R = torch.cat((C2_imp, D21_imp, self.B2.T), dim=1) | ||
| Gamma_Q = mul_Q.T @ Q @ mul_Q | ||
| Gamma_R = mul_R.T @ torch.linalg.solve(R1, mul_R) |
No description provided.