Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/post_processing/dataclass/data_aplose.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,10 @@ def __init__(
).reset_index(drop=True)
self.annotators = sorted(set(self.df["annotator"])) if df is not None else None
self.labels = sorted(set(self.df["annotation"])) if df is not None else None
self.begin = min(self.df["start_datetime"]) if begin is None else begin
self.end = max(self.df["end_datetime"]) if end is None else end
self.begin = (
min(self.df["start_datetime"], default=None) if begin is None else begin
)
self.end = max(self.df["end_datetime"], default=None) if end is None else end
self.dataset = sorted(set(self.df["dataset"])) if df is not None else None
self.lat = None
self.lon = None
Expand Down Expand Up @@ -595,8 +597,7 @@ def reshape(self, begin: Timestamp = None, end: Timestamp = None) -> DataAplose:
]

if self.df.empty:
msg = "DataFrame is empty after reshaping."
raise ValueError(msg)
return self

self.dataset = get_dataset(self.df)
self.labels = get_labels(self.df)
Expand Down
13 changes: 9 additions & 4 deletions src/post_processing/dataclass/detection_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Literal

import numpy as np
import yaml
from pandas import Timedelta, Timestamp

Expand Down Expand Up @@ -38,6 +39,7 @@ class DetectionFilter:
score: float | None = None
box: bool = False
filename_format: str = None
confidence: float = None

