From 71a038902cd0ea4b07efba7808b5b98123f5aeb1 Mon Sep 17 00:00:00 2001 From: suyash469 Date: Tue, 3 Feb 2026 23:57:28 +0530 Subject: [PATCH] feat: Improve data loading and add unit tests for data_loader --- src/data_loader.py | 22 ++++++++++++---------- test/tests_data_loader.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 10 deletions(-) create mode 100644 test/tests_data_loader.py diff --git a/src/data_loader.py b/src/data_loader.py index 100aa05..0859854 100644 --- a/src/data_loader.py +++ b/src/data_loader.py @@ -7,8 +7,6 @@ Licensed under GNU LGPL.3, see LICENCE file ''' - - import os from typing import Optional, Union, Any import pandas as pd @@ -35,24 +33,28 @@ def load_data_msci(path: str = None, n: int = 24) -> dict[str, pd.DataFrame]: '''Loads MSCI daily returns data from 1999-01-01 to 2023-04-18''' path = os.path.join(os.getcwd(), f'data{os.sep}') if path is None else path - # Load msci country index return series + + # --- FILE 1: MSCI Country Indices --- df = pd.read_csv(os.path.join(path, 'msci_country_indices.csv'), - sep=';', + sep=',', # FIXED: Separator is comma index_col=0, header=0, parse_dates=True) - df.index = pd.to_datetime(df.index, format='%d/%m/%Y') + + # FIXED: Date format uses dashes + df.index = pd.to_datetime(df.index, format='%d-%m-%Y') + series_id = df.columns[0:n] X = df[series_id] - # Load msci world index return series + # --- FILE 2: World Index (NDDLWI) --- y = pd.read_csv(f'{path}NDDLWI.csv', - sep=';', + sep=',', index_col=0, header=0, parse_dates=True) - y.index = pd.to_datetime(y.index, format='%d/%m/%Y') - - return {'return_series': X, 'bm_series': y} + # FIXED: Date format uses dashes here too (Line 55 fixed) + y.index = pd.to_datetime(y.index, format='%d-%m-%Y') + return {'return_series': X, 'bm_series': y} \ No newline at end of file diff --git a/test/tests_data_loader.py b/test/tests_data_loader.py new file mode 100644 index 0000000..a4fff35 --- /dev/null +++ b/test/tests_data_loader.py @@ -0,0 +1,34 @@ +import sys +import os +import unittest +import pandas as pd +import numpy as np + +sys.path.insert(1, 'src') + +from data_loader import load_data_msci + +class TestDataLoader(unittest.TestCase): + + def setUp(self): + # This method is run before each test + self.data_path = os.path.join(os.getcwd(), 'data/') + + def test_load_data_msci(self): + # Test if data can be loaded without errors + try: + data = load_data_msci(self.data_path) + self.assertIsNotNone(data) + self.assertIsInstance(data, dict) + self.assertIn('return_series', data) + self.assertIn('bm_series', data) + self.assertIsInstance(data['return_series'], pd.DataFrame) + self.assertIsInstance(data['bm_series'], pd.DataFrame) + self.assertFalse(data['return_series'].empty) + self.assertFalse(data['bm_series'].empty) + print("\nSuccessfully loaded MSCI data.") + except Exception as e: + self.fail(f"load_data_msci failed with an error: {e}") + +if __name__ == '__main__': + unittest.main()