11import torch .nn .functional as F
2+ from abc import ABC , abstractmethod
23
34
4- class ActivationCapture :
5+ class ActivationCapture ( ABC ) :
56 """Helper class to capture activations from model layers."""
7+ has_gate_proj : bool
8+ has_up_proj : bool
69
710 def __init__ (self ):
811 self .hidden_states = {}
912 self .mlp_activations = {}
1013 self .handles = []
11-
14+
15+ @abstractmethod
16+ def _register_gate_hook (self , layer_idx , layer ):
17+ pass
18+
19+ @abstractmethod
20+ def _register_up_hook (self , layer_idx , layer ):
21+ pass
22+
23+ @abstractmethod
24+ def get_layers (self , model ):
25+ pass
26+
27+ def _register_hidden_state_hook (self , layer_idx , layer ):
28+ def hook (module , args , kwargs , output ):
29+ # args[0] is the input hidden states to the layer
30+ if len (args ) > 0 :
31+ # Just detach, don't clone or move to CPU yet
32+ self .hidden_states [layer_idx ] = args [0 ].detach ()
33+ return output
34+ handle = layer .register_forward_hook (
35+ hook ,
36+ with_kwargs = True
37+ )
38+ return handle
39+
1240 def register_hooks (self , model ):
1341 """Register forward hooks to capture activations."""
1442 # Clear any existing hooks
1543 self .remove_hooks ()
1644
1745 # Hook into each transformer layer
18- for i , layer in enumerate (model . model . layers ):
46+ for i , layer in enumerate (self . get_layers ( model ) ):
1947
2048 # Capture hidden states before MLP
21- handle = layer .register_forward_hook (
22- self ._create_hidden_state_hook (i ),
23- with_kwargs = True
24- )
25- self .handles .append (handle )
49+ handle = self ._register_hidden_state_hook (i , layer )
50+ if handle is not None :
51+ self .handles .append (handle )
2652
2753 # Capture MLP gate activations (after activation function)
28- if hasattr (layer .mlp , 'gate_proj' ):
29- handle = layer .mlp .gate_proj .register_forward_hook (
30- self ._create_mlp_hook (i , 'gate' )
31- )
32- self .handles .append (handle )
54+ if self .has_gate_proj :
55+ handle = self ._register_gate_hook (i , layer )
56+ if handle is not None :
57+ self .handles .append (handle )
3358
3459 # Also capture up_proj activations
35- if hasattr (layer .mlp , 'up_proj' ):
36- handle = layer .mlp .up_proj .register_forward_hook (
37- self ._create_mlp_hook (i , 'up' )
38- )
39- self .handles .append (handle )
40-
41- def _create_hidden_state_hook (self , layer_idx ):
42- def hook (module , args , kwargs , output ):
43- # args[0] is the input hidden states to the layer
44- if len (args ) > 0 :
45- # Just detach, don't clone or move to CPU yet
46- self .hidden_states [layer_idx ] = args [0 ].detach ()
47- return output
48- return hook
49-
50- def _create_mlp_hook (self , layer_idx , proj_type ):
51- def hook (module , input , output ):
52- key = f"{ layer_idx } _{ proj_type } "
53- # Just detach, don't clone or move to CPU yet
54- self .mlp_activations [key ] = output .detach ()
55- return output
56- return hook
57-
60+ if self .has_up_proj :
61+ handle = self ._register_up_hook (i , layer )
62+ if handle is not None :
63+ self .handles .append (handle )
64+
5865 def remove_hooks (self ):
5966 """Remove all registered hooks."""
6067 for handle in self .handles :
@@ -65,7 +72,46 @@ def clear_captures(self):
6572 """Clear captured activations."""
6673 self .hidden_states = {}
6774 self .mlp_activations = {}
68-
75+
76+ @abstractmethod
77+ def get_mlp_activations (self , layer_idx ):
78+ """Get combined MLP activations for a layer."""
79+ pass
80+
81+ @abstractmethod
82+ def get_gate_activations (self , layer_idx ):
83+ """Get combined MLP activations for a layer."""
84+ return
85+
86+
87+ class ActivationCaptureDefault (ActivationCapture ):
88+ """Helper class to capture activations from model layers."""
89+ has_gate_proj : bool = True
90+ has_up_proj : bool = True
91+
92+ def get_layers (self , model ):
93+ return model .model .layers
94+
95+ def _create_mlp_hook (self , layer_idx , proj_type ):
96+ def hook (module , input , output ):
97+ key = f"{ layer_idx } _{ proj_type } "
98+ # Just detach, don't clone or move to CPU yet
99+ self .mlp_activations [key ] = output .detach ()
100+ return output
101+ return hook
102+
103+ def _register_gate_hook (self , layer_idx , layer ):
104+ handle = layer .mlp .gate_proj .register_forward_hook (
105+ self ._create_mlp_hook (layer_idx , 'gate' )
106+ )
107+ return handle
108+
109+ def _register_up_hook (self , layer_idx , layer ):
110+ handle = layer .mlp .up_proj .register_forward_hook (
111+ self ._create_mlp_hook (layer_idx , 'up' )
112+ )
113+ return handle
114+
69115 def get_mlp_activations (self , layer_idx ):
70116 """Get combined MLP activations for a layer."""
71117 gate_key = f"{ layer_idx } _gate"
@@ -80,4 +126,18 @@ def get_mlp_activations(self, layer_idx):
80126 gated_act = F .silu (gate_act ) * up_act
81127 return gated_act
82128
83- return None
129+ return None
130+
131+ def get_gate_activations (self , layer_idx ):
132+ """Get combined MLP activations for a layer."""
133+ gate_key = f"{ layer_idx } _gate"
134+ if gate_key in self .mlp_activations :
135+ gate_act = self .mlp_activations [gate_key ]
136+ return F .silu (gate_act )
137+ return None
138+
139+
140+ ACTIVATION_CAPTURE = {}
141+
142+ def register_activation_capture (model_name , activation_capture ):
143+ ACTIVATION_CAPTURE [model_name ] = activation_capture
0 commit comments