diff --git a/examples/cut_and_paste_app.py b/examples/cut_and_paste_app.py index 53fb3aa..8298422 100644 --- a/examples/cut_and_paste_app.py +++ b/examples/cut_and_paste_app.py @@ -5,6 +5,7 @@ from bokeh.plotting import figure, curdoc from bokeh.models import ( LassoSelectTool, + PolySelectTool, TapTool, TextInput, Button, @@ -32,7 +33,6 @@ import vis from vis import maximize_window - from pyRATS._utils import ( procrustes, compute_final_embedding, @@ -98,32 +98,42 @@ def __init__( self, model, y, - color_of_pts_on_tear, - metadata_fname=None, + metadata_fpath=None, save_progress_dir="", cmap_interior="summer", cmap_tear="jet", + tear_color_eig_inds=[5, 1, 2], max_refinement_iter=10, force_compute=False, ): - self.model = model - self.y = y - self.color_of_pts_on_tear = color_of_pts_on_tear self.pts = None - if metadata_fname is None: + if metadata_fpath is None: self.metadata = [] else: - self.metadata = read(metadata_fname) + _, self.metadata = read(metadata_fpath) self.save_progress_dir = save_progress_dir if self.save_progress_dir: makedirs(self.save_progress_dir) self.cur_iter = 0 + + self.model = model + self.y = y self.cmap_interior = cmap_interior self.cmap_tear = cmap_tear + self.tear_color_eig_inds = tear_color_eig_inds self.max_refinement_iter = max_refinement_iter self.force_compute = force_compute self.selected_pts_to_move = np.array([]) + if len(self.metadata) > 0: + # simulate cut and paste + print('Simulating cut and paste for', len(self.metadata), 'iterations using previously stored metadata.') + _ = self.get_current_embedding_data() + for i in range(len(self.metadata)): + self.cut_clusters_to_move(None) + self.paste_clusters(None) + + def init_metadata_for_current_iter(self): if len(self.metadata) <= self.cur_iter: self.metadata.append( @@ -136,30 +146,29 @@ def init_metadata_for_current_iter(self): def get_current_embedding_data(self): y = self.y.copy() - color_of_pts_on_tear = self.color_of_pts_on_tear.copy() + color_of_pts_on_tear = self.model.compute_color_of_pts_on_tear(y, self.tear_color_eig_inds) cmap_interior = self.cmap_interior cmap_tear = self.cmap_tear matplotlib.use("Agg") - color_of_pts_on_tear = self.color_of_pts_on_tear pts_on_tear = ~np.isnan(color_of_pts_on_tear[:, -1]) interior_handle, tear_handle = vis.Visualize().global_embedding_for_gui( y, y[:, 0], cmap0=cmap_interior, - color_of_pts_on_tear=color_of_pts_on_tear[:, -1], + color_of_pts_on_tear=color_of_pts_on_tear[:, self.tear_color_eig_inds], cmap1=cmap_tear, set_title=True, figsize=(3, 3), - s=20, + s=40, ) if self.save_progress_dir: save_fn = ( self.save_progress_dir + "/reg_tear_y_at_iter=" + str(self.cur_iter) - + ".png" + + ".eps" ) plt.savefig(save_fn, bbox_inches="tight", dpi=400) @@ -234,6 +243,7 @@ def paste_clusters(self, final_polyg): selected_cluster_mask = self.metadata[self.cur_iter]["selected_cluster_mask"] points_in_selected_clusters = selected_cluster_mask[self.model.c] y = self.y.copy() + y_new = self.recompute_embedding( y, points_in_selected_clusters, @@ -278,12 +288,12 @@ def finish_pasting(self, y_new): ) self.color_of_pts_on_tear = model.compute_color_of_pts_on_tear(self.y) - self.save_buml_obj( - "buml_obj_and_metadata_after_iter=" + str(self.cur_iter) + ".dat", + self.save_model( + "model_and_metadata_after_iter=" + str(self.cur_iter) + ".dat", ) self.cur_iter += 1 - def save_buml_obj(self, fname): + def save_model(self, fname): if self.save_progress_dir: new_emb_info = { "y": self.y, @@ -309,7 +319,7 @@ def recompute_embedding( ) return y - def save_polygon(self, y, y_face_color, y_polyg, stage, s=20): + def save_polygon(self, y, y_face_color, y_polyg, stage, s=40): assert stage in ["init", "final"] matplotlib.use("Agg") _, ax = plt.subplots() @@ -320,16 +330,16 @@ def save_polygon(self, y, y_face_color, y_polyg, stage, s=20): plt.axis("image") plt.axis("off") plt.tight_layout() - ax.set_rasterized(True) save_fn = ( self.save_progress_dir + "/" + stage + "_polyg_at_iter=" + str(self.cur_iter) - + ".png" + + ".eps" ) plt.savefig(save_fn, bbox_inches="tight", dpi=400) + ax.set_rasterized(True) def get_convex_covering_polygon(y): @@ -339,8 +349,11 @@ def get_convex_covering_polygon(y): # APP itself ############################################################### +from datetime import datetime -REG_TEAR_TAG = "reg_tear" +# Format as YYYY-MM-DD HH:MM:SS +string_format = datetime.now().strftime("%Y-%m-%d %H:%M:%S") +REG_TEAR_TAG = "reg_tear_" + string_format # Spinner using CSS ################################################################ @@ -446,21 +459,26 @@ def print_main_log(message): ################################################################ gen_data_path_input = TextInput( title="Generated data path", - value="/Users/joshuaoffergeld/Documents/pyRATS/examples", + value="/data2/dhruv/RATS_code/RATS/code/R2OFday1/notebooks/generated_data/", + width=550, +) +cut_and_paste_metadata_path_input = TextInput( + title="Cut and paste metadata path (leave it blank if running for the first time)", + value="/data2/dhruv/RATS_code/RATS/code/R2OFday1/notebooks/generated_data/rats_mishnelab/34_3_reg_tear_2026-06-16 15:04:51/model_and_metadata_after_iter=1.dat", width=550, ) -algo_input = TextInput(title="Algorithm Name: (for ex: rats)", value="rats") -hyp_param_input = TextInput(title="Hyperparameter: (for ex: 28_5)", value="28_5") +algo_input = TextInput(title="Algorithm Name: (for ex: rats)", value="rats_mishnelab") +hyp_param_input = TextInput(title="Hyperparameter: (for ex: 28_5)", value="34_3") options_input = TextAreaInput( title="Additional options:", - value="{'force_compute': True, 'max_refinement_iter': 10, 'cmap_interior': 'summer', 'cmap_tear': 'jet'}", + value="{'force_compute': False, 'max_refinement_iter': 10, 'cmap_interior': 'summer', 'cmap_tear': 'jet', 'tear_color_eig_inds': [5, 2, 3]}", rows=5, ) cut_and_paste_obj = None rot_deg = 5 -lasso_tool = LassoSelectTool() +select_tool = PolySelectTool() tap_tool = TapTool() update_max_fig_height_button = Button( @@ -478,15 +496,16 @@ def print_main_log(message): ) source = ColumnDataSource(data=dict(x=[], y=[], color=[])) +cluster_patches_source = ColumnDataSource(data=dict(xs=[], ys=[])) patch_source = ColumnDataSource(data=dict(x=[], y=[])) -MAX_FIG_HEIGHT = 900 -MAX_FIG_WIDTH = 900 +MAX_FIG_HEIGHT = 200 +MAX_FIG_WIDTH = 200 max_fig_height_input = TextInput(title="Max figure height", value=str(MAX_FIG_HEIGHT)) fig = figure( title="Cut & Paste operation on the embedding", - tools=[lasso_tool], + tools=[select_tool], x_axis_label="x", y_axis_label="y", width=MAX_FIG_HEIGHT, @@ -500,22 +519,22 @@ def print_main_log(message): # Adjust figure size using the scatter x and y coordinates ############################################################################## def adjust_figure_size(): - print(source.data, source) - x_min = 1.25 * np.min(source.data["x"]) - x_max = 1.25 * np.max(source.data["x"]) - y_min = 1.25 * np.min(source.data["y"]) - y_max = 1.25 * np.max(source.data["y"]) + scale = 2 + x_min = scale * np.min(source.data["x"]) + x_max = scale * np.max(source.data["x"]) + y_min = scale * np.min(source.data["y"]) + y_max = scale * np.max(source.data["y"]) aspect_ratio = (x_max - x_min) / (y_max - y_min) print("Aspect ratio (x/y) of the embedding: " + str(aspect_ratio)) print("Adjusting figure width based on the aspect ratio.") print("MAX_FIG_WIDTH = " + str(MAX_FIG_WIDTH)) print("MAX_FIG_HEIGHT = " + str(MAX_FIG_HEIGHT)) if aspect_ratio >= 1: - fig.height = int(1.25 * MAX_FIG_WIDTH / aspect_ratio) - fig.width = int(1.25 * MAX_FIG_WIDTH) + fig.height = int(scale * MAX_FIG_WIDTH / aspect_ratio) + fig.width = int(scale * MAX_FIG_WIDTH) else: - fig.height = int(1.25 * MAX_FIG_HEIGHT) - fig.width = int(1.25 * aspect_ratio * MAX_FIG_HEIGHT) + fig.height = int(scale * MAX_FIG_HEIGHT) + fig.width = int(scale * aspect_ratio * MAX_FIG_HEIGHT) fig.x_range = Range1d(x_min, x_max, bounds=(x_min, x_max)) fig.y_range = Range1d(y_min, y_max, bounds=(y_min, y_max)) @@ -668,26 +687,36 @@ def read(fpath, verbose=True): def start_main(): gen_data_path = gen_data_path_input.value.strip() + cut_and_paste_metadata_path = cut_and_paste_metadata_path_input.value.strip() algo = algo_input.value.strip() hyp_param = hyp_param_input.value.strip() - buml_obj_info_path = ( + model_info_path = ( gen_data_path + "/" + algo + "/" + hyp_param + "/" + algo + ".dat" ) - save_dir = gen_data_path + "/" + algo + "/" + hyp_param + "_" + REG_TEAR_TAG + if cut_and_paste_metadata_path != "": + print("cut_and_paste_metadata_path=" + cut_and_paste_metadata_path) + save_dir = gen_data_path + "/" + algo + "/" + cut_and_paste_metadata_path.split('/')[-2] + else: + cut_and_paste_metadata_path = None + save_dir = gen_data_path + "/" + algo + "/" + hyp_param + "_" + REG_TEAR_TAG + options = process_additional_options() - print("buml_obj_info_path=" + buml_obj_info_path) + print("model_info_path=" + model_info_path) print("save_dir=" + save_dir) global cut_and_paste_obj - emb_info, metadata = read(buml_obj_info_path) + emb_info, metadata = read(model_info_path) + + cut_and_paste_obj = CutAndPaste( model=emb_info["model"], y=emb_info["y"], - color_of_pts_on_tear=emb_info["color_of_pts_on_tear"], + metadata_fpath=cut_and_paste_metadata_path, save_progress_dir=save_dir, max_refinement_iter=options["max_refinement_iter"], force_compute=options["force_compute"], cmap_interior=options["cmap_interior"], cmap_tear=options["cmap_tear"], + tear_color_eig_inds=options["tear_color_eig_inds"] ) source.data = cut_and_paste_obj.get_current_embedding_data() @@ -698,8 +727,9 @@ def start_end(): cut_button.disabled = False update_max_fig_height_button.disabled = False fig.toolbar.active_inspect = None - fig.toolbar.active_drag = lasso_tool - print_main_log("Select region to cut using lasso and then press cut.") + #fig.toolbar.active_drag = select_tool + fig.toolbar.active_tap = select_tool + print_main_log("Select polygonal region to cut (tap at a location to add a vertex), press Enter/Esc to finish selection and then press cut button.") def start(): @@ -721,13 +751,16 @@ def on_start_button_click(): # Cut ####################################### def cut_start(): - fig.toolbar.active_drag = None + #fig.toolbar.active_drag = None + fig.toolbar.active_tap = None cut_button.disabled = True save_button.disabled = True def cut_main(): global cut_and_paste_obj + print(source.selected.indices) + print(len(source.selected.indices)) patch_source.data = cut_and_paste_obj.cut_clusters_to_move(source.selected.indices) source.selected.indices = cut_and_paste_obj.selected_pts_to_move.tolist() print("len(source.selected.indices) = " + str(len(source.selected.indices))) @@ -768,7 +801,8 @@ def paste_end(): cut_button.disabled = False save_button.disabled = False patch_source.data = dict(x=[], y=[]) - fig.toolbar.active_drag = lasso_tool + #fig.toolbar.active_drag = select_tool + fig.toolbar.active_tap = select_tool def paste_main(): @@ -800,7 +834,7 @@ def on_paste_button_click(): def save_(): algo = algo_input.value.strip() global cut_and_paste_obj - cut_and_paste_obj.save_buml_obj(algo + ".dat") + cut_and_paste_obj.save_model(algo + ".dat") def on_save_button_click(): @@ -833,6 +867,7 @@ def on_update_max_fig_height_button_click(): column( main_log_div, row(gen_data_path_input), + row(cut_and_paste_metadata_path_input), row(algo_input, hyp_param_input, options_input, max_fig_height_input), row( update_max_fig_height_button,