def __getitem__(self, key: str):
"""Return the value of the given key."""
Expand Down Expand Up @@ -88,10 +90,13 @@ def from_dict(
filters = []
for detection_file, filters_dict in parameters.items():
df_preview = read_dataframe(Path(detection_file), rows=5)
filters_dict["timebin_origin"] = Timedelta(
max(df_preview["end_time"]),
"s",
)
if df_preview.empty:
filters_dict["timebin_origin"] = np.nan
else:
filters_dict["timebin_origin"] = Timedelta(
max(df_preview["end_time"]),
"s",
)
filters_dict["detection_file"] = Path(detection_file)
if filters_dict.get("timebin_new"):
filters_dict["timebin_new"] = Timedelta(
Expand Down
72 changes: 33 additions & 39 deletions src/post_processing/utils/filtering_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ def filter_strong_detection(
"""
if "type" in df.columns:
df = df[df["type"] == "WEAK"]
elif "is_box" in df.columns:
df = df[df["is_box"] == 0]
else:
msg = "Could not determine annotation type."
raise ValueError(msg)
Expand Down Expand Up @@ -114,15 +112,9 @@ def filter_by_time(
"""
if begin is not None:
df = df[df["start_datetime"] >= begin]
if df.empty:
msg = f"No detection found after '{begin}'."
raise ValueError(msg)

if end is not None:
df = df[df["end_datetime"] <= end]
if df.empty:
msg = f"No detection found before '{end}'."
raise ValueError(msg)

return df

Expand Down Expand Up @@ -218,47 +210,43 @@ def filter_by_freq(

"""
if f_min is not None:
df = df[df["start_frequency"] >= f_min]
df = df[df["min_frequency"] >= f_min]
if df.empty:
msg = f"No detection found above {f_min}Hz."
raise ValueError(msg)

if f_max is not None:
df = df[df["end_frequency"] <= f_max]
df = df[df["max_frequency"] <= f_max]
if df.empty:
msg = f"No detection found below {f_max}Hz."
raise ValueError(msg)
return df


def filter_by_score(df: DataFrame, score: float) -> DataFrame:
"""Filter detections by confidence score.
def filter_by_confidence(df: DataFrame, confidence: float) -> DataFrame:
"""Filter detections by confidence.

Parameters
----------
df : DataFrame
APLOSE-formatted DataFrame containing a 'score' column.
score : float
The minimum confidence score threshold (inclusive).
APLOSE-formatted DataFrame containing a 'confidence' column.
confidence : float
The minimum confidence threshold (inclusive).

Returns
-------
DataFrame
Filtered DataFrame containing only detections with score >= min_score.
Filtered DataFrame containing only detections with confidence >= min_confidence.

"""
if not score:
if not confidence:
return df

if "score" not in df.columns:
msg = "'score' column not present if DataFrame."
if "confidence" not in df.columns:
msg = "'confidence' column not present if DataFrame."
raise ValueError(msg)

df = df[df["score"] >= score]
if df.empty:
msg = f"No detection found with score above {score}."
raise ValueError(msg)
return df
return df[df["confidence"] >= confidence]


def read_dataframe(file: Path, rows: int | None = None) -> DataFrame:
Expand All @@ -278,36 +266,40 @@ def read_dataframe(file: Path, rows: int | None = None) -> DataFrame:
)


def get_annotators(df: DataFrame) -> list[str]:
def get_annotators(df: DataFrame) -> str | list[str]:
"""Return the annotator list of APLOSE DataFrame."""
if len(df) == 1:
return df["annotator"][0]
if df.empty:
return []
annotators = sorted(set(df["annotator"]))
return annotators if len(annotators) > 1 else annotators[0]


def get_labels(df: DataFrame) -> str | list[str]:
"""Return the label list of APLOSE DataFrame."""
if len(df) == 1:
return df["annotation"][0]
if df.empty:
return []
labels = sorted(set(df["annotation"]))
return labels if len(labels) > 1 else labels[0]


def get_max_freq(df: DataFrame) -> float:
"""Return the maximum frequency of APLOSE DataFrame."""
return df["end_frequency"].max()
if df.empty:
return []
return df["max_frequency"].max()


def get_max_time(df: DataFrame) -> float:
"""Return the maximum time of APLOSE DataFrame."""
if df.empty:
return []
return df["end_time"].max()


def get_dataset(df: DataFrame) -> str | list[str]:
"""Return dataset list of APLOSE DataFrame."""
if len(df) == 1:
return df["dataset"][0]
if df.empty:
return []
datasets = sorted(set(df["dataset"]))
return datasets if len(datasets) > 1 else datasets[0]

Expand Down Expand Up @@ -443,8 +435,8 @@ def _create_result_dataframe(
"filename": file_vector,
"start_time": [0] * len(file_vector),
"end_time": [timebin_new.total_seconds()] * len(file_vector),
"start_frequency": [0] * len(file_vector),
"end_frequency": [max_freq] * len(file_vector),
"min_frequency": [0] * len(file_vector),
"max_frequency": [max_freq] * len(file_vector),
"annotation": [label] * len(file_vector),
"annotator": [annotator] * len(file_vector),
"start_datetime": start_datetime,
Expand Down Expand Up @@ -545,8 +537,7 @@ def reshape_timebin(

"""
if df.empty:
msg = "DataFrame is empty"
raise ValueError(msg)
return df

if not timebin_new:
return df
Expand Down Expand Up @@ -648,13 +639,17 @@ def load_detections(filters: DetectionFilter) -> DataFrame:

"""
df = read_dataframe(filters.detection_file)

if df.empty:
return df

if filters.box:
df = filter_strong_detection(df)
df = filter_by_time(df, filters.begin, filters.end)
df = filter_by_annotator(df, annotator=filters.annotator)
df = filter_by_label(df, label=filters.annotation)
df = filter_by_freq(df, filters.f_min, filters.f_max)
df = filter_by_score(df, filters.score)
df = filter_by_confidence(df, filters.confidence)
filename_ts = get_filename_timestamps(df, filters.filename_format)
df = reshape_timebin(
df,
Expand Down Expand Up @@ -733,14 +728,13 @@ def add_weak_detection(
"start_time": 0,
"end_time": max_time.total_seconds(),
"min_frequency": 0,
"start_frequency": 0,
"max_frequency": max_freq,
"end_frequency": max_freq,
"annotation": lbl,
"annotator": ant,
"start_datetime": strftime_osmose_format(start_datetime),
"end_datetime": strftime_osmose_format(end_datetime),
"type": "WEAK",
"confidence": None,
})
new_row_df = DataFrame([new_row])
df = concat([df, new_row_df], ignore_index=True)
Expand Down
Loading