Skip to content

Commit 80fae69

Browse files
committed
add and test dataset iterator
1 parent 7ce8d57 commit 80fae69

2 files changed

Lines changed: 97 additions & 14 deletions

File tree

gnss_lib_py/parsers/android.py

Lines changed: 73 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,11 @@ def solve_kaggle_baseline(navdata):
504504
navdata : gnss_lib_py.parsers.android.AndroidDerived2022
505505
Instance of the AndroidDerived2022 class.
506506
507+
Returns
508+
-------
509+
state_estimate : gnss_lib_py.parsers.navdata.NavData
510+
Baseline state estimate.
511+
507512
"""
508513

509514
columns = ["unix_millis",
@@ -527,37 +532,91 @@ def solve_kaggle_baseline(navdata):
527532

528533
return state_estimate
529534

530-
def prepare_kaggle_submission(state_wls, trip_id):
535+
def prepare_kaggle_submission(state_estimate, trip_id):
531536
"""Converts from gnss_lib_py receiver state to Kaggle submission.
532537
533-
534-
receiver_state : gnss_lib_py.parsers.navdata.NavData
538+
Parameters
539+
----------
540+
state_estimate : gnss_lib_py.parsers.navdata.NavData
535541
Estimated receiver position in latitude and longitude as an
536542
instance of the NavData class with the following
537-
rows: ``lat_*_deg``, ``lon_*_deg``.
538-
tripId : string
543+
rows: ``gps_millis``, ``lat_*_deg``, ``lon_*_deg``.
544+
trip_id : string
539545
Value for the tripId column in kaggle submission which is a
540-
fusion of the data and phone type
546+
fusion of the data and phone type.
541547
542548
Returns
543549
-------
544550
output : gnss_lib_py.parsers.navdata.NavData
545-
NavData structure ready for Kaggle submission
551+
NavData structure ready for Kaggle submission.
546552
547553
"""
548554

549-
state_wls.in_rows("gps_millis")
550-
wildcards = state_wls.find_wildcard_indexes(["lat_*_deg",
555+
state_estimate.in_rows("gps_millis")
556+
wildcards = state_estimate.find_wildcard_indexes(["lat_*_deg",
551557
"lon_*_deg"],max_allow = 1)
552558

553559
output = NavData()
554-
output["tripId"] = np.array([trip_id] * state_wls.shape[1])
555-
output["UnixTimeMillis"] = gps_to_unix_millis(state_wls["gps_millis"])
560+
output["tripId"] = np.array([trip_id] * state_estimate.shape[1])
561+
output["UnixTimeMillis"] = gps_to_unix_millis(state_estimate["gps_millis"])
556562
output.orig_dtypes["UnixTimeMillis"] = np.int64
557-
output["LatitudeDegrees"] = state_wls[wildcards["lat_*_deg"]]
558-
output["LongitudeDegrees"] = state_wls[wildcards["lon_*_deg"]]
563+
output["LatitudeDegrees"] = state_estimate[wildcards["lat_*_deg"]]
564+
output["LongitudeDegrees"] = state_estimate[wildcards["lon_*_deg"]]
559565

560566
output.interpolate("UnixTimeMillis",["LatitudeDegrees",
561567
"LongitudeDegrees"])
562-
563568
return output
569+
570+
def solve_kaggle_dataset(folder_path, solver, *args):
571+
"""Run solver on all kaggle traces.
572+
573+
Additional ``*args`` arguments are passed into the ``solver``
574+
function.
575+
576+
Parameters
577+
----------
578+
folder_path: string or path-like
579+
Path to folder containing all traces (e.g. full path to "train"
580+
or "test" directories.
581+
solver : function
582+
State estimate solver that takes an instance of
583+
AndroidDerived2022 and outputs a state_estimate NavData object.
584+
Additional ``*args`` arguments are passed into this ``solver``
585+
function.
586+
587+
Returns
588+
-------
589+
submission : gnss_lib_py.parsers.navdata.NavData
590+
Full solution submission across all traces. Can then be saved
591+
using submission.to_csv().
592+
593+
"""
594+
595+
# create solution NavData object
596+
solution = NavData()
597+
598+
# iterate through all trace options
599+
for trace_name in sorted(os.listdir(folder_path)):
600+
trace_path = os.path.join(folder_path, trace_name)
601+
# iterate through all phone types
602+
for phone_type in sorted(os.listdir(trace_path)):
603+
data_path = os.path.join(folder_path,trace_name,
604+
phone_type,"device_gnss.csv")
605+
try:
606+
# convert data to Measurement class
607+
derived_data = AndroidDerived2022(data_path)
608+
609+
# compute state estimate using provided solver function
610+
state_estimate = solver(derived_data, *args)
611+
612+
trip_id = "/".join([trace_name,phone_type])
613+
output = prepare_kaggle_submission(state_estimate,
614+
trip_id)
615+
616+
# concatenate solution to previous solutions
617+
solution.concat(navdata=output, inplace=True)
618+
619+
except FileNotFoundError:
620+
continue
621+
622+
return solution

tests/parsers/test_android.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
from gnss_lib_py.parsers import android
1515
from gnss_lib_py.parsers.navdata import NavData
16+
from gnss_lib_py.algorithms.snapshot import solve_wls
17+
from gnss_lib_py.algorithms.gnss_filters import solve_gnss_ekf
1618

1719
# pylint: disable=protected-access
1820

@@ -698,3 +700,25 @@ def test_prepare_kaggle_submission(state_estimate):
698700
expected = np.array([1619735725999,1619735726999,1619735727999,
699701
1619735728999,1619735729999,1619735730999])
700702
np.testing.assert_array_equal(output["UnixTimeMillis"], expected)
703+
704+
def test_solve_kaggle_dataset(root_path):
705+
"""Test kaggle solver.
706+
707+
"""
708+
709+
folder_path = os.path.join(root_path,"..","..")
710+
for solver in [android.solve_kaggle_baseline,
711+
solve_wls,
712+
solve_gnss_ekf,
713+
]:
714+
solution = android.solve_kaggle_dataset(folder_path, solver)
715+
716+
solution.in_rows(["tripId","UnixTimeMillis",
717+
"LatitudeDegrees","LongitudeDegrees"])
718+
719+
assert solution.shape[1] == 6
720+
721+
expected = np.array([1619735725999,1619735726999,1619735727999,
722+
1619735728999,1619735729999,1619735730999])
723+
np.testing.assert_array_equal(solution["UnixTimeMillis"],
724+
expected)

0 commit comments

Comments
 (0)