diff --git a/src/bioset/analysis/loader.py b/src/bioset/analysis/loader.py index bb13302..43d2c33 100644 --- a/src/bioset/analysis/loader.py +++ b/src/bioset/analysis/loader.py @@ -5,6 +5,7 @@ import sqlite3 import tempfile from dataclasses import dataclass, field +from itertools import combinations as iter_combinations from pathlib import Path from typing import Optional @@ -73,6 +74,8 @@ def __init__(self): self.metadata: Optional[AnalysisMetadata] = None self._loaded = False self._total_tiles_cache: dict[int, int] = {} # level -> total tile count + self._channel_totals_voxels_cache: dict[ + tuple[str, float, int], int] = {} # (channel, dilation, level) -> total voxels @property def is_loaded(self) -> bool: @@ -230,8 +233,9 @@ def get_top_combinations( Get top N combinations by aggregated IoU. Aggregates across all tiles: - global_iou = SUM(total_count) / SUM(total_union) - + global_iou = SUM(total_count) / SUM(total_union) + global_overlap_coeff = SUM(total_count) / MIN(SUM(ch_a), SUM(ch_b), ...) + Sorted by global_iou DESC. Self-pairs (e.g. CD8|CD8) are excluded. """ @@ -266,10 +270,16 @@ def get_top_combinations( sum_union = row["sum_union"] or 1 agg_iou = sum_inter / sum_union if sum_union > 0 else 0.0 + # Calculate overlap coefficient from channel_stats + overlap_coeff = self._compute_overlap_coeff( + channels, sum_inter, dilation, hierarchy_level + ) + results.append(CombinationData( channels=channels, total_count=row["agg_count"] or 0, iou=agg_iou, + overlap_coeff=overlap_coeff, )) if len(results) >= limit: @@ -287,6 +297,7 @@ def get_filtered_combinations( ) -> list[CombinationData]: """ Get combinations containing ANY of the specified channels, aggregated by IoU. + Also computes aggregated overlap coefficient. """ if not self.is_loaded or not channel_filter: return [] @@ -348,10 +359,16 @@ def get_filtered_combinations( sum_union = row["sum_union"] or 1 agg_iou = sum_inter / sum_union if sum_union > 0 else 0.0 + # Calculate overlap coefficient from channel_stats + overlap_coeff = self._compute_overlap_coeff( + channels, sum_inter, dilation, hierarchy_level + ) + results.append(CombinationData( channels=channels, total_count=row["agg_count"] or 0, iou=agg_iou, + overlap_coeff=overlap_coeff, )) if len(results) >= limit: @@ -359,6 +376,38 @@ def get_filtered_combinations( return results + def _compute_overlap_coeff( + self, + channels: list[str], + sum_inter: int, + dilation: float, + hierarchy_level: int, + ) -> float: + """ + Compute aggregated overlap coefficient. + + overlap_coeff = SUM(intersection) / MIN(SUM(ch_a), SUM(ch_b), ...) + + Queries channel_stats to get total voxels per channel. + """ + if not channels or sum_inter == 0: + return 0.0 + + # Get total voxels for each channel + channel_totals = [] + for ch in channels: + total = self.get_channel_total_voxels(ch, dilation, hierarchy_level) + if total > 0: + channel_totals.append(total) + + if not channel_totals: + return 0.0 + + min_voxels = min(channel_totals) + return sum_inter / min_voxels if min_voxels > 0 else 0.0 + + + # ────────────────────────────────────────────── # Bar chart: coverage percentage # ────────────────────────────────────────────── @@ -595,46 +644,190 @@ def get_tile_combinations( # Dilation curve # ────────────────────────────────────────────── - def get_dilation_curve( + def get_subcombination_dilation_curves( self, channels: list[str], - hierarchy_level: int, - ) -> list[dict]: - """Get aggregated IoU across all dilations for a combination.""" - if not self.is_loaded: - return [] + hierarchy_level: int = 0, + ) -> dict[str, list[dict]]: + """ + Get dilation curves for all subcombinations of the given channels. + + For channels ["A", "B", "C"], attempts to find curves for: + - Single channels: A, B, C + - Pairs: A|B, A|C, B|C (if they exist in the database) + - Triple: A|B|C (if it exists) + + Only returns subcombinations that actually exist in the database. + + Returns: + Dict mapping channel string (e.g., "A|B") to list of dilation points: + { + "A": [{"dilation": 0.0, "count": ..., "iou": 1.0, "overlap_coeff": 1.0, "density": ...}, ...], + "A|B": [{"dilation": 0.0, "count": ..., "iou": ..., "overlap_coeff": ..., "density": ...}, ...], + ... + } + + Each dilation point dict contains: + - dilation: float + - count: int (voxel count or intersection count) + - iou: float + - overlap_coeff: float + - density: float (percentage of total volume) + """ + if not self.is_loaded or not channels: + return {} + # Get channel ordering for consistent key generation channel_order = self.metadata.channels if self.metadata else [] - sorted_channels = sorted( - channels, - key=lambda c: channel_order.index(c) if c in channel_order else 999 - ) - channels_str = "|".join(sorted_channels) - cursor = self._conn.execute(''' - SELECT - dilation, - SUM(total_count) as sum_inter, - SUM(total_union) as sum_union - FROM combinations - WHERE channels = ? AND hierarchy_level = ? - GROUP BY dilation - ORDER BY dilation - ''', (channels_str, hierarchy_level)) + def sort_channels(ch_list: list[str]) -> list[str]: + return sorted( + ch_list, + key=lambda c: channel_order.index(c) if c in channel_order else 999 + ) - results = [] - for row in cursor: - sum_inter = row["sum_inter"] or 0 - sum_union = row["sum_union"] or 1 - agg_iou = sum_inter / sum_union if sum_union > 0 else 0.0 - results.append({ - "dilation": row["dilation"], - "count": sum_inter, - "iou": agg_iou, - }) + def make_key(ch_list: list[str]) -> str: + return "|".join(sort_channels(ch_list)) + + results = {} + + # Generate all subcombinations of size 1 to len(channels) + for size in range(1, len(channels) + 1): + for combo in iter_combinations(channels, size): + combo_list = list(combo) + combo_key = make_key(combo_list) + + # Get dilation curve for this subcombination + if size == 1: + curve = self._get_single_channel_dilation_curve( + combo_list[0], hierarchy_level + ) + else: + curve = self._get_multi_channel_dilation_curve_full( + combo_list, hierarchy_level + ) + + # Only include if data exists + if curve: + results[combo_key] = curve return results + def _get_single_channel_dilation_curve( + self, + channel: str, + hierarchy_level: int, + ) -> list[dict]: + """Get voxel count across dilations for a single channel.""" + + # Get total volume for density calculation + bounds = self.metadata.volume_bounds + total_volume = ( + (bounds["x"][1] - bounds["x"][0]) * + (bounds["y"][1] - bounds["y"][0]) * + (bounds["z"][1] - bounds["z"][0]) + ) + + try: + cursor = self._conn.execute(''' + SELECT + dilation, + SUM(voxel_count) as total_voxels + FROM channel_stats + WHERE channel = ? AND hierarchy_level = ? + GROUP BY dilation + ORDER BY dilation + ''', (channel, hierarchy_level)) + + results = [] + for row in cursor: + total_voxels = row["total_voxels"] or 0 + density = (total_voxels / total_volume * 100) if total_volume > 0 else 0.0 + + results.append({ + "dilation": row["dilation"], + "count": total_voxels, + "iou": 1.0, # Self-overlap is always 1.0 + "overlap_coeff": 1.0, # Self-overlap is always 1.0 + "density": density, + }) + + return results + + except sqlite3.OperationalError as e: + if "no such table: channel_stats" in str(e): + return [] + raise + + + def _get_multi_channel_dilation_curve_full( + self, + channels: list[str], + hierarchy_level: int, + ) -> list[dict]: + """ + Get IoU, overlap coefficient, and density across dilations for a multi-channel combination. + + Returns empty list if the combination doesn't exist in the database. + """ + # Get total volume for density calculation + bounds = self.metadata.volume_bounds + total_volume = ( + (bounds["x"][1] - bounds["x"][0]) * + (bounds["y"][1] - bounds["y"][0]) * + (bounds["z"][1] - bounds["z"][0]) + ) + + channel_order = self.metadata.channels if self.metadata else [] + sorted_channels = sorted( + channels, + key=lambda c: channel_order.index(c) if c in channel_order else 999 + ) + channels_str = "|".join(sorted_channels) + + cursor = self._conn.execute(''' + SELECT + dilation, + SUM(total_count) as sum_inter, + SUM(total_union) as sum_union + FROM combinations + WHERE channels = ? AND hierarchy_level = ? + GROUP BY dilation + ORDER BY dilation + ''', (channels_str, hierarchy_level)) + + rows = cursor.fetchall() + + if not rows: + return [] # Combination doesn't exist in database + + results = [] + for row in rows: + dilation = row["dilation"] + sum_inter = row["sum_inter"] or 0 + sum_union = row["sum_union"] or 1 + + # IoU + iou = sum_inter / sum_union if sum_union > 0 else 0.0 + + # Overlap coefficient: intersection / min(channel_voxels) + overlap_coeff = self._compute_overlap_coeff( + channels, sum_inter, dilation, hierarchy_level + ) + + # Density of intersection + density = (sum_inter / total_volume * 100) if total_volume > 0 else 0.0 + + results.append({ + "dilation": dilation, + "count": sum_inter, + "iou": iou, + "overlap_coeff": overlap_coeff, + "density": density, + }) + + return results + # ────────────────────────────────────────────── # Channel voxel totals (for reference) # ────────────────────────────────────────────── @@ -645,6 +838,11 @@ def get_channel_total_voxels( """Get total voxels for a single channel across all tiles.""" if not self.is_loaded: return 0 + + cache_entry = (channel, dilation, level) + if cache_entry in self._channel_totals_voxels_cache: + return self._channel_totals_voxels_cache[cache_entry] + try: cursor = self._conn.execute(''' SELECT SUM(voxel_count) as total @@ -652,7 +850,10 @@ def get_channel_total_voxels( WHERE channel = ? AND dilation = ? AND hierarchy_level = ? ''', (channel, dilation, level)) row = cursor.fetchone() - return row["total"] if row and row["total"] else 0 + total = row["total"] if row and row["total"] else 0 + + self._channel_totals_voxels_cache[cache_entry] = total + return total except sqlite3.OperationalError as e: if "no such table: channel_stats" in str(e): return 0 @@ -697,6 +898,7 @@ def close(self): self._loaded = False self.metadata = None self._total_tiles_cache.clear() + self._channel_totals_voxels_cache.clear() def __del__(self): self.close() diff --git a/src/bioset/ui/callbacks.py b/src/bioset/ui/callbacks.py index 7f1a449..2d7a3c8 100644 --- a/src/bioset/ui/callbacks.py +++ b/src/bioset/ui/callbacks.py @@ -5,8 +5,6 @@ import os import tempfile -import requests - from bioset.NOV import register_nov_callbacks from bioset.bookmark import register_bookmark_callbacks, capture_screenshot_png_bytes from bioset.llm import BiomniLocalClient @@ -387,6 +385,7 @@ def load_analysis_file(file_info): # Initialize plot channel selections with all channels state.upset_selected_channels = [ch for ch in state.analysis_channels] state.bar_selected_channels = [ch for ch in state.analysis_channels] + state.dilation_selected_channels = [ch for ch in state.analysis_channels] state.analysis_loaded = True state.right_drawer_open = True @@ -398,6 +397,7 @@ def load_analysis_file(file_info): update_heatmap_combinations() update_upset_data() update_bar_data() + update_dilation_data() if _refs["view"]: _refs["view"].update() @@ -650,6 +650,37 @@ def update_heatmap(): if _refs["view"]: _refs["view"].update() + def print_dilation_curve(): + """Print dilation curves for all subcombinations of the selected channels.""" + loader = _refs.get("analysis_loader") + if not loader or not loader.is_loaded: + return + + channels = state.heatmap_combination or [] + if not channels: + return + + level = state.current_hierarchy_level + curves = loader.get_subcombination_dilation_curves(channels, level) + if not curves: + print(f"[dilation-curve] No data for {channels}") + return + + print(f"[dilation-curve] Subcombination curves for {' | '.join(channels)} (hierarchy_level={level})") + for combo_key, curve in curves.items(): + is_single = "|" not in combo_key + print(f"\n --- {combo_key} ---") + if is_single: + print(f" {'Dilation':>10} {'Voxels':>12} {'Density':>10}") + print(f" {'-'*10} {'-'*12} {'-'*10}") + for pt in curve: + print(f" {pt['dilation']:>10.1f} {pt['count']:>12} {pt.get('density', 0):>10.6f}") + else: + print(f" {'Dilation':>10} {'Intersection':>12} {'IoU':>10} {'Overlap':>10}") + print(f" {'-'*10} {'-'*12} {'-'*10} {'-'*10}") + for pt in curve: + print(f" {pt['dilation']:>10.1f} {pt['count']:>12} {pt['iou']:>10.6f} {pt.get('overlap_coeff', 0):>10.6f}") + def _filter_combinations_by_channel_selection(combinations, selected_channels): """ Filter combinations to only include those whose channels are all @@ -697,6 +728,7 @@ def update_upset_data(): mapped_combinations.append({ "channels": combination.channels, "iou": combination.iou, + "overlap_coeff": combination.overlap_coeff, }) state.upset_data = mapped_combinations @@ -750,6 +782,7 @@ def update_upset_data_local(): mapped_combinations.append({ "channels": combination.channels, "iou": combination.iou, + "overlap_coeff": combination.overlap_coeff, }) state.upset_data_local = mapped_combinations @@ -811,6 +844,50 @@ def update_bar_data_local(): state.bar_data_local = local_bar_data print(f"[callbacks] Bar local data updated: {len(local_bar_data)} channels") + + def update_dilation_data(): + """Update dilation curve data for the line plot.""" + loader = _refs.get("analysis_loader") + if not loader or not loader.is_loaded: + print("[callbacks] Cannot update dilation data - loader not ready") + state.dilation_data = {} + return + + print( + f"[callbacks] Updating dilation data: mode={state.dilation_view_mode}, level={state.current_hierarchy_level}") + + active_ids = state.active_channels or [] + if not active_ids: + state.dilation_data = {} + print("[callbacks] Dilation data cleared (no channels selected)") + return + + channels_list = state.channels or [] + id_to_name = {ch["id"]: ch["name"] for ch in channels_list} + selected = [id_to_name[ch_id] for ch_id in active_ids if ch_id in id_to_name] + + if not selected: + state.dilation_data = {} + return + + view_mode = getattr(state, "dilation_view_mode", "single") + level = getattr(state, "current_hierarchy_level", 0) + + dilation_curves = loader.get_subcombination_dilation_curves(selected, hierarchy_level=level) + + result = {} + if view_mode == "single": + for curve_key in dilation_curves: + if "|" not in curve_key: + result[curve_key] = dilation_curves[curve_key] + + else: + for curve_key in dilation_curves: + if "|" in curve_key: + result[curve_key] = dilation_curves[curve_key] + state.dilation_data = result + + print(f"[callbacks] Dilation data updated: keys={list(state.dilation_data.keys())}") def reset_camera(): """Reset camera to initial position (from when data was first loaded). Use after opening a Bookmark to return to default view.""" @@ -1817,6 +1894,7 @@ def generate_pdf_report(report_data=None): ctrl.load_analysis_file = load_analysis_file ctrl.update_heatmap = update_heatmap ctrl.update_heatmap_combinations = update_heatmap_combinations + ctrl.print_dilation_curve = print_dilation_curve ctrl.toggle_channel = toggle_channel ctrl.update_active_channels = update_active_channels ctrl.reset_camera = reset_camera @@ -1831,6 +1909,7 @@ def generate_pdf_report(report_data=None): ctrl.update_upset_data_local = update_upset_data_local ctrl.update_bar_data = update_bar_data ctrl.update_bar_data_local = update_bar_data_local + ctrl.update_dilation_data = update_dilation_data ctrl.chatbot_login = chatbot_login ctrl.biomni_add_data = biomni_add_data ctrl.biomni_upload_file = biomni_upload_file diff --git a/src/bioset/ui/components/right_drawer.py b/src/bioset/ui/components/right_drawer.py index 187a8a6..6c5e6ba 100644 --- a/src/bioset/ui/components/right_drawer.py +++ b/src/bioset/ui/components/right_drawer.py @@ -134,6 +134,7 @@ def right_drawer(state, ctrl): :dataLocal="upset_data_local" :channelData="channels" :view-mode="upset_view_mode" + :metric="upset_metric" :offset="upset_expanded_offset" :limit="upset_expanded_limit" :width="1200" @@ -150,6 +151,28 @@ def right_drawer(state, ctrl): vuetify.VDivider() with vuetify.VCardText(): + with html.Div(classes="mb-4 mt-2"): + html.Div("Metric", classes="text-caption mb-2 text-left", style="color: white;") + with html.Div(classes="d-flex justify-space-between", style="width: 100%; gap: 8px;"): + vuetify.VBtn( + "IoU", + click="upset_metric = 'iou'", + color=("upset_metric === 'iou' ? 'white' : 'grey darken-3'",), + dark=("upset_metric !== 'iou'",), + title="Intersection over Union", + class_="flex-grow-1 rounded px-4", + style="flex: 1;" + ) + vuetify.VBtn( + "Overlap Coefficient", + click="upset_metric = 'overlap_coeff'", + color=("upset_metric === 'overlap_coeff' ? 'white' : 'grey darken-3'",), + dark=("upset_metric !== 'overlap_coeff'",), + title="Overlap Coefficient", + class_="flex-grow-1 rounded px-4", + style="flex: 1;" + ) + with html.Div(classes="mb-4 mt-2"): html.Div("Minimum Number of Channels", classes="text-caption mb-2 text-left", style="color: white;") with html.Div(classes="d-flex justify-space-between", style="width: 100%; gap: 8px;"): @@ -214,6 +237,7 @@ def right_drawer(state, ctrl): :dataLocal="upset_data_local" :channelData="channels" :view-mode="upset_view_mode" + :metric="upset_metric" :offset="upset_offset" :limit="upset_limit" @click="upset_click = $event" @@ -402,3 +426,78 @@ def right_drawer(state, ctrl): /> """ ) + + vuetify.VDivider() + + with html.Div(classes="px-4 py-3"): + html.Div("Dilation Curves", classes="text-overline mb-2 text-center", style="color: white;") + + with html.Div(classes="d-flex justify-center mb-2 align-center"): + with vuetify.VBtnToggle( + v_model=("dilation_view_mode", "single"), + mandatory=True, + dense=True, + classes="mr-2", + style="background: transparent;", + ): + vuetify.VBtn("Single", value="single", small=True, classes="text-capitalize", outlined=True) + vuetify.VBtn("Multiple", value="multiple", small=True, classes="text-capitalize", outlined=True) + + # Filter button + with vuetify.VBtn(icon=True, small=True, v_show="dilation_view_mode === 'multiple'", + click="dilation_filter_dialog = true"): + vuetify.VIcon("mdi-filter-variant", small=True) + + with vuetify.VDialog(v_model=("dilation_filter_dialog",), max_width="900px", scrollable=True): + with vuetify.VCard(classes="grey darken-4 white--text"): + vuetify.VCardTitle("Settings - Dilation Plot", classes="headline grey darken-3") + vuetify.VDivider() + + with vuetify.VCardText(): + with html.Div(v_show="dilation_view_mode === 'multiple'", classes="mb-4"): + html.Div("Intersection Metric", classes="text-overline mb-1", style="color: white;") + with html.Div(classes="d-flex justify-space-between", style="width: 100%; gap: 8px;"): + vuetify.VBtn( + "IoU", + click="dilation_metric_multiple = 'iou'", + color=("dilation_metric_multiple === 'iou' ? 'white' : 'grey darken-3'",), + dark=("dilation_metric_multiple !== 'iou'",), + title="Intersection over Union", + class_="flex-grow-1 rounded px-4", + style="flex: 1;" + ) + vuetify.VBtn( + "Overlap Coefficient", + click="dilation_metric_multiple = 'overlap_coeff'", + color=("dilation_metric_multiple === 'overlap_coeff' ? 'white' : 'grey darken-3'",), + dark=("dilation_metric_multiple !== 'overlap_coeff'",), + title="Overlap Coefficient", + class_="flex-grow-1 rounded px-4", + style="flex: 1;" + ) + vuetify.VBtn( + "Count", + click="dilation_metric_multiple = 'count'", + color=("dilation_metric_multiple === 'count' ? 'white' : 'grey darken-3'",), + dark=("dilation_metric_multiple !== 'count'",), + title="Count", + class_="flex-grow-1 rounded px-4", + style="flex: 1;" + ) + + vuetify.VDivider() + with vuetify.VCardActions(classes="grey darken-3"): + vuetify.VSpacer() + vuetify.VBtn("Close", color="surface-variant", click="dilation_filter_dialog = false") + + # Vue component for Line plot + vuetify.Template( + """ + + """ + ) diff --git a/src/bioset/ui/scripts/__init__.py b/src/bioset/ui/scripts/__init__.py index 3e86d9e..0697750 100644 --- a/src/bioset/ui/scripts/__init__.py +++ b/src/bioset/ui/scripts/__init__.py @@ -146,5 +146,6 @@ def register_scripts(client): client.Script(NOV_DRAG_SCRIPT) # Load first so window.novStartDrag exists when lens is clicked client.Script(_read_js("upset.js")) client.Script(_read_js("bar.js")) + client.Script(_read_js("linechart.js")) client.Script(_read_js("mousemove.js")) client.Script(_read_js("histogram.js")) diff --git a/src/bioset/ui/scripts/linechart.js b/src/bioset/ui/scripts/linechart.js new file mode 100644 index 0000000..9c439f3 --- /dev/null +++ b/src/bioset/ui/scripts/linechart.js @@ -0,0 +1,204 @@ +Vue.component('linechart', { + props: { + data: Object, + channelData: Array, + viewMode: String, + metric: String, + width: { + type: Number, + default: 330 + }, + height: { + type: Number, + default: 280 + } + }, + template: '
', + watch: { + data: { handler: 'render', deep: true }, + metric: 'render', + viewMode: 'render', + channelData: { handler: 'render', deep: true } + }, + mounted() { + this.render(); + }, + methods: { + render() { + if (!this.$refs.container || !window.d3) return; + + const container = this.$refs.container; + d3.select(container).selectAll("*").remove(); + + if (!this.data || Object.keys(this.data).length === 0) { + container.innerHTML = '
Select a channel to see the dilation plot
'; + return; + } + + const width = this.width; + const height = this.height; + const marginTop = 30; + const marginBottom = 50; + + const svg = d3.select(container) + .append("svg") + .attr("width", width) + .attr("height", height) + .attr("viewBox", [0, 0, width, height]); + + // Gather all points to compute scales + let allDilations = []; + let allValues = []; + + const lines = Object.keys(this.data).map(key => { + const curveData = this.data[key]; + const points = curveData.map(d => ({ + x: d.dilation !== undefined ? d.dilation : 0, + y: d[this.metric] !== undefined ? d[this.metric] : null + })).filter(d => d.y !== null && !isNaN(d.y)); + + let displayLabel = key; + const isCombo = key.includes('|'); + if (isCombo) { + const channelNames = key.split('|'); + if (channelNames.length > 3) { + displayLabel = `[${channelNames.slice(0, 3).join(', ')}, ...]`; + } else { + displayLabel = `[${channelNames.join(', ')}]`; + } + } + + points.forEach(p => { + allDilations.push(p.x); + allValues.push(p.y); + }); + + return { + key: key, + displayLabel: displayLabel, + points: points, + isCombo: isCombo + }; + }); + + if (allDilations.length === 0) { + return; + } + + const maxLabelLength = d3.max(lines, d => d.displayLabel.length); + const marginRight = Math.max(15, maxLabelLength * 6.5); + + const xDomain = d3.extent(allDilations); + if (xDomain[0] === xDomain[1]) { + xDomain[0] -= 1; + xDomain[1] += 1; + } + + const yMax = d3.max(allValues); + const yMin = d3.min(allValues); + const padding = (yMax - yMin) * 0.1 || (yMax * 0.1) || 0.1; + const yDomain = [Math.max(0, yMin - padding), yMax + padding]; + + const tickFormat = d => { + if(d >= 1000) { + return (d/1000).toFixed(0) + "k"; + } + return d.toString(); + }; + const sampleTicks = d3.ticks(yDomain[0], yDomain[1], 5); + const maxTickLabelLength = d3.max(sampleTicks, d => tickFormat(d).length) || 0; + const marginLeft = Math.max(50, maxTickLabelLength * 7 + 25); + + const x = d3.scaleLinear() + .domain(xDomain) + .range([marginLeft, width - marginRight]); + + const y = d3.scaleLinear() + .domain(yDomain) + .range([height - marginBottom, marginTop]); + + svg.append("g") + .attr("transform", `translate(0,${height - marginBottom})`) + .call(d3.axisBottom(x).ticks(5)) + .selectAll("text") + .attr("fill", "white") + .style("font-size", "12px"); + + svg.append("text") + .attr("x", marginLeft + (width - marginLeft - marginRight) / 2) + .attr("y", height - 10) + .attr("text-anchor", "middle") + .style("fill", "white") + .style("font-size", "12px") + .text("Dilation"); + + const metricLabels = { + "iou": "IoU", + "overlap_coeff": "Overlap Coefficient", + "density": "Density", + "count": "Voxel Count" + }; + + svg.append("g") + .attr("transform", `translate(${marginLeft},0)`) + .call(d3.axisLeft(y).ticks(5).tickFormat(tickFormat)) + .call(g => g.select(".domain").remove()) + .selectAll("text") + .attr("fill", "white") + .style("font-size", "12px"); + + svg.append("text") + .attr("x", -marginTop - (height - marginTop - marginBottom) / 2) + .attr("y", Math.max(12, marginLeft - (maxTickLabelLength * 7) - 20)) + .attr("transform", "rotate(-90)") + .attr("text-anchor", "middle") + .style("fill", "white") + .style("font-size", "12px") + .text(metricLabels[this.metric]); + + svg.selectAll(".domain").attr("stroke", "white"); + svg.selectAll(".tick line").attr("stroke", "#444"); + + const lineGen = d3.line() + .x(d => x(d.x)) + .y(d => y(d.y)); + + lines.forEach(lineObj => { + let color = "#FFFFFF"; + let strokeWidth = 2; + + const ch = this.channelData.find(c => c.name === lineObj.key); + if (ch) { + color = ch.color; + } + + svg.append("path") + .datum(lineObj.points) + .attr("fill", "none") + .attr("stroke", color) + .attr("stroke-width", strokeWidth) + .attr("d", lineGen); + + if (lineObj.points.length > 0) { + const lastPoint = lineObj.points[lineObj.points.length - 1]; + svg.append("text") + .attr("x", x(lastPoint.x) + 5) + .attr("y", y(lastPoint.y)) + .attr("alignment-baseline", "middle") + .attr("fill", color) + .style("font-size", "10px") + .style("font-weight", lineObj.isCombo ? "bold" : "normal") + .text(lineObj.displayLabel); + } + + lineObj.points.forEach(p => { + svg.append("circle") + .attr("cx", x(p.x)) + .attr("cy", y(p.y)) + .attr("r", 4) + .attr("fill", color); + }); + }); + } + } +}); diff --git a/src/bioset/ui/scripts/upset.js b/src/bioset/ui/scripts/upset.js index e7cacea..e0eb25a 100644 --- a/src/bioset/ui/scripts/upset.js +++ b/src/bioset/ui/scripts/upset.js @@ -1,11 +1,13 @@ - -// UpSet Plot Component Vue.component('upset-plot', { props: { data: Array, dataLocal: Array, channelData: Array, viewMode: String, + metric: { + type: String, + default: 'iou' + }, offset: { type: Number, default: 0 @@ -28,6 +30,7 @@ Vue.component('upset-plot', { data: 'render', dataLocal: 'render', viewMode: 'render', + metric: 'render', offset: 'render', limit: 'render', channelData: { @@ -47,7 +50,11 @@ Vue.component('upset-plot', { return; } - const maxIou = sourceData.length > 0 ? Math.max(...sourceData.map(d => d.iou)) : 0; + let maxMetricValue = 0; + if (sourceData.length > 0) { + const values = sourceData.map(d => d[this.metric]); + maxMetricValue = Math.max(...values); + } const start = this.offset; const end = start + this.limit; @@ -55,7 +62,7 @@ Vue.component('upset-plot', { const mappedData = renderData.map(item => ({ sets: item.channels, - cardinality: item.iou + cardinality: item[this.metric] })); const { sets, combinations } = UpSetJS.extractFromExpression(mappedData); @@ -107,7 +114,7 @@ Vue.component('upset-plot', { widthRatios: [0, 0.35], heightRatios: [0.4], exportButtons: false, - yDomain: [0, maxIou], + yDomain: [0, maxMetricValue], onClick: (clickedItem) => { setTimeout(() => { if (!clickedItem) { diff --git a/src/bioset/ui/state.py b/src/bioset/ui/state.py index 63506ef..ce63d7c 100644 --- a/src/bioset/ui/state.py +++ b/src/bioset/ui/state.py @@ -111,6 +111,13 @@ def init_state(state): state.setdefault("bar_offset", 0) state.setdefault("bar_limit", 10) + # Dilation Lineplot + state.setdefault("dilation_data", {}) + state.setdefault("dilation_view_mode", "single") # one of: ["single", "multiple"] + state.setdefault("dilation_metric_single", "density") # always fixed to density + state.setdefault("dilation_metric_multiple", "iou") # one of: ["iou", "overlap_coeff", "count", "density"] + state.setdefault("dilation_filter_dialog", False) + # Expanded View States state.setdefault("upset_expanded_offset", 0) state.setdefault("upset_expanded_limit", 40) @@ -170,6 +177,7 @@ def init_state(state): state.setdefault("upset_filtered_channels", []) # Channels shown in filter list state.setdefault("upset_filter_dialog", False) state.setdefault("upset_expanded", False) + state.setdefault("upset_metric", "iou") # one of ["iou", "overlap_coeff"] state.setdefault("upset_min_channels", 2) # default minimum combination limit # Bar Plot filtering @@ -307,6 +315,8 @@ def on_active_channels_change(active_channels, **kwargs): ctrl.update_bar_data_local() if hasattr(ctrl, 'nov_recompute_scores_if_visible'): ctrl.nov_recompute_scores_if_visible() + if hasattr(ctrl, 'update_dilation_data'): + ctrl.update_dilation_data() @state.change("channels") def on_channels_change(channels, **kwargs): @@ -388,6 +398,8 @@ def on_heatmap_combination_change(heatmap_combination, **kwargs): print(f"[state] Heatmap combination changed: {heatmap_combination}") if hasattr(ctrl, 'update_heatmap'): ctrl.update_heatmap() + if hasattr(ctrl, 'print_dilation_curve'): + ctrl.print_dilation_curve() @state.change("heatmap_outline_only") def on_heatmap_outline_only_change(heatmap_outline_only, **kwargs): @@ -425,6 +437,8 @@ def on_hierarchy_change(current_hierarchy_level, **kwargs): ctrl.update_upset_data() if hasattr(ctrl, 'update_bar_data'): ctrl.update_bar_data() + if hasattr(ctrl, 'update_dilation_data'): + ctrl.update_dilation_data() @state.change("upset_data") def on_upset_data_change(upset_data, **kwargs): @@ -450,6 +464,17 @@ def on_bar_view_mode_change(bar_view_mode, **kwargs): state.bar_offset = 0 state.bar_expanded_offset = 0 + @state.change("dilation_view_mode") + def on_dilation_view_mode_change(dilation_view_mode, **kwargs): + if hasattr(ctrl, 'update_dilation_data'): + ctrl.update_dilation_data() + + @state.change("dilation_selected_channels") + def on_dilation_selected_channels_change(dilation_selected_channels, **kwargs): + print(f"[state] Dilation selected channels changed: {len(dilation_selected_channels)} channels") + if hasattr(ctrl, 'update_dilation_data'): + ctrl.update_dilation_data() + @state.change("upset_selected_channels") def on_upset_selected_channels_change(upset_selected_channels, **kwargs): print(f"[state] UpSet selected channels changed: {len(upset_selected_channels)} channels") @@ -462,6 +487,14 @@ def on_upset_min_channels_change(upset_min_channels, **kwargs): if hasattr(ctrl, 'update_upset_data'): ctrl.update_upset_data() + @state.change("upset_metric") + def on_upset_metric_change(upset_metric, **kwargs): + print(f"[state] UpSet metric changed: {upset_metric}") + if hasattr(ctrl, 'update_upset_data'): + ctrl.update_upset_data() + if hasattr(ctrl, 'update_upset_data_local'): + ctrl.update_upset_data_local() + @state.change("bar_selected_channels") def on_bar_selected_channels_change(bar_selected_channels, **kwargs): print(f"[state] Bar selected channels changed: {len(bar_selected_channels)} channels") @@ -529,4 +562,3 @@ def on_analysis_channels_change(analysis_channels, **kwargs): state.bar_filtered_channels = list(analysis_channels) state.upset_search = "" state.bar_search = "" - \ No newline at end of file