@@ -504,6 +504,11 @@ def solve_kaggle_baseline(navdata):
504504 navdata : gnss_lib_py.parsers.android.AndroidDerived2022
505505 Instance of the AndroidDerived2022 class.
506506
507+ Returns
508+ -------
509+ state_estimate : gnss_lib_py.parsers.navdata.NavData
510+ Baseline state estimate.
511+
507512 """
508513
509514 columns = ["unix_millis" ,
@@ -527,37 +532,91 @@ def solve_kaggle_baseline(navdata):
527532
528533 return state_estimate
529534
530- def prepare_kaggle_submission (state_wls , trip_id ):
535+ def prepare_kaggle_submission (state_estimate , trip_id ):
531536 """Converts from gnss_lib_py receiver state to Kaggle submission.
532537
533-
534- receiver_state : gnss_lib_py.parsers.navdata.NavData
538+ Parameters
539+ ----------
540+ state_estimate : gnss_lib_py.parsers.navdata.NavData
535541 Estimated receiver position in latitude and longitude as an
536542 instance of the NavData class with the following
537- rows: ``lat_*_deg``, ``lon_*_deg``.
538- tripId : string
543+ rows: ``gps_millis``, `` lat_*_deg``, ``lon_*_deg``.
544+ trip_id : string
539545 Value for the tripId column in kaggle submission which is a
540- fusion of the data and phone type
546+ fusion of the data and phone type.
541547
542548 Returns
543549 -------
544550 output : gnss_lib_py.parsers.navdata.NavData
545- NavData structure ready for Kaggle submission
551+ NavData structure ready for Kaggle submission.
546552
547553 """
548554
549- state_wls .in_rows ("gps_millis" )
550- wildcards = state_wls .find_wildcard_indexes (["lat_*_deg" ,
555+ state_estimate .in_rows ("gps_millis" )
556+ wildcards = state_estimate .find_wildcard_indexes (["lat_*_deg" ,
551557 "lon_*_deg" ],max_allow = 1 )
552558
553559 output = NavData ()
554- output ["tripId" ] = np .array ([trip_id ] * state_wls .shape [1 ])
555- output ["UnixTimeMillis" ] = gps_to_unix_millis (state_wls ["gps_millis" ])
560+ output ["tripId" ] = np .array ([trip_id ] * state_estimate .shape [1 ])
561+ output ["UnixTimeMillis" ] = gps_to_unix_millis (state_estimate ["gps_millis" ])
556562 output .orig_dtypes ["UnixTimeMillis" ] = np .int64
557- output ["LatitudeDegrees" ] = state_wls [wildcards ["lat_*_deg" ]]
558- output ["LongitudeDegrees" ] = state_wls [wildcards ["lon_*_deg" ]]
563+ output ["LatitudeDegrees" ] = state_estimate [wildcards ["lat_*_deg" ]]
564+ output ["LongitudeDegrees" ] = state_estimate [wildcards ["lon_*_deg" ]]
559565
560566 output .interpolate ("UnixTimeMillis" ,["LatitudeDegrees" ,
561567 "LongitudeDegrees" ])
562-
563568 return output
569+
570+ def solve_kaggle_dataset (folder_path , solver , * args ):
571+ """Run solver on all kaggle traces.
572+
573+ Additional ``*args`` arguments are passed into the ``solver``
574+ function.
575+
576+ Parameters
577+ ----------
578+ folder_path: string or path-like
579+ Path to folder containing all traces (e.g. full path to "train"
580+ or "test" directories.
581+ solver : function
582+ State estimate solver that takes an instance of
583+ AndroidDerived2022 and outputs a state_estimate NavData object.
584+ Additional ``*args`` arguments are passed into this ``solver``
585+ function.
586+
587+ Returns
588+ -------
589+ submission : gnss_lib_py.parsers.navdata.NavData
590+ Full solution submission across all traces. Can then be saved
591+ using submission.to_csv().
592+
593+ """
594+
595+ # create solution NavData object
596+ solution = NavData ()
597+
598+ # iterate through all trace options
599+ for trace_name in sorted (os .listdir (folder_path )):
600+ trace_path = os .path .join (folder_path , trace_name )
601+ # iterate through all phone types
602+ for phone_type in sorted (os .listdir (trace_path )):
603+ data_path = os .path .join (folder_path ,trace_name ,
604+ phone_type ,"device_gnss.csv" )
605+ try :
606+ # convert data to Measurement class
607+ derived_data = AndroidDerived2022 (data_path )
608+
609+ # compute state estimate using provided solver function
610+ state_estimate = solver (derived_data , * args )
611+
612+ trip_id = "/" .join ([trace_name ,phone_type ])
613+ output = prepare_kaggle_submission (state_estimate ,
614+ trip_id )
615+
616+ # concatenate solution to previous solutions
617+ solution .concat (navdata = output , inplace = True )
618+
619+ except FileNotFoundError :
620+ continue
621+
622+ return solution
0 commit comments