From f702be5eaf984b3873950752b7ec4b37856e30d4 Mon Sep 17 00:00:00 2001 From: yurekami Date: Wed, 31 Dec 2025 16:22:23 +0900 Subject: [PATCH] Add MiMo-7B model support for introspective interpretation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Integrate Xiaomi's MiMo-7B reasoning model as an alternative to Llama for self-explanation tasks. MiMo-7B achieves 95.8% on MATH500 and 68.2% on AIME 2024, rivaling OpenAI o1-mini performance. Changes: - Add ContinuousMiMo adapter class based on Qwen2 architecture - Register MiMo models in MODEL_TYPE_TO_VANILLA_MODEL_MAPPING - Add config files for feature descriptions, activation patching, and input ablation tasks - Support all MiMo variants (Base, SFT, RL, RL-Zero, RL-0530) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- config/act_patch/mimo_mimo_act_patch_cf.yaml | 80 ++++ config/feature_descriptions/mimo_131k.yaml | 80 ++++ config/input_ablation/mimo_mimo_hint.yaml | 61 +++ model/__init__.py | 2 + model/continuous_mimo.py | 422 +++++++++++++++++++ model/utils.py | 10 + 6 files changed, 655 insertions(+) create mode 100644 config/act_patch/mimo_mimo_act_patch_cf.yaml create mode 100644 config/feature_descriptions/mimo_131k.yaml create mode 100644 config/input_ablation/mimo_mimo_hint.yaml create mode 100644 model/continuous_mimo.py diff --git a/config/act_patch/mimo_mimo_act_patch_cf.yaml b/config/act_patch/mimo_mimo_act_patch_cf.yaml new file mode 100644 index 0000000..2d18021 --- /dev/null +++ b/config/act_patch/mimo_mimo_act_patch_cf.yaml @@ -0,0 +1,80 @@ +# MiMo-7B Activation Patching Configuration +# Uses MiMo-7B for activation patching predictions + +target_model_path: XiaomiMiMo/MiMo-7B-RL +model_path: XiaomiMiMo/MiMo-7B-RL +trust_remote_code: true + +# MiMo-specific settings +generation_config: + temperature: 0.6 + do_sample: true + +continuous_tokens: + "begin_continuous": "<|reserved_special_token_0|>" + "end_continuous": "<|reserved_special_token_1|>" + "continuous_rep": "<|reserved_special_token_2|>" + +output_dir: /data/artifacts/mimo/checkpoints/mimo_act_patch_cf +cache_dir: /data/artifacts/mimo/cache/ + +train: + num_samples: 100000000 + batch_size: 16 # Reduced for 7B model memory + save_strategy: steps + save_total_limit: 10 + save_steps: 2000 + learning_rate: !!float 5e-5 + num_epochs: 20 + eval_strategy: steps + eval_steps: 2000 + peft_lora: true + lora_r: 128 + dataset: ["counterfact"] + evaluation_type: mixed + intervention_path: Transluce/act_patch_mimo_7b_counterfact + hf_data_cache_dir: /data/artifacts/mimo/datasets/ + # MiMo-7B has 32 layers (similar to Llama-3.1-8B) + layers: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31] + bf16: true + tasks: + act_patch: + num_samples: 100000000 + question_types: + generative_explanation: + evaluation_type: exact_match + weight: 1.0 + prompts: + - messages: + - role: "user" + content: "If feature {begin_continuous}{feature}{end_continuous} at layer {layer} is added to tokens {tokens} when processing the text <<<{input}>>>, how would the output change?" + - role: "assistant" + content: "" + - messages: + - role: "user" + content: "When feature {begin_continuous}{feature}{end_continuous} at layer {layer} is added at tokens {tokens} in the input <<<{input}>>>, what happens to the model's output?" + - role: "assistant" + content: "" + - messages: + - role: "user" + content: "Consider the input text: <<<{input}>>>. If we steer layer {layer} towards feature {begin_continuous}{feature}{end_continuous} at tokens {tokens}, how does this affect the generated continuation?" + - role: "assistant" + content: "" + - messages: + - role: "user" + content: "Given the text <<<{input}>>>, what would be the effect on the output if feature {begin_continuous}{feature}{end_continuous} at layer {layer} is added to tokens {tokens}?" + - role: "assistant" + content: "" + - messages: + - role: "user" + content: "If we steer towards feature {begin_continuous}{feature}{end_continuous} at layer {layer} and tokens {tokens} when processing <<<{input}>>>, how would the model's response differ?" + - role: "assistant" + content: "" + +test: + batch_size: 8 + tasks: + act_patch: + num_samples: 3200 + intervention_path: Transluce/act_patch_mimo_7b_counterfact + evaluation_type: exact_match diff --git a/config/feature_descriptions/mimo_131k.yaml b/config/feature_descriptions/mimo_131k.yaml new file mode 100644 index 0000000..0243a0d --- /dev/null +++ b/config/feature_descriptions/mimo_131k.yaml @@ -0,0 +1,80 @@ +# MiMo-7B Feature Descriptions Configuration +# Uses MiMo-7B as the explainer model for SAE feature descriptions + +target_model_path: meta-llama/Llama-3.1-8B +model_path: XiaomiMiMo/MiMo-7B-RL +trust_remote_code: true + +# MiMo-specific settings +# MiMo recommends empty system prompt and temperature 0.6 +generation_config: + temperature: 0.6 + do_sample: true + +continuous_tokens: + "begin_continuous": "<|reserved_special_token_0|>" + "end_continuous": "<|reserved_special_token_1|>" + "continuous_rep": "<|reserved_special_token_2|>" + +output_dir: /data/artifacts/mimo/checkpoints/mimo_feature_descriptions +cache_dir: /data/artifacts/mimo/cache/ + +use_embed_proj: true + +train: + num_samples: 1000000000 + batch_size: 32 # Reduced for 7B model memory + save_strategy: steps + save_total_limit: 10 + save_steps: 8000 + learning_rate: !!float 5e-5 + num_epochs: 100 + eval_strategy: steps + eval_steps: 8000 + peft_lora: false + lora_r: 128 + dataset: ["sae_explanations"] + evaluation_type: mixed + explanation_dir: /data/artifacts/bzl/autointerp/datasets/SAE_feature_explanations_llama3.1_8b/ + split_keys: explanations/{split}_layer_feature_idxs.pkl + split_explanations_data: explanations/{split}_explanations.pkl + sae_save_path: features + layers: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31] + bf16: true + tasks: + features_explain: + question_types: + generative_explanation: + evaluation_type: semantic_similarity + weight: 1.0 + layers: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31] + prompts: + - messages: + - role: "user" + content: "Generate a description of this feature at layer {layer}: {begin_continuous}{feature}{end_continuous}" + - role: "assistant" + content: "" + - messages: + - role: "user" + content: "What does {begin_continuous}{feature}{end_continuous} encode at layer {layer}?" + - role: "assistant" + content: "" + - messages: + - role: "user" + content: "{begin_continuous}{feature}{end_continuous} activates at layer {layer} for inputs with the following features:" + - role: "assistant" + - messages: + - role: "user" + content: "What does {begin_continuous}{feature}{end_continuous} mean at layer {layer}?" + - role: "assistant" + - messages: + - role: "user" + content: "Layer {layer}, {begin_continuous}{feature}{end_continuous} means?" + - role: "assistant" + +test: + batch_size: 32 + tasks: + features_explain: + num_samples: 1600 + evaluation_type: semantic_similarity diff --git a/config/input_ablation/mimo_mimo_hint.yaml b/config/input_ablation/mimo_mimo_hint.yaml new file mode 100644 index 0000000..c99c36a --- /dev/null +++ b/config/input_ablation/mimo_mimo_hint.yaml @@ -0,0 +1,61 @@ +# MiMo-7B Input Ablation Configuration +# Uses MiMo-7B for predicting input ablation effects + +target_model_path: XiaomiMiMo/MiMo-7B-RL +model_path: XiaomiMiMo/MiMo-7B-RL +trust_remote_code: true + +# MiMo-specific settings - empty system prompt recommended +generation_config: + temperature: 0.6 + do_sample: true + +continuous_tokens: + "begin_continuous": "<|reserved_special_token_0|>" + "end_continuous": "<|reserved_special_token_1|>" + "continuous_rep": "<|reserved_special_token_2|>" + +output_dir: /data/artifacts/mimo/checkpoints/mimo_input_ablation +cache_dir: /data/artifacts/mimo/cache/ + +use_embed_proj: false + +train: + num_samples: 100000000 + batch_size: 8 + save_strategy: steps + save_total_limit: 10 + save_steps: 2000 + learning_rate: !!float 5e-5 + num_epochs: 20 + eval_strategy: steps + eval_steps: 2000 + peft_lora: false + lora_r: 128 + dataset: ["hint"] + evaluation_type: mixed + hint_path: Transluce/input_ablation_mimo_7b_mmlu_hint + hf_data_cache_dir: /data/artifacts/mimo/datasets/ + layers: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31] + bf16: true + tasks: + hint_attribution: + num_samples: 100000000 + weight: 1.0 + self_consistency: false + evaluation_type: exact_match + question_types: + generative_explanation: + evaluation_type: exact_match + prompts: + - messages: + - role: "user" + content: "{user_prompt}\n\nIf the hint were removed how would the assistant answer change?" + - role: "assistant" + content: "" + +test: + batch_size: 8 + tasks: + hint_attribution: + num_samples: 3200 diff --git a/model/__init__.py b/model/__init__.py index b6ea55c..ded5f23 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -2,6 +2,7 @@ from model.continuous_gemma2 import ContinuousGemma2ForCausalLM from model.continuous_llama import ContinuousLlama from model.continuous_qwen import ContinuousQwen3ForCausalLM +from model.continuous_mimo import ContinuousMiMo from model.nearest_neighbor import NearestNeighborModel from model.continuous_peft import ContinuousPeft from model.self_explanations import SelfExplanationsModel @@ -11,6 +12,7 @@ "ContinuousGemma3ForCausalLM", "ContinuousGemma2ForCausalLM", "ContinuousQwen3ForCausalLM", + "ContinuousMiMo", "ContinuousPeft", "NearestNeighborModel", "SelfExplanationsModel", diff --git a/model/continuous_mimo.py b/model/continuous_mimo.py new file mode 100644 index 0000000..efddaf9 --- /dev/null +++ b/model/continuous_mimo.py @@ -0,0 +1,422 @@ +""" +MiMo-7B adapter for introspective-interp framework. + +This module provides the ContinuousMiMo class that integrates Xiaomi's MiMo-7B +reasoning model into the introspective-interp framework for self-explanation tasks. + +MiMo-7B features: +- 7B parameter reasoning-focused model based on Qwen2 architecture +- Multiple-Token Prediction (MTP) auxiliary objective +- Supports vLLM and SGLang deployment +- Achieves 95.8% on MATH500, 68.2% on AIME 2024 + +Usage: + model = ContinuousMiMo.from_pretrained( + "XiaomiMiMo/MiMo-7B-RL", + trust_remote_code=True, + ... + ) +""" + +from typing import Optional, Tuple, Union, Dict, Any + +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, AutoConfig +from transformers.cache_utils import Cache + +from .continuous_base import ContinuousCausalLMBase, ContinuousCausalLMOutputWithPast + + +class ContinuousMiMo(ContinuousCausalLMBase, nn.Module): + """ + Extension of MiMo-7B that allows for intervention in the hidden states + and attention head outputs during the forward pass. + + MiMo uses a Qwen2-based architecture with Multiple-Token Prediction (MTP), + so we load it via AutoModelForCausalLM with trust_remote_code=True + and wrap it for continuous token support. + """ + + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__( + self, + config, + batch_size: int = 24, + special_tokens_ids: Dict[str, int] | None = None, + subject_embed_dim: int | None = None, + use_embed_proj: bool = False, + target_model_layers: int | None = None, + **kwargs, + ): + nn.Module.__init__(self) + + # Store config for compatibility + self.config = config + + # Initialize the base model placeholder + self._base_model = None + self._model_initialized = False + + # Store initialization parameters + self._init_params = { + "batch_size": batch_size, + "special_tokens_ids": special_tokens_ids, + "subject_embed_dim": subject_embed_dim, + "use_embed_proj": use_embed_proj, + "target_model_layers": target_model_layers, + } + + # State for hidden state interventions + self._intervention_hidden_states = None + self._intervention_active = False + self._intervention_is_batched = False + + # State for attention head interventions + self._head_intervention_layer = None + self._head_intervention_indices_positions = None + self._head_intervention_active = False + self._head_intervention_is_batched = False + self._head_patching_handles = [] + + # State for attribution patching + self._attribution_active = False + self._attribution_hooks = [] + self._attribution_gradients = {} + + def _init_continuous_base(self): + """Initialize ContinuousCausalLMBase after model is loaded.""" + ContinuousCausalLMBase.__init__( + self, + batch_size=self._init_params["batch_size"], + special_tokens_ids=self._init_params["special_tokens_ids"], + subject_embed_dim=self._init_params["subject_embed_dim"], + use_embed_proj=self._init_params["use_embed_proj"], + target_model_layers=self._init_params["target_model_layers"], + ) + self._model_initialized = True + + @property + def model(self): + """Access the underlying transformer model. + + MiMo (like Qwen2) has the transformer layers under .model attribute. + """ + if self._base_model is None: + raise RuntimeError("Model not initialized. Call from_pretrained first.") + # MiMo/Qwen2 models have transformer under .model + if hasattr(self._base_model, 'model'): + return self._base_model.model + return self._base_model + + @property + def lm_head(self): + """Access the language model head.""" + if self._base_model is None: + raise RuntimeError("Model not initialized. Call from_pretrained first.") + return self._base_model.lm_head + + @classmethod + def from_pretrained( + cls, + model_path: str, + batch_size: int = 24, + special_tokens_ids: Dict[str, int] | None = None, + subject_embed_dim: int | None = None, + use_embed_proj: bool = False, + target_model_layers: int | None = None, + cache_dir: str | None = None, + **kwargs, + ): + """ + Load a MiMo model from pretrained weights. + + Args: + model_path: HuggingFace model path (e.g., "XiaomiMiMo/MiMo-7B-RL") + batch_size: Batch size for training + special_tokens_ids: Dictionary of special token IDs + subject_embed_dim: Dimension of subject embeddings + use_embed_proj: Whether to use embedding projection + target_model_layers: Number of target model layers + cache_dir: Cache directory for model weights + **kwargs: Additional arguments passed to AutoModelForCausalLM + + Returns: + ContinuousMiMo: Initialized model wrapper + """ + # Ensure trust_remote_code is set for MiMo + kwargs['trust_remote_code'] = True + + if cache_dir is not None: + kwargs['cache_dir'] = cache_dir + + # Load the config first + config = AutoConfig.from_pretrained( + model_path, + trust_remote_code=True, + cache_dir=cache_dir, + ) + + # Create the wrapper instance + instance = cls( + config=config, + batch_size=batch_size, + special_tokens_ids=special_tokens_ids, + subject_embed_dim=subject_embed_dim, + use_embed_proj=use_embed_proj, + target_model_layers=target_model_layers, + ) + + # Load the actual model + instance._base_model = AutoModelForCausalLM.from_pretrained( + model_path, + **kwargs, + ) + + # Now initialize the continuous base with the loaded model + instance._init_continuous_base() + + return instance + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, + inputs_continuous_tokens: Optional[list[torch.LongTensor]] = None, + labels_continuous_tokens: Optional[list[torch.LongTensor]] = None, + debug: bool = False, + extra_args: Optional[list[dict]] = None, + **kwargs, + ) -> Union[Tuple, ContinuousCausalLMOutputWithPast]: + """ + Forward pass with continuous token support. + + inputs_continuous_tokens: List[torch.LongTensor] of dim batch_size x (seq_len) + labels_continuous_tokens: List[torch.LongTensor] of dim batch_size x (seq_len) + """ + # Handle both parameter names for logits keeping + if logits_to_keep == 0 and num_logits_to_keep != 0: + logits_to_keep = num_logits_to_keep + + return self.shared_forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + inputs_continuous_tokens=inputs_continuous_tokens, + labels_continuous_tokens=labels_continuous_tokens, + debug=debug, + extra_args=extra_args, + **kwargs, + ) + + def generate(self, *args, **kwargs): + """Generate text using the underlying MiMo model. + + MiMo recommends: + - Empty system prompt + - Temperature 0.6 + """ + if self._base_model is None: + raise RuntimeError("Model not initialized. Call from_pretrained first.") + return self._base_model.generate(*args, **kwargs) + + def save_pretrained(self, save_directory: str, **kwargs): + """Save the model to a directory.""" + if self._base_model is None: + raise RuntimeError("Model not initialized. Call from_pretrained first.") + self._base_model.save_pretrained(save_directory, **kwargs) + + # Also save embed_projs if they exist + if hasattr(self, 'embed_projs') and self.embed_projs is not None: + import os + embed_proj_path = os.path.join(save_directory, "embed_projs.pt") + torch.save(self.embed_projs.state_dict(), embed_proj_path) + + print(f"Saved MiMo model to {save_directory}") + + def train(self, mode: bool = True): + """Set training mode.""" + if self._base_model is not None: + self._base_model.train(mode) + # Call parent train method for loss function setup + if hasattr(super(), 'train'): + super().train(mode) + return self + + def eval(self): + """Set evaluation mode.""" + if self._base_model is not None: + self._base_model.eval() + if hasattr(super(), 'eval'): + super().eval() + return self + + def to(self, *args, **kwargs): + """Move model to device.""" + if self._base_model is not None: + self._base_model.to(*args, **kwargs) + if hasattr(self, 'embed_projs') and self.embed_projs is not None: + self.embed_projs.to(*args, **kwargs) + return self + + def cuda(self, device=None): + """Move model to CUDA.""" + if self._base_model is not None: + self._base_model.cuda(device) + if hasattr(self, 'embed_projs') and self.embed_projs is not None: + self.embed_projs.cuda(device) + return self + + def cpu(self): + """Move model to CPU.""" + if self._base_model is not None: + self._base_model.cpu() + if hasattr(self, 'embed_projs') and self.embed_projs is not None: + self.embed_projs.cpu() + return self + + def half(self): + """Convert model to half precision.""" + if self._base_model is not None: + self._base_model.half() + return self + + def bfloat16(self): + """Convert model to bfloat16.""" + if self._base_model is not None: + self._base_model.bfloat16() + return self + + def float(self): + """Convert model to float32.""" + if self._base_model is not None: + self._base_model.float() + return self + + def parameters(self, recurse: bool = True): + """Return model parameters.""" + params = [] + if self._base_model is not None: + params.extend(self._base_model.parameters(recurse)) + if hasattr(self, 'embed_projs') and self.embed_projs is not None: + params.extend(self.embed_projs.parameters(recurse)) + return iter(params) + + def named_parameters(self, prefix: str = '', recurse: bool = True): + """Return named parameters.""" + params = [] + if self._base_model is not None: + for name, param in self._base_model.named_parameters(prefix, recurse): + params.append((name, param)) + if hasattr(self, 'embed_projs') and self.embed_projs is not None: + embed_prefix = f"{prefix}embed_projs." if prefix else "embed_projs." + for name, param in self.embed_projs.named_parameters('', recurse): + params.append((f"{embed_prefix}{name}", param)) + return iter(params) + + def state_dict(self, *args, **kwargs): + """Return state dict.""" + state = {} + if self._base_model is not None: + state.update(self._base_model.state_dict(*args, **kwargs)) + if hasattr(self, 'embed_projs') and self.embed_projs is not None: + for i, proj in enumerate(self.embed_projs): + for name, param in proj.state_dict().items(): + state[f"embed_projs.{i}.{name}"] = param + return state + + def load_state_dict(self, state_dict, strict: bool = True): + """Load state dict.""" + # Separate embed_projs from base model state + base_state = {} + embed_state = {} + for key, value in state_dict.items(): + if key.startswith("embed_projs."): + embed_state[key] = value + else: + base_state[key] = value + + if self._base_model is not None and base_state: + self._base_model.load_state_dict(base_state, strict=strict) + + if hasattr(self, 'embed_projs') and self.embed_projs is not None and embed_state: + for key, value in embed_state.items(): + parts = key.split(".") + if len(parts) >= 3: + idx = int(parts[1]) + param_name = ".".join(parts[2:]) + if idx < len(self.embed_projs): + self.embed_projs[idx].load_state_dict({param_name: value}, strict=strict) + + @property + def device(self): + """Get model device.""" + if self._base_model is not None: + return next(self._base_model.parameters()).device + return torch.device('cpu') + + @property + def dtype(self): + """Get model dtype.""" + if self._base_model is not None: + return next(self._base_model.parameters()).dtype + return torch.float32 + + def get_input_embeddings(self): + """Get input embeddings layer.""" + if self._base_model is not None and hasattr(self._base_model, 'get_input_embeddings'): + return self._base_model.get_input_embeddings() + return None + + def set_input_embeddings(self, value): + """Set input embeddings layer.""" + if self._base_model is not None and hasattr(self._base_model, 'set_input_embeddings'): + self._base_model.set_input_embeddings(value) + + def get_output_embeddings(self): + """Get output embeddings layer (lm_head).""" + if self._base_model is not None and hasattr(self._base_model, 'get_output_embeddings'): + return self._base_model.get_output_embeddings() + return self.lm_head + + def resize_token_embeddings(self, new_num_tokens: int): + """Resize token embeddings.""" + if self._base_model is not None and hasattr(self._base_model, 'resize_token_embeddings'): + return self._base_model.resize_token_embeddings(new_num_tokens) + return None + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + """Enable gradient checkpointing.""" + if self._base_model is not None and hasattr(self._base_model, 'gradient_checkpointing_enable'): + self._base_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs) + + def gradient_checkpointing_disable(self): + """Disable gradient checkpointing.""" + if self._base_model is not None and hasattr(self._base_model, 'gradient_checkpointing_disable'): + self._base_model.gradient_checkpointing_disable() + + def __repr__(self): + return f"ContinuousMiMo(config={self.config})" diff --git a/model/utils.py b/model/utils.py index 405ae99..d955e46 100644 --- a/model/utils.py +++ b/model/utils.py @@ -7,6 +7,7 @@ from model.continuous_gemma2 import ContinuousGemma2ForCausalLM from model.continuous_qwen import ContinuousQwen3ForCausalLM from model.continuous_llama import ContinuousLlama +from model.continuous_mimo import ContinuousMiMo from model.nearest_neighbor import NearestNeighborModel from model.continuous_peft import ContinuousPeft from model.self_explanations import SelfExplanationsModel @@ -27,6 +28,7 @@ "gemma3": ContinuousGemma3ForCausalLM, "gemma2": ContinuousGemma2ForCausalLM, "qwen3": ContinuousQwen3ForCausalLM, + "mimo": ContinuousMiMo, "nearest_neighbor": NearestNeighborModel, "self_explanations": SelfExplanationsModel, } @@ -97,6 +99,14 @@ def load_models( "meta-llama/Meta-Llama-3-8B", ]: model_type = "llama" + elif predictor_model_type in [ + "XiaomiMiMo/MiMo-7B-Base", + "XiaomiMiMo/MiMo-7B-RL", + "XiaomiMiMo/MiMo-7B-RL-0530", + "XiaomiMiMo/MiMo-7B-SFT", + "XiaomiMiMo/MiMo-7B-RL-Zero", + ] or "MiMo" in predictor_model_type: + model_type = "mimo" else: raise ValueError(f"Model {predictor_model_type} not supported for autointerp")