Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 236 additions & 34 deletions src/bioset/analysis/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sqlite3
import tempfile
from dataclasses import dataclass, field
from itertools import combinations as iter_combinations
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -73,6 +74,8 @@ def __init__(self):
self.metadata: Optional[AnalysisMetadata] = None
self._loaded = False
self._total_tiles_cache: dict[int, int] = {} # level -> total tile count
self._channel_totals_voxels_cache: dict[
tuple[str, float, int], int] = {} # (channel, dilation, level) -> total voxels

@property
def is_loaded(self) -> bool:
Expand Down Expand Up @@ -230,8 +233,9 @@ def get_top_combinations(
Get top N combinations by aggregated IoU.

Aggregates across all tiles:
global_iou = SUM(total_count) / SUM(total_union)

global_iou = SUM(total_count) / SUM(total_union)
global_overlap_coeff = SUM(total_count) / MIN(SUM(ch_a), SUM(ch_b), ...)

Sorted by global_iou DESC.
Self-pairs (e.g. CD8|CD8) are excluded.
"""
Expand Down Expand Up @@ -266,10 +270,16 @@ def get_top_combinations(
sum_union = row["sum_union"] or 1
agg_iou = sum_inter / sum_union if sum_union > 0 else 0.0

# Calculate overlap coefficient from channel_stats
overlap_coeff = self._compute_overlap_coeff(
channels, sum_inter, dilation, hierarchy_level
)

results.append(CombinationData(
channels=channels,
total_count=row["agg_count"] or 0,
iou=agg_iou,
overlap_coeff=overlap_coeff,
))

if len(results) >= limit:
Expand All @@ -287,6 +297,7 @@ def get_filtered_combinations(
) -> list[CombinationData]:
"""
Get combinations containing ANY of the specified channels, aggregated by IoU.
Also computes aggregated overlap coefficient.
"""
if not self.is_loaded or not channel_filter:
return []
Expand Down Expand Up @@ -348,17 +359,55 @@ def get_filtered_combinations(
sum_union = row["sum_union"] or 1
agg_iou = sum_inter / sum_union if sum_union > 0 else 0.0

# Calculate overlap coefficient from channel_stats
overlap_coeff = self._compute_overlap_coeff(
channels, sum_inter, dilation, hierarchy_level
)

results.append(CombinationData(
channels=channels,
total_count=row["agg_count"] or 0,
iou=agg_iou,
overlap_coeff=overlap_coeff,
))

if len(results) >= limit:
break

return results

def _compute_overlap_coeff(
self,
channels: list[str],
sum_inter: int,
dilation: float,
hierarchy_level: int,
) -> float:
"""
Compute aggregated overlap coefficient.

overlap_coeff = SUM(intersection) / MIN(SUM(ch_a), SUM(ch_b), ...)

Queries channel_stats to get total voxels per channel.
"""
if not channels or sum_inter == 0:
return 0.0

# Get total voxels for each channel
channel_totals = []
for ch in channels:
total = self.get_channel_total_voxels(ch, dilation, hierarchy_level)
if total > 0:
channel_totals.append(total)

if not channel_totals:
return 0.0

min_voxels = min(channel_totals)
return sum_inter / min_voxels if min_voxels > 0 else 0.0



# ──────────────────────────────────────────────
# Bar chart: coverage percentage
# ──────────────────────────────────────────────
Expand Down Expand Up @@ -595,46 +644,190 @@ def get_tile_combinations(
# Dilation curve
# ──────────────────────────────────────────────

def get_dilation_curve(
def get_subcombination_dilation_curves(
self,
channels: list[str],
hierarchy_level: int,
) -> list[dict]:
"""Get aggregated IoU across all dilations for a combination."""
if not self.is_loaded:
return []
hierarchy_level: int = 0,
) -> dict[str, list[dict]]:
"""
Get dilation curves for all subcombinations of the given channels.

For channels ["A", "B", "C"], attempts to find curves for:
- Single channels: A, B, C
- Pairs: A|B, A|C, B|C (if they exist in the database)
- Triple: A|B|C (if it exists)

Only returns subcombinations that actually exist in the database.

Returns:
Dict mapping channel string (e.g., "A|B") to list of dilation points:
{
"A": [{"dilation": 0.0, "count": ..., "iou": 1.0, "overlap_coeff": 1.0, "density": ...}, ...],
"A|B": [{"dilation": 0.0, "count": ..., "iou": ..., "overlap_coeff": ..., "density": ...}, ...],
...
}

Each dilation point dict contains:
- dilation: float
- count: int (voxel count or intersection count)
- iou: float
- overlap_coeff: float
- density: float (percentage of total volume)
"""
if not self.is_loaded or not channels:
return {}

# Get channel ordering for consistent key generation
channel_order = self.metadata.channels if self.metadata else []
sorted_channels = sorted(
channels,
key=lambda c: channel_order.index(c) if c in channel_order else 999
)
channels_str = "|".join(sorted_channels)

cursor = self._conn.execute('''
SELECT
dilation,
SUM(total_count) as sum_inter,
SUM(total_union) as sum_union
FROM combinations
WHERE channels = ? AND hierarchy_level = ?
GROUP BY dilation
ORDER BY dilation
''', (channels_str, hierarchy_level))
def sort_channels(ch_list: list[str]) -> list[str]:
return sorted(
ch_list,
key=lambda c: channel_order.index(c) if c in channel_order else 999
)

results = []
for row in cursor:
sum_inter = row["sum_inter"] or 0
sum_union = row["sum_union"] or 1
agg_iou = sum_inter / sum_union if sum_union > 0 else 0.0
results.append({
"dilation": row["dilation"],
"count": sum_inter,
"iou": agg_iou,
})
def make_key(ch_list: list[str]) -> str:
return "|".join(sort_channels(ch_list))

results = {}

# Generate all subcombinations of size 1 to len(channels)
for size in range(1, len(channels) + 1):
for combo in iter_combinations(channels, size):
combo_list = list(combo)
combo_key = make_key(combo_list)

# Get dilation curve for this subcombination
if size == 1:
curve = self._get_single_channel_dilation_curve(
combo_list[0], hierarchy_level
)
else:
curve = self._get_multi_channel_dilation_curve_full(
combo_list, hierarchy_level
)

# Only include if data exists
if curve:
results[combo_key] = curve
Comment on lines +694 to +712
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_subcombination_dilation_curves() generates and queries all subcombinations of the provided channels (2^N−1). If channels comes from active_channels, this can grow exponentially and cause noticeable UI stalls (many SQL queries, repeated per hierarchy change/view toggle). Consider enforcing an upper bound (e.g., max N), limiting to specific sizes (singles + pairs), or restructuring to batch-query curves more efficiently.

Copilot uses AI. Check for mistakes.

return results

def _get_single_channel_dilation_curve(
self,
channel: str,
hierarchy_level: int,
) -> list[dict]:
"""Get voxel count across dilations for a single channel."""

# Get total volume for density calculation
bounds = self.metadata.volume_bounds
total_volume = (
(bounds["x"][1] - bounds["x"][0]) *
(bounds["y"][1] - bounds["y"][0]) *
(bounds["z"][1] - bounds["z"][0])
)

try:
cursor = self._conn.execute('''
SELECT
dilation,
SUM(voxel_count) as total_voxels
FROM channel_stats
WHERE channel = ? AND hierarchy_level = ?
GROUP BY dilation
ORDER BY dilation
''', (channel, hierarchy_level))

results = []
for row in cursor:
total_voxels = row["total_voxels"] or 0
density = (total_voxels / total_volume * 100) if total_volume > 0 else 0.0

results.append({
"dilation": row["dilation"],
"count": total_voxels,
"iou": 1.0, # Self-overlap is always 1.0
"overlap_coeff": 1.0, # Self-overlap is always 1.0
"density": density,
})

return results

except sqlite3.OperationalError as e:
if "no such table: channel_stats" in str(e):
return []
raise


def _get_multi_channel_dilation_curve_full(
self,
channels: list[str],
hierarchy_level: int,
) -> list[dict]:
"""
Get IoU, overlap coefficient, and density across dilations for a multi-channel combination.

Returns empty list if the combination doesn't exist in the database.
"""
# Get total volume for density calculation
bounds = self.metadata.volume_bounds
total_volume = (
(bounds["x"][1] - bounds["x"][0]) *
(bounds["y"][1] - bounds["y"][0]) *
(bounds["z"][1] - bounds["z"][0])
)

channel_order = self.metadata.channels if self.metadata else []
sorted_channels = sorted(
channels,
key=lambda c: channel_order.index(c) if c in channel_order else 999
)
channels_str = "|".join(sorted_channels)

cursor = self._conn.execute('''
SELECT
dilation,
SUM(total_count) as sum_inter,
SUM(total_union) as sum_union
FROM combinations
WHERE channels = ? AND hierarchy_level = ?
GROUP BY dilation
ORDER BY dilation
''', (channels_str, hierarchy_level))

rows = cursor.fetchall()

if not rows:
return [] # Combination doesn't exist in database

results = []
for row in rows:
dilation = row["dilation"]
sum_inter = row["sum_inter"] or 0
sum_union = row["sum_union"] or 1

# IoU
iou = sum_inter / sum_union if sum_union > 0 else 0.0

# Overlap coefficient: intersection / min(channel_voxels)
overlap_coeff = self._compute_overlap_coeff(
channels, sum_inter, dilation, hierarchy_level
)

# Density of intersection
density = (sum_inter / total_volume * 100) if total_volume > 0 else 0.0

results.append({
"dilation": dilation,
"count": sum_inter,
"iou": iou,
"overlap_coeff": overlap_coeff,
"density": density,
})

return results

Comment on lines +764 to +830
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_multi_channel_dilation_curve_full has inconsistent indentation (parameters and body indented more than other methods in this file), which hurts readability and makes diffs noisier. Align the indentation with the surrounding class methods (4 spaces for method body indentation level).

Suggested change
self,
channels: list[str],
hierarchy_level: int,
) -> list[dict]:
"""
Get IoU, overlap coefficient, and density across dilations for a multi-channel combination.
Returns empty list if the combination doesn't exist in the database.
"""
# Get total volume for density calculation
bounds = self.metadata.volume_bounds
total_volume = (
(bounds["x"][1] - bounds["x"][0]) *
(bounds["y"][1] - bounds["y"][0]) *
(bounds["z"][1] - bounds["z"][0])
)
channel_order = self.metadata.channels if self.metadata else []
sorted_channels = sorted(
channels,
key=lambda c: channel_order.index(c) if c in channel_order else 999
)
channels_str = "|".join(sorted_channels)
cursor = self._conn.execute('''
SELECT
dilation,
SUM(total_count) as sum_inter,
SUM(total_union) as sum_union
FROM combinations
WHERE channels = ? AND hierarchy_level = ?
GROUP BY dilation
ORDER BY dilation
''', (channels_str, hierarchy_level))
rows = cursor.fetchall()
if not rows:
return [] # Combination doesn't exist in database
results = []
for row in rows:
dilation = row["dilation"]
sum_inter = row["sum_inter"] or 0
sum_union = row["sum_union"] or 1
# IoU
iou = sum_inter / sum_union if sum_union > 0 else 0.0
# Overlap coefficient: intersection / min(channel_voxels)
overlap_coeff = self._compute_overlap_coeff(
channels, sum_inter, dilation, hierarchy_level
)
# Density of intersection
density = (sum_inter / total_volume * 100) if total_volume > 0 else 0.0
results.append({
"dilation": dilation,
"count": sum_inter,
"iou": iou,
"overlap_coeff": overlap_coeff,
"density": density,
})
return results
self,
channels: list[str],
hierarchy_level: int,
) -> list[dict]:
"""
Get IoU, overlap coefficient, and density across dilations for a multi-channel combination.
Returns empty list if the combination doesn't exist in the database.
"""
# Get total volume for density calculation
bounds = self.metadata.volume_bounds
total_volume = (
(bounds["x"][1] - bounds["x"][0]) *
(bounds["y"][1] - bounds["y"][0]) *
(bounds["z"][1] - bounds["z"][0])
)
channel_order = self.metadata.channels if self.metadata else []
sorted_channels = sorted(
channels,
key=lambda c: channel_order.index(c) if c in channel_order else 999
)
channels_str = "|".join(sorted_channels)
cursor = self._conn.execute('''
SELECT
dilation,
SUM(total_count) as sum_inter,
SUM(total_union) as sum_union
FROM combinations
WHERE channels = ? AND hierarchy_level = ?
GROUP BY dilation
ORDER BY dilation
''', (channels_str, hierarchy_level))
rows = cursor.fetchall()
if not rows:
return [] # Combination doesn't exist in database
results = []
for row in rows:
dilation = row["dilation"]
sum_inter = row["sum_inter"] or 0
sum_union = row["sum_union"] or 1
# IoU
iou = sum_inter / sum_union if sum_union > 0 else 0.0
# Overlap coefficient: intersection / min(channel_voxels)
overlap_coeff = self._compute_overlap_coeff(
channels, sum_inter, dilation, hierarchy_level
)
# Density of intersection
density = (sum_inter / total_volume * 100) if total_volume > 0 else 0.0
results.append({
"dilation": dilation,
"count": sum_inter,
"iou": iou,
"overlap_coeff": overlap_coeff,
"density": density,
})
return results

Copilot uses AI. Check for mistakes.
# ──────────────────────────────────────────────
# Channel voxel totals (for reference)
# ──────────────────────────────────────────────
Expand All @@ -645,14 +838,22 @@ def get_channel_total_voxels(
"""Get total voxels for a single channel across all tiles."""
if not self.is_loaded:
return 0

cache_entry = (channel, dilation, level)
if cache_entry in self._channel_totals_voxels_cache:
return self._channel_totals_voxels_cache[cache_entry]

try:
cursor = self._conn.execute('''
SELECT SUM(voxel_count) as total
FROM channel_stats
WHERE channel = ? AND dilation = ? AND hierarchy_level = ?
''', (channel, dilation, level))
row = cursor.fetchone()
return row["total"] if row and row["total"] else 0
total = row["total"] if row and row["total"] else 0

self._channel_totals_voxels_cache[cache_entry] = total
return total
except sqlite3.OperationalError as e:
if "no such table: channel_stats" in str(e):
return 0
Expand Down Expand Up @@ -697,6 +898,7 @@ def close(self):
self._loaded = False
self.metadata = None
self._total_tiles_cache.clear()
self._channel_totals_voxels_cache.clear()

def __del__(self):
self.close()
Loading
Loading