-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathalignment_vis.py
More file actions
82 lines (67 loc) · 2.97 KB
/
Copy pathalignment_vis.py
File metadata and controls
82 lines (67 loc) · 2.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
from scipy.sparse.linalg import eigsh
import pickle
from scipy.linalg import eigh
from tqdm import tqdm
import os
import torch
import torch.nn.functional as F
import llm_research
import utils
def alignment_vis(model_name, dataset_name='emotion', i_layer=0):
# Load cached alignment data
data_path = f".cache/{dataset_name.replace('/', '_')}_train_{model_name.replace('/', '_')}_alignment_cache.pkl"
with open(data_path, 'rb') as f:
(meanX_meanXT_dict, meanX_YT_dict,
lag0_cum_dict, lag1_cum_dict,
mean_cache, last_cache) = pickle.load(f)
del mean_cache, last_cache
num_layers = len(meanX_meanXT_dict)
assert i_layer < num_layers, f"Layer {i_layer} exceeds the number of layers {num_layers} in the model."
M_t = meanX_meanXT_dict[i_layer]
X_t = meanX_YT_dict[i_layer]
C0_t = lag0_cum_dict[i_layer]
C1_t = lag1_cum_dict[i_layer]
M = M_t.detach().cpu().to(torch.float32).numpy()
X = X_t.detach().cpu().to(torch.float32).numpy()
C0 = C0_t.detach().cpu().to(torch.float32).numpy()
C1 = C1_t.detach().cpu().to(torch.float32).numpy()
# Compute generalized eigenvectors
eigvals_U, U = eigsh(X @ X.T, k=2, M=M, which='LA')
eigvals_V, V = eigsh(C1 @ C1.T, k=2, M=C0, which='LA')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer, model = utils.tokenizer_model_load(model_name)
model.to(device)
model.eval()
categories = {
'Joy': ['lovely', 'enjoyable', 'amazing', 'delighted'],
'Sad': ['horrible', 'disappointing', 'unhappy', 'lonely'],
'Be': ['am', 'is', 'are', 'was', 'were', ],
}
# Prepare projections
labels = ['U', 'V']
projected = {
label: {cat: [] for cat in categories}
for label in labels
}
# Project each word
for cat, word_list in categories.items():
for w in word_list:
toks = tokenizer(w, return_tensors='pt', padding=True, truncation=True).to(device)
out = model(toks['input_ids'], attention_mask=toks['attention_mask'], output_hidden_states=True)
h = out.hidden_states[i_layer].detach().cpu().numpy()
h_vec = h.mean(axis=1)
projected['U'][cat].append(h_vec @ U)
projected['V'][cat].append(h_vec @ V)
# Save the projections
save_path = f"alignment_vis_{model_name.replace('/', '_')}_{dataset_name}_{i_layer}.pkl"
with open(save_path, 'wb') as f:
pickle.dump(projected, f)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Alignment Visualization")
parser.add_argument('--model_name', type=str, default="apple/OpenELM-450M", help='Name of the model')
parser.add_argument('--dataset_name', type=str, default='emotion', help='Name of the dataset')
parser.add_argument('--i_layer', type=int, default=0, help='Layer index to visualize')
args = parser.parse_args()
alignment_vis(args.model_name, args.dataset_name, args.i_layer)