Skip to content

Add BiLipREN implementations in Jax and Pytorch#48

Open
yuruizhang06 wants to merge 4 commits into
acfr:mainfrom
yuruizhang06:main
Open

Add BiLipREN implementations in Jax and Pytorch#48
yuruizhang06 wants to merge 4 commits into
acfr:mainfrom
yuruizhang06:main

Conversation

@yuruizhang06

Copy link
Copy Markdown
Collaborator

No description provided.

Copilot AI review requested due to automatic review settings June 30, 2026 09:35
@yuruizhang06 yuruizhang06 requested review from RoverLiu and removed request for Copilot June 30, 2026 09:36
Comment thread robustnn/solvers.py Outdated

@RoverLiu RoverLiu left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not explicitly look into ren related files (ren_base_jax, ren_composition_jax, ren_composition_torch, ren_jax, ren_torch). @nic-barbara , could you please have a look at it?

Comment thread robustnn/solvers.py Outdated
return w_eq_final


def peaceman_rachford_layer(activation, D11, b, tol=1e-9, alpha=0.8,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Comment thread robustnn/solvers.py Outdated
return w_eq_final


def douglas_rachford_layer(activation, D11, b, tol=1e-9, alpha=0.6,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Comment thread robustnn/orthogonal_jax.py Outdated
Comment thread robustnn/orthogonal_torch.py Outdated
Comment thread robustnn/utils.py Outdated
return total_elements


def compute_lipschitz_constants(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if we actually need this function in robustnn package. This feels like an evaluation approach and is not related to model itself

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using BiLipREN, we might want to estimate the empirical Lip bounds from data as a baseline for choosing Lip bounds for BiLipREN

Copilot AI review requested due to automatic review settings July 1, 2026 15:51

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds bi-Lipschitz Recurrent Equilibrium Network (BiLipREN) implementations and supporting layers in both JAX and PyTorch, along with shared solver utilities and tests to validate forward/inverse behavior, bounds, and gradients.

Changes:

  • Added JAX and PyTorch BiLipschitz REN implementations (including explicit inverse support) plus composition networks with (dynamic) unitary layers.
  • Consolidated fixed-point / splitting solvers into a shared robustnn/solvers.py module and updated MonLipNet imports accordingly.
  • Added new JAX and PyTorch test suites covering forward evaluation, inverse round-trips, empirical bi-Lipschitz bounds, and gradient flow.

Reviewed changes

Copilot reviewed 16 out of 16 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
test/test_bilipren.py New JAX tests for BiLipschitzREN and CompositionREN, including inverse and bound checks.
test/test_bilipren_torch.py New PyTorch tests for BiLipschitzREN and CompositionREN, including inverse and bound checks.
robustnn/utils.py Removes trailing stray whitespace lines.
robustnn/solvers.py New shared solver module (Davis–Yin + equilibrium-layer splitting solvers).
robustnn/solver_DYS.py Removed; Davis–Yin solver moved into robustnn/solvers.py.
robustnn/ren_torch.py New PyTorch REN + bi-Lipschitz parameterization and inverse evaluation utilities.
robustnn/ren_jax.py Adds _get_qsr override hook and introduces BiLipschitzREN as a GeneralREN specialization.
robustnn/ren_composition_torch.py New PyTorch composition of BiLipschitz REN blocks with (dynamic) unitary layers and inverses.
robustnn/ren_composition_jax.py New JAX composition of BiLipschitz REN blocks with (dynamic) unitary layers and inverses.
robustnn/ren_base_jax.py Adds full (non-triangular) equilibrium solver support with an IFT-based custom VJP + explicit inverse plumbing.
robustnn/orthogonal_torch.py Trailing newline addition.
robustnn/orthogonal_jax.py Trailing newline addition.
robustnn/monlipnet_torch.py Updates import to use robustnn.solvers.DavisYinSplit.
robustnn/monlipnet_jax.py Updates import to use robustnn.solvers.DavisYinSplit.
robustnn/dyn_orthogonal_torch.py New PyTorch dynamic unitary (stateful orthogonal) layer.
robustnn/dyn_orthogonal_jax.py New JAX dynamic unitary (stateful orthogonal) layer.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread robustnn/ren_base_jax.py
Comment on lines +82 to +86
for i in range(w_eq_bar.shape[0]):
ji = j_diag[i, ...]
y_bar_i = y_bar[i, ...]
w_grad = jnp.linalg.solve(I - (ji * D11.T), y_bar_i.T).T
w_eq_bar = w_eq_bar.at[i, ...].set(w_grad)
Comment thread robustnn/ren_base_jax.py
Comment on lines +67 to +69
def _equilibrium_ift_grad_fwd(activation, D11, v, w_eq):
I = jnp.identity(v.shape[-1])
return w_eq, (D11, v, I)
Comment thread robustnn/ren_torch.py
Comment on lines +296 to +304
def _get_qsr(self):
n = self.input_size
I = torch.eye(n, dtype=self._dtype)
alpha1 = 2.0 * (self.mu * self.nu) / (self.mu + self.nu)
alpha2 = 2.0 / (self.mu + self.nu)
Q = -alpha2 * I
S = I.clone()
R = -alpha1 * I
return Q, S, R
Comment thread robustnn/ren_torch.py
Comment on lines +307 to +311
nu, nx, ny = self.input_size, self.state_size, self.output_size
Q, S, R = self._get_qsr()
I_ny = torch.eye(ny, dtype=Q.dtype)
I_nu = torch.eye(nu, dtype=Q.dtype)

Comment thread robustnn/ren_torch.py
Comment on lines +329 to +333
R1 = R + S @ D22 + D22.T @ S.T + D22.T @ Q @ D22
mul_Q = torch.cat((self.C2, self.D21, torch.zeros((ny, nx), dtype=Q.dtype)), dim=1)
mul_R = torch.cat((C2_imp, D21_imp, self.B2.T), dim=1)
Gamma_Q = mul_Q.T @ Q @ mul_Q
Gamma_R = mul_R.T @ torch.linalg.solve(R1, mul_R)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants