Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 82 additions & 47 deletions examples/cut_and_paste_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from bokeh.plotting import figure, curdoc
from bokeh.models import (
LassoSelectTool,
PolySelectTool,
TapTool,
TextInput,
Button,
Expand Down Expand Up @@ -32,7 +33,6 @@
import vis
from vis import maximize_window


from pyRATS._utils import (
procrustes,
compute_final_embedding,
Expand Down Expand Up @@ -98,32 +98,42 @@ def __init__(
self,
model,
y,
color_of_pts_on_tear,
metadata_fname=None,
metadata_fpath=None,
save_progress_dir="",
cmap_interior="summer",
cmap_tear="jet",
tear_color_eig_inds=[5, 1, 2],
max_refinement_iter=10,
force_compute=False,
):
self.model = model
self.y = y
self.color_of_pts_on_tear = color_of_pts_on_tear
self.pts = None
if metadata_fname is None:
if metadata_fpath is None:
self.metadata = []
else:
self.metadata = read(metadata_fname)
_, self.metadata = read(metadata_fpath)
self.save_progress_dir = save_progress_dir
if self.save_progress_dir:
makedirs(self.save_progress_dir)
self.cur_iter = 0

self.model = model
self.y = y
self.cmap_interior = cmap_interior
self.cmap_tear = cmap_tear
self.tear_color_eig_inds = tear_color_eig_inds
self.max_refinement_iter = max_refinement_iter
self.force_compute = force_compute
self.selected_pts_to_move = np.array([])

if len(self.metadata) > 0:
# simulate cut and paste
print('Simulating cut and paste for', len(self.metadata), 'iterations using previously stored metadata.')
_ = self.get_current_embedding_data()
for i in range(len(self.metadata)):
self.cut_clusters_to_move(None)
self.paste_clusters(None)


def init_metadata_for_current_iter(self):
if len(self.metadata) <= self.cur_iter:
self.metadata.append(
Expand All @@ -136,30 +146,29 @@ def init_metadata_for_current_iter(self):

def get_current_embedding_data(self):
y = self.y.copy()
color_of_pts_on_tear = self.color_of_pts_on_tear.copy()
color_of_pts_on_tear = self.model.compute_color_of_pts_on_tear(y, self.tear_color_eig_inds)
cmap_interior = self.cmap_interior
cmap_tear = self.cmap_tear

matplotlib.use("Agg")
color_of_pts_on_tear = self.color_of_pts_on_tear
pts_on_tear = ~np.isnan(color_of_pts_on_tear[:, -1])

interior_handle, tear_handle = vis.Visualize().global_embedding_for_gui(
y,
y[:, 0],
cmap0=cmap_interior,
color_of_pts_on_tear=color_of_pts_on_tear[:, -1],
color_of_pts_on_tear=color_of_pts_on_tear[:, self.tear_color_eig_inds],
cmap1=cmap_tear,
set_title=True,
figsize=(3, 3),
s=20,
s=40,
)
if self.save_progress_dir:
save_fn = (
self.save_progress_dir
+ "/reg_tear_y_at_iter="
+ str(self.cur_iter)
+ ".png"
+ ".eps"
)
plt.savefig(save_fn, bbox_inches="tight", dpi=400)

Expand Down Expand Up @@ -234,6 +243,7 @@ def paste_clusters(self, final_polyg):
selected_cluster_mask = self.metadata[self.cur_iter]["selected_cluster_mask"]
points_in_selected_clusters = selected_cluster_mask[self.model.c]
y = self.y.copy()

y_new = self.recompute_embedding(
y,
points_in_selected_clusters,
Expand Down Expand Up @@ -278,12 +288,12 @@ def finish_pasting(self, y_new):
)
self.color_of_pts_on_tear = model.compute_color_of_pts_on_tear(self.y)

self.save_buml_obj(
"buml_obj_and_metadata_after_iter=" + str(self.cur_iter) + ".dat",
self.save_model(
"model_and_metadata_after_iter=" + str(self.cur_iter) + ".dat",
)
self.cur_iter += 1

def save_buml_obj(self, fname):
def save_model(self, fname):
if self.save_progress_dir:
new_emb_info = {
"y": self.y,
Expand All @@ -309,7 +319,7 @@ def recompute_embedding(
)
return y

def save_polygon(self, y, y_face_color, y_polyg, stage, s=20):
def save_polygon(self, y, y_face_color, y_polyg, stage, s=40):
assert stage in ["init", "final"]
matplotlib.use("Agg")
_, ax = plt.subplots()
Expand All @@ -320,16 +330,16 @@ def save_polygon(self, y, y_face_color, y_polyg, stage, s=20):
plt.axis("image")
plt.axis("off")
plt.tight_layout()
ax.set_rasterized(True)
save_fn = (
self.save_progress_dir
+ "/"
+ stage
+ "_polyg_at_iter="
+ str(self.cur_iter)
+ ".png"
+ ".eps"
)
plt.savefig(save_fn, bbox_inches="tight", dpi=400)
ax.set_rasterized(True)


def get_convex_covering_polygon(y):
Expand All @@ -339,8 +349,11 @@ def get_convex_covering_polygon(y):

# APP itself
###############################################################
from datetime import datetime

REG_TEAR_TAG = "reg_tear"
# Format as YYYY-MM-DD HH:MM:SS
string_format = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
REG_TEAR_TAG = "reg_tear_" + string_format

# Spinner using CSS
################################################################
Expand Down Expand Up @@ -446,21 +459,26 @@ def print_main_log(message):
################################################################
gen_data_path_input = TextInput(
title="Generated data path",
value="/Users/joshuaoffergeld/Documents/pyRATS/examples",
value="/data2/dhruv/RATS_code/RATS/code/R2OFday1/notebooks/generated_data/",
width=550,
)
cut_and_paste_metadata_path_input = TextInput(
title="Cut and paste metadata path (leave it blank if running for the first time)",
value="/data2/dhruv/RATS_code/RATS/code/R2OFday1/notebooks/generated_data/rats_mishnelab/34_3_reg_tear_2026-06-16 15:04:51/model_and_metadata_after_iter=1.dat",
width=550,
)

algo_input = TextInput(title="Algorithm Name: (for ex: rats)", value="rats")
hyp_param_input = TextInput(title="Hyperparameter: (for ex: 28_5)", value="28_5")
algo_input = TextInput(title="Algorithm Name: (for ex: rats)", value="rats_mishnelab")
hyp_param_input = TextInput(title="Hyperparameter: (for ex: 28_5)", value="34_3")
options_input = TextAreaInput(
title="Additional options:",
value="{'force_compute': True, 'max_refinement_iter': 10, 'cmap_interior': 'summer', 'cmap_tear': 'jet'}",
value="{'force_compute': False, 'max_refinement_iter': 10, 'cmap_interior': 'summer', 'cmap_tear': 'jet', 'tear_color_eig_inds': [5, 2, 3]}",
rows=5,
)

cut_and_paste_obj = None
rot_deg = 5
lasso_tool = LassoSelectTool()
select_tool = PolySelectTool()
tap_tool = TapTool()

update_max_fig_height_button = Button(
Expand All @@ -478,15 +496,16 @@ def print_main_log(message):
)

source = ColumnDataSource(data=dict(x=[], y=[], color=[]))
cluster_patches_source = ColumnDataSource(data=dict(xs=[], ys=[]))
patch_source = ColumnDataSource(data=dict(x=[], y=[]))

MAX_FIG_HEIGHT = 900
MAX_FIG_WIDTH = 900
MAX_FIG_HEIGHT = 200
MAX_FIG_WIDTH = 200

max_fig_height_input = TextInput(title="Max figure height", value=str(MAX_FIG_HEIGHT))
fig = figure(
title="Cut & Paste operation on the embedding",
tools=[lasso_tool],
tools=[select_tool],
x_axis_label="x",
y_axis_label="y",
width=MAX_FIG_HEIGHT,
Expand All @@ -500,22 +519,22 @@ def print_main_log(message):
# Adjust figure size using the scatter x and y coordinates
##############################################################################
def adjust_figure_size():
print(source.data, source)
x_min = 1.25 * np.min(source.data["x"])
x_max = 1.25 * np.max(source.data["x"])
y_min = 1.25 * np.min(source.data["y"])
y_max = 1.25 * np.max(source.data["y"])
scale = 2
x_min = scale * np.min(source.data["x"])
x_max = scale * np.max(source.data["x"])
y_min = scale * np.min(source.data["y"])
y_max = scale * np.max(source.data["y"])
aspect_ratio = (x_max - x_min) / (y_max - y_min)
print("Aspect ratio (x/y) of the embedding: " + str(aspect_ratio))
print("Adjusting figure width based on the aspect ratio.")
print("MAX_FIG_WIDTH = " + str(MAX_FIG_WIDTH))
print("MAX_FIG_HEIGHT = " + str(MAX_FIG_HEIGHT))
if aspect_ratio >= 1:
fig.height = int(1.25 * MAX_FIG_WIDTH / aspect_ratio)
fig.width = int(1.25 * MAX_FIG_WIDTH)
fig.height = int(scale * MAX_FIG_WIDTH / aspect_ratio)
fig.width = int(scale * MAX_FIG_WIDTH)
else:
fig.height = int(1.25 * MAX_FIG_HEIGHT)
fig.width = int(1.25 * aspect_ratio * MAX_FIG_HEIGHT)
fig.height = int(scale * MAX_FIG_HEIGHT)
fig.width = int(scale * aspect_ratio * MAX_FIG_HEIGHT)

fig.x_range = Range1d(x_min, x_max, bounds=(x_min, x_max))
fig.y_range = Range1d(y_min, y_max, bounds=(y_min, y_max))
Expand Down Expand Up @@ -668,26 +687,36 @@ def read(fpath, verbose=True):

def start_main():
gen_data_path = gen_data_path_input.value.strip()
cut_and_paste_metadata_path = cut_and_paste_metadata_path_input.value.strip()
algo = algo_input.value.strip()
hyp_param = hyp_param_input.value.strip()
buml_obj_info_path = (
model_info_path = (
gen_data_path + "/" + algo + "/" + hyp_param + "/" + algo + ".dat"
)
save_dir = gen_data_path + "/" + algo + "/" + hyp_param + "_" + REG_TEAR_TAG
if cut_and_paste_metadata_path != "":
print("cut_and_paste_metadata_path=" + cut_and_paste_metadata_path)
save_dir = gen_data_path + "/" + algo + "/" + cut_and_paste_metadata_path.split('/')[-2]
else:
cut_and_paste_metadata_path = None
save_dir = gen_data_path + "/" + algo + "/" + hyp_param + "_" + REG_TEAR_TAG

options = process_additional_options()
print("buml_obj_info_path=" + buml_obj_info_path)
print("model_info_path=" + model_info_path)
print("save_dir=" + save_dir)
global cut_and_paste_obj
emb_info, metadata = read(buml_obj_info_path)
emb_info, metadata = read(model_info_path)


cut_and_paste_obj = CutAndPaste(
model=emb_info["model"],
y=emb_info["y"],
color_of_pts_on_tear=emb_info["color_of_pts_on_tear"],
metadata_fpath=cut_and_paste_metadata_path,
save_progress_dir=save_dir,
max_refinement_iter=options["max_refinement_iter"],
force_compute=options["force_compute"],
cmap_interior=options["cmap_interior"],
cmap_tear=options["cmap_tear"],
tear_color_eig_inds=options["tear_color_eig_inds"]
)
source.data = cut_and_paste_obj.get_current_embedding_data()

Expand All @@ -698,8 +727,9 @@ def start_end():
cut_button.disabled = False
update_max_fig_height_button.disabled = False
fig.toolbar.active_inspect = None
fig.toolbar.active_drag = lasso_tool
print_main_log("Select region to cut using lasso and then press cut.")
#fig.toolbar.active_drag = select_tool
fig.toolbar.active_tap = select_tool
print_main_log("Select polygonal region to cut (tap at a location to add a vertex), press Enter/Esc to finish selection and then press cut button.")


def start():
Expand All @@ -721,13 +751,16 @@ def on_start_button_click():
# Cut
#######################################
def cut_start():
fig.toolbar.active_drag = None
#fig.toolbar.active_drag = None
fig.toolbar.active_tap = None
cut_button.disabled = True
save_button.disabled = True


def cut_main():
global cut_and_paste_obj
print(source.selected.indices)
print(len(source.selected.indices))
patch_source.data = cut_and_paste_obj.cut_clusters_to_move(source.selected.indices)
source.selected.indices = cut_and_paste_obj.selected_pts_to_move.tolist()
print("len(source.selected.indices) = " + str(len(source.selected.indices)))
Expand Down Expand Up @@ -768,7 +801,8 @@ def paste_end():
cut_button.disabled = False
save_button.disabled = False
patch_source.data = dict(x=[], y=[])
fig.toolbar.active_drag = lasso_tool
#fig.toolbar.active_drag = select_tool
fig.toolbar.active_tap = select_tool


def paste_main():
Expand Down Expand Up @@ -800,7 +834,7 @@ def on_paste_button_click():
def save_():
algo = algo_input.value.strip()
global cut_and_paste_obj
cut_and_paste_obj.save_buml_obj(algo + ".dat")
cut_and_paste_obj.save_model(algo + ".dat")


def on_save_button_click():
Expand Down Expand Up @@ -833,6 +867,7 @@ def on_update_max_fig_height_button_click():
column(
main_log_div,
row(gen_data_path_input),
row(cut_and_paste_metadata_path_input),
row(algo_input, hyp_param_input, options_input, max_fig_height_input),
row(
update_max_fig_height_button,
Expand Down
Loading