diff --git a/README.md b/README.md index 8677a95..015275a 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,8 @@ Madrid LTE Zone I (frequencies `f796`, `f1815`, `f2650`) is also supported: ```bash python generate_data.py --steps 5000 --input dataset/madrid-lte-dataset/zoneI --output shared_data/traffic_data.csv ``` +This produces aligned per-second features such as: +`downlink_f796`, `uplink_f796`, `users_f796`, `downlink_f1815`, `uplink_f1815`, `users_f1815`, `downlink_f2650`, `uplink_f2650`, `users_f2650`. ### Train ```bash diff --git a/generate_data.py b/generate_data.py index 1566536..1f1646b 100644 --- a/generate_data.py +++ b/generate_data.py @@ -13,6 +13,83 @@ from oran_sim.splitting import chronological_split +def _is_madrid_zone_layout(root: Path) -> bool: + freq_dirs = [p for p in root.glob("f*") if p.is_dir()] + if not freq_dirs: + return False + return any(list(fd.glob("downlink_*.csv")) for fd in freq_dirs) + + +def _load_madrid_zone_wide(root: Path) -> pd.DataFrame: + freq_dirs = sorted([p for p in root.glob("f*") if p.is_dir()]) + rows = [] + per_freq = {} + + for freq_dir in freq_dirs: + freq = freq_dir.name + dl_path = next(iter(sorted(freq_dir.glob("downlink_*.csv"))), None) + ul_path = next(iter(sorted(freq_dir.glob("uplink_*.csv"))), None) + users_path = next(iter(sorted(freq_dir.glob("users_*.csv"))), None) + if dl_path is None: + continue + + dl = pd.read_csv(dl_path) + dl["second"] = np.floor(pd.to_numeric(dl["timestamp"], errors="coerce")).astype("Int64") + dl["tbs_sum"] = pd.to_numeric(dl["tbs_sum"], errors="coerce") + dl = dl.dropna(subset=["second"]).groupby("second", as_index=False)["tbs_sum"].sum() + dl = dl.rename(columns={"tbs_sum": f"downlink_{freq}"}) + + ul = pd.DataFrame(columns=["second", f"uplink_{freq}"]) + if ul_path is not None: + ul_tmp = pd.read_csv(ul_path) + ul_tmp["second"] = np.floor(pd.to_numeric(ul_tmp["timestamp"], errors="coerce")).astype("Int64") + ul_tmp["tbs_sum"] = pd.to_numeric(ul_tmp["tbs_sum"], errors="coerce") + ul_tmp = ul_tmp.dropna(subset=["second"]).groupby("second", as_index=False)["tbs_sum"].sum() + ul = ul_tmp.rename(columns={"tbs_sum": f"uplink_{freq}"}) + + users = pd.DataFrame(columns=["second", f"users_{freq}"]) + if users_path is not None: + user_tmp = pd.read_csv(users_path) + user_tmp["second"] = np.floor(pd.to_numeric(user_tmp["timestamp"], errors="coerce")).astype("Int64") + user_tmp["user_unique"] = pd.to_numeric(user_tmp["user_unique"], errors="coerce") + user_tmp = user_tmp.dropna(subset=["second"]).groupby("second", as_index=False)["user_unique"].mean() + users = user_tmp.rename(columns={"user_unique": f"users_{freq}"}) + + merged = dl.merge(ul, on="second", how="outer").merge(users, on="second", how="outer").sort_values("second") + per_freq[freq] = merged + rows.extend(merged["second"].dropna().astype(int).tolist()) + + if not per_freq: + raise RuntimeError(f"No usable frequency data found in {root}") + + all_seconds = pd.DataFrame({"second": sorted(set(rows))}) + base = all_seconds.copy() + for freq in sorted(per_freq.keys()): + base = base.merge(per_freq[freq], on="second", how="left") + + for c in base.columns: + if c != "second": + base[c] = pd.to_numeric(base[c], errors="coerce") + base = base.sort_values("second").reset_index(drop=True) + feature_cols = [c for c in base.columns if c != "second"] + base[feature_cols] = base[feature_cols].ffill().fillna(0.0) + + down_cols = [c for c in base.columns if c.startswith("downlink_f")] + up_cols = [c for c in base.columns if c.startswith("uplink_f")] + user_cols = [c for c in base.columns if c.startswith("users_f")] + + base["timestamp"] = base["second"].astype(float) + base["time_ms"] = (base["second"].astype(float) * 1000.0).astype("int64") + base["traffic_load"] = base[down_cols].sum(axis=1) if down_cols else 0.0 + base["num_ues"] = base[user_cols].sum(axis=1) if user_cols else 0.0 + base["ul_buffer_bytes"] = base[up_cols].sum(axis=1) if up_cols else 0.0 + base["dl_buffer_bytes"] = base["traffic_load"] + base["scheduling_policy"] = root.name + base["reservation"] = root.name + + return base + + def _build_target(df: pd.DataFrame, target: str, horizon_steps: int) -> pd.DataFrame: horizon = max(1, horizon_steps) shifted = df["traffic_load"].shift(-horizon) @@ -66,10 +143,16 @@ def main() -> None: args = parser.parse_args() required_rows = args.steps + max(1, int(args.horizon_steps)) - base = load_timeseries_from_kpm(args.input, n_steps=required_rows, verbose=True) + input_path = Path(args.input) + if _is_madrid_zone_layout(input_path): + base = _load_madrid_zone_wide(input_path) + else: + base = load_timeseries_from_kpm(args.input, n_steps=required_rows, verbose=True) base = _build_target(base, args.target, args.horizon_steps) - keep_cols = ["time_ms", "reservation", "traffic_load"] + FEATURE_ORDER + ["target"] + keep_cols = ["timestamp", "time_ms", "reservation", "traffic_load"] + keep_cols += sorted([c for c in base.columns if c.startswith("downlink_f") or c.startswith("uplink_f") or c.startswith("users_f")]) + keep_cols += FEATURE_ORDER + ["target"] for c in keep_cols: if c not in base.columns: base[c] = 0 diff --git a/scripts/train.py b/scripts/train.py index 0988160..2b2ebc4 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -58,9 +58,9 @@ def _load_splits(csv: Path) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, d return train_df, val_df, test_df, split_meta -def _prepare_tabular_preprocessor(features: list[str]) -> ColumnTransformer: - num_features = [c for c in features if c != "scheduling_policy"] - cat_features = [c for c in features if c == "scheduling_policy"] +def _prepare_tabular_preprocessor(features: list[str], train_df: pd.DataFrame) -> ColumnTransformer: + num_features = [c for c in features if pd.api.types.is_numeric_dtype(train_df[c])] + cat_features = [c for c in features if c not in num_features] return ColumnTransformer( transformers=[ ("num", StandardScaler(), num_features), @@ -88,6 +88,9 @@ def main() -> None: train_df, val_df, test_df, split_meta = _load_splits(csv) candidate_features = [c for c in FEATURE_ORDER if c in train_df.columns] + if not candidate_features: + skip = {"target", "time_ms", "timestamp", "reservation"} + candidate_features = [c for c in train_df.columns if c not in skip] importance_df = rank_features_by_importance(train_df, candidate_features, random_state=args.seed) feature_count = args.feature_count if args.feature_count is not None else len(candidate_features) features = select_top_k_features(importance_df, min(feature_count, len(candidate_features))) @@ -201,7 +204,7 @@ def encode(df: pd.DataFrame) -> pd.DataFrame: (out_dir / "metrics.json").write_text(json.dumps(final_metrics, indent=2), encoding="utf-8") pd.DataFrame(epoch_rows).to_csv(out_dir / "epoch_metrics.csv", index=False) else: - pre = _prepare_tabular_preprocessor(features) + pre = _prepare_tabular_preprocessor(features, train_df) model = build_model(args.model, args.seed) pipe = Pipeline([("pre", pre), ("model", model)]) x_train, y_train = train_df[features], train_df["target"].to_numpy() diff --git a/tests/test_madrid_zone_dataset.py b/tests/test_madrid_zone_dataset.py index 6bdd320..7f73900 100644 --- a/tests/test_madrid_zone_dataset.py +++ b/tests/test_madrid_zone_dataset.py @@ -43,4 +43,44 @@ def test_generate_data_with_madrid_zone_i(tmp_path: Path) -> None: df = pd.read_csv(out_csv) assert len(df) == 100 assert {"traffic_load", "num_ues", "target"}.issubset(df.columns) + assert {"downlink_f796", "uplink_f796", "users_f796"}.issubset(df.columns) + assert {"downlink_f1815", "uplink_f1815", "users_f1815"}.issubset(df.columns) + assert {"downlink_f2650", "uplink_f2650", "users_f2650"}.issubset(df.columns) assert df["traffic_load"].sum() > 0 + + +def test_train_pipeline_with_madrid_generated_data(tmp_path: Path) -> None: + csv_path = tmp_path / "traffic_data.csv" + out_dir = tmp_path / "model" + subprocess.run( + [ + sys.executable, + "generate_data.py", + "--steps", + "200", + "--input", + str(ZONE_I), + "--output", + str(csv_path), + ], + check=True, + ) + + subprocess.run( + [ + sys.executable, + "-m", + "scripts.train", + "--csv", + str(csv_path), + "--out_dir", + str(out_dir), + "--model", + "lightweight-32", + "--epochs", + "1", + ], + check=True, + ) + assert (out_dir / "model.joblib").exists() + assert (out_dir / "features.json").exists()