diff --git a/src/bioset/analysis/loader.py b/src/bioset/analysis/loader.py
index bb13302..43d2c33 100644
--- a/src/bioset/analysis/loader.py
+++ b/src/bioset/analysis/loader.py
@@ -5,6 +5,7 @@
import sqlite3
import tempfile
from dataclasses import dataclass, field
+from itertools import combinations as iter_combinations
from pathlib import Path
from typing import Optional
@@ -73,6 +74,8 @@ def __init__(self):
self.metadata: Optional[AnalysisMetadata] = None
self._loaded = False
self._total_tiles_cache: dict[int, int] = {} # level -> total tile count
+ self._channel_totals_voxels_cache: dict[
+ tuple[str, float, int], int] = {} # (channel, dilation, level) -> total voxels
@property
def is_loaded(self) -> bool:
@@ -230,8 +233,9 @@ def get_top_combinations(
Get top N combinations by aggregated IoU.
Aggregates across all tiles:
- global_iou = SUM(total_count) / SUM(total_union)
-
+ global_iou = SUM(total_count) / SUM(total_union)
+ global_overlap_coeff = SUM(total_count) / MIN(SUM(ch_a), SUM(ch_b), ...)
+
Sorted by global_iou DESC.
Self-pairs (e.g. CD8|CD8) are excluded.
"""
@@ -266,10 +270,16 @@ def get_top_combinations(
sum_union = row["sum_union"] or 1
agg_iou = sum_inter / sum_union if sum_union > 0 else 0.0
+ # Calculate overlap coefficient from channel_stats
+ overlap_coeff = self._compute_overlap_coeff(
+ channels, sum_inter, dilation, hierarchy_level
+ )
+
results.append(CombinationData(
channels=channels,
total_count=row["agg_count"] or 0,
iou=agg_iou,
+ overlap_coeff=overlap_coeff,
))
if len(results) >= limit:
@@ -287,6 +297,7 @@ def get_filtered_combinations(
) -> list[CombinationData]:
"""
Get combinations containing ANY of the specified channels, aggregated by IoU.
+ Also computes aggregated overlap coefficient.
"""
if not self.is_loaded or not channel_filter:
return []
@@ -348,10 +359,16 @@ def get_filtered_combinations(
sum_union = row["sum_union"] or 1
agg_iou = sum_inter / sum_union if sum_union > 0 else 0.0
+ # Calculate overlap coefficient from channel_stats
+ overlap_coeff = self._compute_overlap_coeff(
+ channels, sum_inter, dilation, hierarchy_level
+ )
+
results.append(CombinationData(
channels=channels,
total_count=row["agg_count"] or 0,
iou=agg_iou,
+ overlap_coeff=overlap_coeff,
))
if len(results) >= limit:
@@ -359,6 +376,38 @@ def get_filtered_combinations(
return results
+ def _compute_overlap_coeff(
+ self,
+ channels: list[str],
+ sum_inter: int,
+ dilation: float,
+ hierarchy_level: int,
+ ) -> float:
+ """
+ Compute aggregated overlap coefficient.
+
+ overlap_coeff = SUM(intersection) / MIN(SUM(ch_a), SUM(ch_b), ...)
+
+ Queries channel_stats to get total voxels per channel.
+ """
+ if not channels or sum_inter == 0:
+ return 0.0
+
+ # Get total voxels for each channel
+ channel_totals = []
+ for ch in channels:
+ total = self.get_channel_total_voxels(ch, dilation, hierarchy_level)
+ if total > 0:
+ channel_totals.append(total)
+
+ if not channel_totals:
+ return 0.0
+
+ min_voxels = min(channel_totals)
+ return sum_inter / min_voxels if min_voxels > 0 else 0.0
+
+
+
# ──────────────────────────────────────────────
# Bar chart: coverage percentage
# ──────────────────────────────────────────────
@@ -595,46 +644,190 @@ def get_tile_combinations(
# Dilation curve
# ──────────────────────────────────────────────
- def get_dilation_curve(
+ def get_subcombination_dilation_curves(
self,
channels: list[str],
- hierarchy_level: int,
- ) -> list[dict]:
- """Get aggregated IoU across all dilations for a combination."""
- if not self.is_loaded:
- return []
+ hierarchy_level: int = 0,
+ ) -> dict[str, list[dict]]:
+ """
+ Get dilation curves for all subcombinations of the given channels.
+
+ For channels ["A", "B", "C"], attempts to find curves for:
+ - Single channels: A, B, C
+ - Pairs: A|B, A|C, B|C (if they exist in the database)
+ - Triple: A|B|C (if it exists)
+
+ Only returns subcombinations that actually exist in the database.
+
+ Returns:
+ Dict mapping channel string (e.g., "A|B") to list of dilation points:
+ {
+ "A": [{"dilation": 0.0, "count": ..., "iou": 1.0, "overlap_coeff": 1.0, "density": ...}, ...],
+ "A|B": [{"dilation": 0.0, "count": ..., "iou": ..., "overlap_coeff": ..., "density": ...}, ...],
+ ...
+ }
+
+ Each dilation point dict contains:
+ - dilation: float
+ - count: int (voxel count or intersection count)
+ - iou: float
+ - overlap_coeff: float
+ - density: float (percentage of total volume)
+ """
+ if not self.is_loaded or not channels:
+ return {}
+ # Get channel ordering for consistent key generation
channel_order = self.metadata.channels if self.metadata else []
- sorted_channels = sorted(
- channels,
- key=lambda c: channel_order.index(c) if c in channel_order else 999
- )
- channels_str = "|".join(sorted_channels)
- cursor = self._conn.execute('''
- SELECT
- dilation,
- SUM(total_count) as sum_inter,
- SUM(total_union) as sum_union
- FROM combinations
- WHERE channels = ? AND hierarchy_level = ?
- GROUP BY dilation
- ORDER BY dilation
- ''', (channels_str, hierarchy_level))
+ def sort_channels(ch_list: list[str]) -> list[str]:
+ return sorted(
+ ch_list,
+ key=lambda c: channel_order.index(c) if c in channel_order else 999
+ )
- results = []
- for row in cursor:
- sum_inter = row["sum_inter"] or 0
- sum_union = row["sum_union"] or 1
- agg_iou = sum_inter / sum_union if sum_union > 0 else 0.0
- results.append({
- "dilation": row["dilation"],
- "count": sum_inter,
- "iou": agg_iou,
- })
+ def make_key(ch_list: list[str]) -> str:
+ return "|".join(sort_channels(ch_list))
+
+ results = {}
+
+ # Generate all subcombinations of size 1 to len(channels)
+ for size in range(1, len(channels) + 1):
+ for combo in iter_combinations(channels, size):
+ combo_list = list(combo)
+ combo_key = make_key(combo_list)
+
+ # Get dilation curve for this subcombination
+ if size == 1:
+ curve = self._get_single_channel_dilation_curve(
+ combo_list[0], hierarchy_level
+ )
+ else:
+ curve = self._get_multi_channel_dilation_curve_full(
+ combo_list, hierarchy_level
+ )
+
+ # Only include if data exists
+ if curve:
+ results[combo_key] = curve
return results
+ def _get_single_channel_dilation_curve(
+ self,
+ channel: str,
+ hierarchy_level: int,
+ ) -> list[dict]:
+ """Get voxel count across dilations for a single channel."""
+
+ # Get total volume for density calculation
+ bounds = self.metadata.volume_bounds
+ total_volume = (
+ (bounds["x"][1] - bounds["x"][0]) *
+ (bounds["y"][1] - bounds["y"][0]) *
+ (bounds["z"][1] - bounds["z"][0])
+ )
+
+ try:
+ cursor = self._conn.execute('''
+ SELECT
+ dilation,
+ SUM(voxel_count) as total_voxels
+ FROM channel_stats
+ WHERE channel = ? AND hierarchy_level = ?
+ GROUP BY dilation
+ ORDER BY dilation
+ ''', (channel, hierarchy_level))
+
+ results = []
+ for row in cursor:
+ total_voxels = row["total_voxels"] or 0
+ density = (total_voxels / total_volume * 100) if total_volume > 0 else 0.0
+
+ results.append({
+ "dilation": row["dilation"],
+ "count": total_voxels,
+ "iou": 1.0, # Self-overlap is always 1.0
+ "overlap_coeff": 1.0, # Self-overlap is always 1.0
+ "density": density,
+ })
+
+ return results
+
+ except sqlite3.OperationalError as e:
+ if "no such table: channel_stats" in str(e):
+ return []
+ raise
+
+
+ def _get_multi_channel_dilation_curve_full(
+ self,
+ channels: list[str],
+ hierarchy_level: int,
+ ) -> list[dict]:
+ """
+ Get IoU, overlap coefficient, and density across dilations for a multi-channel combination.
+
+ Returns empty list if the combination doesn't exist in the database.
+ """
+ # Get total volume for density calculation
+ bounds = self.metadata.volume_bounds
+ total_volume = (
+ (bounds["x"][1] - bounds["x"][0]) *
+ (bounds["y"][1] - bounds["y"][0]) *
+ (bounds["z"][1] - bounds["z"][0])
+ )
+
+ channel_order = self.metadata.channels if self.metadata else []
+ sorted_channels = sorted(
+ channels,
+ key=lambda c: channel_order.index(c) if c in channel_order else 999
+ )
+ channels_str = "|".join(sorted_channels)
+
+ cursor = self._conn.execute('''
+ SELECT
+ dilation,
+ SUM(total_count) as sum_inter,
+ SUM(total_union) as sum_union
+ FROM combinations
+ WHERE channels = ? AND hierarchy_level = ?
+ GROUP BY dilation
+ ORDER BY dilation
+ ''', (channels_str, hierarchy_level))
+
+ rows = cursor.fetchall()
+
+ if not rows:
+ return [] # Combination doesn't exist in database
+
+ results = []
+ for row in rows:
+ dilation = row["dilation"]
+ sum_inter = row["sum_inter"] or 0
+ sum_union = row["sum_union"] or 1
+
+ # IoU
+ iou = sum_inter / sum_union if sum_union > 0 else 0.0
+
+ # Overlap coefficient: intersection / min(channel_voxels)
+ overlap_coeff = self._compute_overlap_coeff(
+ channels, sum_inter, dilation, hierarchy_level
+ )
+
+ # Density of intersection
+ density = (sum_inter / total_volume * 100) if total_volume > 0 else 0.0
+
+ results.append({
+ "dilation": dilation,
+ "count": sum_inter,
+ "iou": iou,
+ "overlap_coeff": overlap_coeff,
+ "density": density,
+ })
+
+ return results
+
# ──────────────────────────────────────────────
# Channel voxel totals (for reference)
# ──────────────────────────────────────────────
@@ -645,6 +838,11 @@ def get_channel_total_voxels(
"""Get total voxels for a single channel across all tiles."""
if not self.is_loaded:
return 0
+
+ cache_entry = (channel, dilation, level)
+ if cache_entry in self._channel_totals_voxels_cache:
+ return self._channel_totals_voxels_cache[cache_entry]
+
try:
cursor = self._conn.execute('''
SELECT SUM(voxel_count) as total
@@ -652,7 +850,10 @@ def get_channel_total_voxels(
WHERE channel = ? AND dilation = ? AND hierarchy_level = ?
''', (channel, dilation, level))
row = cursor.fetchone()
- return row["total"] if row and row["total"] else 0
+ total = row["total"] if row and row["total"] else 0
+
+ self._channel_totals_voxels_cache[cache_entry] = total
+ return total
except sqlite3.OperationalError as e:
if "no such table: channel_stats" in str(e):
return 0
@@ -697,6 +898,7 @@ def close(self):
self._loaded = False
self.metadata = None
self._total_tiles_cache.clear()
+ self._channel_totals_voxels_cache.clear()
def __del__(self):
self.close()
diff --git a/src/bioset/ui/callbacks.py b/src/bioset/ui/callbacks.py
index 7f1a449..2d7a3c8 100644
--- a/src/bioset/ui/callbacks.py
+++ b/src/bioset/ui/callbacks.py
@@ -5,8 +5,6 @@
import os
import tempfile
-import requests
-
from bioset.NOV import register_nov_callbacks
from bioset.bookmark import register_bookmark_callbacks, capture_screenshot_png_bytes
from bioset.llm import BiomniLocalClient
@@ -387,6 +385,7 @@ def load_analysis_file(file_info):
# Initialize plot channel selections with all channels
state.upset_selected_channels = [ch for ch in state.analysis_channels]
state.bar_selected_channels = [ch for ch in state.analysis_channels]
+ state.dilation_selected_channels = [ch for ch in state.analysis_channels]
state.analysis_loaded = True
state.right_drawer_open = True
@@ -398,6 +397,7 @@ def load_analysis_file(file_info):
update_heatmap_combinations()
update_upset_data()
update_bar_data()
+ update_dilation_data()
if _refs["view"]:
_refs["view"].update()
@@ -650,6 +650,37 @@ def update_heatmap():
if _refs["view"]:
_refs["view"].update()
+ def print_dilation_curve():
+ """Print dilation curves for all subcombinations of the selected channels."""
+ loader = _refs.get("analysis_loader")
+ if not loader or not loader.is_loaded:
+ return
+
+ channels = state.heatmap_combination or []
+ if not channels:
+ return
+
+ level = state.current_hierarchy_level
+ curves = loader.get_subcombination_dilation_curves(channels, level)
+ if not curves:
+ print(f"[dilation-curve] No data for {channels}")
+ return
+
+ print(f"[dilation-curve] Subcombination curves for {' | '.join(channels)} (hierarchy_level={level})")
+ for combo_key, curve in curves.items():
+ is_single = "|" not in combo_key
+ print(f"\n --- {combo_key} ---")
+ if is_single:
+ print(f" {'Dilation':>10} {'Voxels':>12} {'Density':>10}")
+ print(f" {'-'*10} {'-'*12} {'-'*10}")
+ for pt in curve:
+ print(f" {pt['dilation']:>10.1f} {pt['count']:>12} {pt.get('density', 0):>10.6f}")
+ else:
+ print(f" {'Dilation':>10} {'Intersection':>12} {'IoU':>10} {'Overlap':>10}")
+ print(f" {'-'*10} {'-'*12} {'-'*10} {'-'*10}")
+ for pt in curve:
+ print(f" {pt['dilation']:>10.1f} {pt['count']:>12} {pt['iou']:>10.6f} {pt.get('overlap_coeff', 0):>10.6f}")
+
def _filter_combinations_by_channel_selection(combinations, selected_channels):
"""
Filter combinations to only include those whose channels are all
@@ -697,6 +728,7 @@ def update_upset_data():
mapped_combinations.append({
"channels": combination.channels,
"iou": combination.iou,
+ "overlap_coeff": combination.overlap_coeff,
})
state.upset_data = mapped_combinations
@@ -750,6 +782,7 @@ def update_upset_data_local():
mapped_combinations.append({
"channels": combination.channels,
"iou": combination.iou,
+ "overlap_coeff": combination.overlap_coeff,
})
state.upset_data_local = mapped_combinations
@@ -811,6 +844,50 @@ def update_bar_data_local():
state.bar_data_local = local_bar_data
print(f"[callbacks] Bar local data updated: {len(local_bar_data)} channels")
+
+ def update_dilation_data():
+ """Update dilation curve data for the line plot."""
+ loader = _refs.get("analysis_loader")
+ if not loader or not loader.is_loaded:
+ print("[callbacks] Cannot update dilation data - loader not ready")
+ state.dilation_data = {}
+ return
+
+ print(
+ f"[callbacks] Updating dilation data: mode={state.dilation_view_mode}, level={state.current_hierarchy_level}")
+
+ active_ids = state.active_channels or []
+ if not active_ids:
+ state.dilation_data = {}
+ print("[callbacks] Dilation data cleared (no channels selected)")
+ return
+
+ channels_list = state.channels or []
+ id_to_name = {ch["id"]: ch["name"] for ch in channels_list}
+ selected = [id_to_name[ch_id] for ch_id in active_ids if ch_id in id_to_name]
+
+ if not selected:
+ state.dilation_data = {}
+ return
+
+ view_mode = getattr(state, "dilation_view_mode", "single")
+ level = getattr(state, "current_hierarchy_level", 0)
+
+ dilation_curves = loader.get_subcombination_dilation_curves(selected, hierarchy_level=level)
+
+ result = {}
+ if view_mode == "single":
+ for curve_key in dilation_curves:
+ if "|" not in curve_key:
+ result[curve_key] = dilation_curves[curve_key]
+
+ else:
+ for curve_key in dilation_curves:
+ if "|" in curve_key:
+ result[curve_key] = dilation_curves[curve_key]
+ state.dilation_data = result
+
+ print(f"[callbacks] Dilation data updated: keys={list(state.dilation_data.keys())}")
def reset_camera():
"""Reset camera to initial position (from when data was first loaded). Use after opening a Bookmark to return to default view."""
@@ -1817,6 +1894,7 @@ def generate_pdf_report(report_data=None):
ctrl.load_analysis_file = load_analysis_file
ctrl.update_heatmap = update_heatmap
ctrl.update_heatmap_combinations = update_heatmap_combinations
+ ctrl.print_dilation_curve = print_dilation_curve
ctrl.toggle_channel = toggle_channel
ctrl.update_active_channels = update_active_channels
ctrl.reset_camera = reset_camera
@@ -1831,6 +1909,7 @@ def generate_pdf_report(report_data=None):
ctrl.update_upset_data_local = update_upset_data_local
ctrl.update_bar_data = update_bar_data
ctrl.update_bar_data_local = update_bar_data_local
+ ctrl.update_dilation_data = update_dilation_data
ctrl.chatbot_login = chatbot_login
ctrl.biomni_add_data = biomni_add_data
ctrl.biomni_upload_file = biomni_upload_file
diff --git a/src/bioset/ui/components/right_drawer.py b/src/bioset/ui/components/right_drawer.py
index 187a8a6..6c5e6ba 100644
--- a/src/bioset/ui/components/right_drawer.py
+++ b/src/bioset/ui/components/right_drawer.py
@@ -134,6 +134,7 @@ def right_drawer(state, ctrl):
:dataLocal="upset_data_local"
:channelData="channels"
:view-mode="upset_view_mode"
+ :metric="upset_metric"
:offset="upset_expanded_offset"
:limit="upset_expanded_limit"
:width="1200"
@@ -150,6 +151,28 @@ def right_drawer(state, ctrl):
vuetify.VDivider()
with vuetify.VCardText():
+ with html.Div(classes="mb-4 mt-2"):
+ html.Div("Metric", classes="text-caption mb-2 text-left", style="color: white;")
+ with html.Div(classes="d-flex justify-space-between", style="width: 100%; gap: 8px;"):
+ vuetify.VBtn(
+ "IoU",
+ click="upset_metric = 'iou'",
+ color=("upset_metric === 'iou' ? 'white' : 'grey darken-3'",),
+ dark=("upset_metric !== 'iou'",),
+ title="Intersection over Union",
+ class_="flex-grow-1 rounded px-4",
+ style="flex: 1;"
+ )
+ vuetify.VBtn(
+ "Overlap Coefficient",
+ click="upset_metric = 'overlap_coeff'",
+ color=("upset_metric === 'overlap_coeff' ? 'white' : 'grey darken-3'",),
+ dark=("upset_metric !== 'overlap_coeff'",),
+ title="Overlap Coefficient",
+ class_="flex-grow-1 rounded px-4",
+ style="flex: 1;"
+ )
+
with html.Div(classes="mb-4 mt-2"):
html.Div("Minimum Number of Channels", classes="text-caption mb-2 text-left", style="color: white;")
with html.Div(classes="d-flex justify-space-between", style="width: 100%; gap: 8px;"):
@@ -214,6 +237,7 @@ def right_drawer(state, ctrl):
:dataLocal="upset_data_local"
:channelData="channels"
:view-mode="upset_view_mode"
+ :metric="upset_metric"
:offset="upset_offset"
:limit="upset_limit"
@click="upset_click = $event"
@@ -402,3 +426,78 @@ def right_drawer(state, ctrl):
/>
"""
)
+
+ vuetify.VDivider()
+
+ with html.Div(classes="px-4 py-3"):
+ html.Div("Dilation Curves", classes="text-overline mb-2 text-center", style="color: white;")
+
+ with html.Div(classes="d-flex justify-center mb-2 align-center"):
+ with vuetify.VBtnToggle(
+ v_model=("dilation_view_mode", "single"),
+ mandatory=True,
+ dense=True,
+ classes="mr-2",
+ style="background: transparent;",
+ ):
+ vuetify.VBtn("Single", value="single", small=True, classes="text-capitalize", outlined=True)
+ vuetify.VBtn("Multiple", value="multiple", small=True, classes="text-capitalize", outlined=True)
+
+ # Filter button
+ with vuetify.VBtn(icon=True, small=True, v_show="dilation_view_mode === 'multiple'",
+ click="dilation_filter_dialog = true"):
+ vuetify.VIcon("mdi-filter-variant", small=True)
+
+ with vuetify.VDialog(v_model=("dilation_filter_dialog",), max_width="900px", scrollable=True):
+ with vuetify.VCard(classes="grey darken-4 white--text"):
+ vuetify.VCardTitle("Settings - Dilation Plot", classes="headline grey darken-3")
+ vuetify.VDivider()
+
+ with vuetify.VCardText():
+ with html.Div(v_show="dilation_view_mode === 'multiple'", classes="mb-4"):
+ html.Div("Intersection Metric", classes="text-overline mb-1", style="color: white;")
+ with html.Div(classes="d-flex justify-space-between", style="width: 100%; gap: 8px;"):
+ vuetify.VBtn(
+ "IoU",
+ click="dilation_metric_multiple = 'iou'",
+ color=("dilation_metric_multiple === 'iou' ? 'white' : 'grey darken-3'",),
+ dark=("dilation_metric_multiple !== 'iou'",),
+ title="Intersection over Union",
+ class_="flex-grow-1 rounded px-4",
+ style="flex: 1;"
+ )
+ vuetify.VBtn(
+ "Overlap Coefficient",
+ click="dilation_metric_multiple = 'overlap_coeff'",
+ color=("dilation_metric_multiple === 'overlap_coeff' ? 'white' : 'grey darken-3'",),
+ dark=("dilation_metric_multiple !== 'overlap_coeff'",),
+ title="Overlap Coefficient",
+ class_="flex-grow-1 rounded px-4",
+ style="flex: 1;"
+ )
+ vuetify.VBtn(
+ "Count",
+ click="dilation_metric_multiple = 'count'",
+ color=("dilation_metric_multiple === 'count' ? 'white' : 'grey darken-3'",),
+ dark=("dilation_metric_multiple !== 'count'",),
+ title="Count",
+ class_="flex-grow-1 rounded px-4",
+ style="flex: 1;"
+ )
+
+ vuetify.VDivider()
+ with vuetify.VCardActions(classes="grey darken-3"):
+ vuetify.VSpacer()
+ vuetify.VBtn("Close", color="surface-variant", click="dilation_filter_dialog = false")
+
+ # Vue component for Line plot
+ vuetify.Template(
+ """
+