-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_predictions.py
More file actions
59 lines (47 loc) · 1.73 KB
/
plot_predictions.py
File metadata and controls
59 lines (47 loc) · 1.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import argparse
from graph_model import BlockModel, WideNet, UnSqueeze, OneHotNet
import torch
import matplotlib.pyplot as plt
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def plot(model_args, save_path, coords, model_path="", model=None):
if model_path != "" and model == None:
model = BlockModel(**model_args)
model.load_state_dict(torch.load(model_path))
model.to(device)
model.eval()
X = []
y = []
# for i in range(100):
# for j in range(100):
# input = torch.Tensor([i, j])
# input.to(device)
# y_val_pred = model(input)
# y_pred_softmax = torch.log_softmax(y_val_pred, dim=0)
# _, y_pred_tags = torch.max(y_pred_softmax, dim=0)
# y.append(y_pred_tags.item())
# X.append([i, j])
for i, c in enumerate(coords):
input = coords[i].to(device)
y_val_pred = model(input)
y_pred_softmax = torch.log_softmax(y_val_pred, dim=0)
_, y_pred_tags = torch.max(y_pred_softmax, dim=0)
y.append(y_pred_tags.item())
X.append([input[0].cpu(), input[1].cpu()])
y = np.array(y)
X = np.array(X)
print(X.shape)
print(y.shape)
fig = plt.figure()
colors = ["blue" if y_ == 1 else "red" for y_ in y]
plt.figure(figsize=(15, 15), dpi=160)
plt.scatter(X[:, 0], X[:, 1], c=colors)
# plt.xticks(np.arange(0, 101, 5))
# plt.yticks(np.arange(0, 101, 5))
plt.grid()
plt.savefig(f"eval_image_plots/{save_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", "-m", type=str, required=True)
args = parser.parse_args()
plot(args.model_path)