Skip to content

Commit 3c19500

Browse files
Merge pull request #97 from Stanford-NavLab/derek/sort-where
Derek/sort where
2 parents e15e267 + 0eb91dc commit 3c19500

3 files changed

Lines changed: 151 additions & 87 deletions

File tree

gnss_lib_py/parsers/navdata.py

Lines changed: 111 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -265,20 +265,40 @@ def concat(self, navdata=None, axis=1, inplace=False):
265265
return new_navdata
266266

267267
def where(self, key_idx, value, condition="eq"):
268-
"""Return NavData where conditions are met for the given row
268+
"""Return NavData where conditions are met for the given row.
269+
270+
For string rows, only the "eq" and "neq" conditions are valid.
271+
The "value" argument can contain either a string, np.nan or an
272+
array-like object of strings. If an array-like object of strings
273+
is passed in then np.isin() is used to check the condition
274+
meaning that the returned subset will contain one of the values
275+
in the "value" array-like object for the "eq" condition or none
276+
of the values in the "value" array-like object for the "neq"
277+
condition.
278+
279+
For non-string rows, all valid conditions are listed in the
280+
"condition" argument description. The "value" argument can either
281+
contain a numeric or an array-like object of numerics for both
282+
the "eq" and "neq" conditions.
283+
If an array-like object is passed then the returned subset will
284+
contain one of the values in the "value" array-like object for
285+
the "eq" condition or none of the values in the "value"
286+
array-like object for the "neq" condition.
287+
For the "between" condition, the two limit values must be passed
288+
into the "value" argument as an array-like object.
269289
270290
Parameters
271291
----------
272292
key_idx : string/int
273293
Key or index of the row in which conditions will be checked
274-
value : float/list
275-
Number (or list of two numbers for ) to compare array values
276-
against
294+
value : float/int/str/array-like
295+
Value that the row is checked against, array-like object
296+
possible for "eq", "neq", or "between" conditions.
277297
condition : string
278298
Condition type (greater than ("greater")/ less than ("lesser")/
279299
equal to ("eq")/ greater than or equal to ("geq")/
280300
lesser than or equal to ("leq") / in between ("between")
281-
inclusive of the provided limits
301+
inclusive of the provided limits / not equal to ("neq"))
282302
283303
Returns
284304
-------
@@ -293,19 +313,40 @@ def where(self, key_idx, value, condition="eq"):
293313
return new_navdata
294314

295315
def argwhere(self, key_idx, value, condition="eq"):
296-
"""Return columns where conditions are met for the given row
316+
"""Return columns where conditions are met for the given row.
317+
318+
For string rows, only the "eq" and "neq" conditions are valid.
319+
The "value" argument can contain either a string, np.nan or an
320+
array-like object of strings. If an array-like object of strings
321+
is passed in then np.isin() is used to check the condition
322+
meaning that the returned subset will contain one of the values
323+
in the "value" array-like object for the "eq" condition or none
324+
of the values in the "value" array-like object for the "neq"
325+
condition.
326+
327+
For non-string rows, all valid conditions are listed in the
328+
"condition" argument description. The "value" argument can either
329+
contain a numeric or an array-like object of numerics for both
330+
the "eq" and "neq" conditions.
331+
If an array-like object is passed then the returned subset will
332+
contain one of the values in the "value" array-like object for
333+
the "eq" condition or none of the values in the "value"
334+
array-like object for the "neq" condition.
335+
For the "between" condition, the two limit values must be passed
336+
into the "value" argument as an array-like object.
297337
298338
Parameters
299339
----------
300340
key_idx : string/int
301341
Key or index of the row in which conditions will be checked
302-
value : float/list
303-
Number (or list of two numbers for ) to compare array values against
342+
value : float/int/str/array-like
343+
Value that the row is checked against, array-like object
344+
possible for "eq", "neq", or "between" conditions.
304345
condition : string
305346
Condition type (greater than ("greater")/ less than ("lesser")/
306347
equal to ("eq")/ greater than or equal to ("geq")/
307-
lesser than or equal to ("leq") / in between ("between")/
308-
not equal to ("neq"))
348+
lesser than or equal to ("leq") / in between ("between")
349+
inclusive of the provided limits / not equal to ("neq"))
309350
310351
Returns
311352
-------
@@ -314,41 +355,52 @@ def argwhere(self, key_idx, value, condition="eq"):
314355
for specified row
315356
"""
316357
rows, _ = self._parse_key_idx(key_idx)
317-
inv_map = self.inv_map
318358
row_list, row_str = self._get_str_rows(rows)
319359
if len(row_list)>1:
320360
error_msg = "where does not currently support multiple rows"
321361
raise NotImplementedError(error_msg)
322362
row = row_list[0]
323363
row_str = row_str[0]
324-
new_cols = None
364+
new_cols = np.array([])
325365
if row_str:
326366
# Values in row are strings
327367
if condition not in ("eq","neq"):
328-
raise ValueError("Inequality comparison not valid for strings")
329-
key = inv_map[row]
330-
for str_key, str_value in self.str_map[key].items():
331-
if str_value==str(value):
332-
if condition == "eq":
333-
new_cols = np.argwhere(self.array[row, :]==str_key)
334-
break
335-
# condition == "neq"
336-
new_cols = np.argwhere(self.array[row, :]!=str_key)
337-
break
338-
if new_cols is None:
339-
new_cols = np.array([])
368+
raise ValueError("Unsupported where condition for strings")
369+
if isinstance(value,str):
370+
str_check = [str(value)]
371+
elif isinstance(value,(np.ndarray,list,tuple,set)):
372+
str_check = [str(v) for v in value]
373+
elif np.isnan(value):
374+
str_check = [str(np.nan)]
375+
else:
376+
raise ValueError("Value must be string or array-like" \
377+
+ "for string condition checks")
340378
# Extract columns where condition holds true and return new NavData
379+
if condition == "eq":
380+
new_cols = np.argwhere(np.isin(self[row, :],str_check))
381+
else:
382+
# condition == "neq"
383+
new_cols = np.argwhere(~np.isin(self[row, :],str_check))
384+
341385
else:
342386
# Values in row are numerical
343387
# Find columns where value can be found and return new NavData
344388
if condition=="eq":
345-
if not isinstance(value,str) and np.isnan(value):
389+
if isinstance(value,(np.ndarray,list,tuple,set)):
390+
# use numpy's isin() condition if list of values
391+
new_cols = np.argwhere(np.isin(self.array[row, :],
392+
value))
393+
elif not isinstance(value,str) and np.isnan(value):
346394
# check isinstance b/c np.isnan can't handle strings
347395
new_cols = np.argwhere(np.isnan(self.array[row, :]))
348396
else:
349397
new_cols = np.argwhere(self.array[row, :]==value)
350398
elif condition=="neq":
351-
if not isinstance(value,str) and np.isnan(value):
399+
if isinstance(value,(np.ndarray,list,tuple,set)):
400+
# use numpy's isin() condition if list of values
401+
new_cols = np.argwhere(~np.isin(self.array[row, :],
402+
value))
403+
elif not isinstance(value,str) and np.isnan(value):
352404
# check isinstance b/c np.isnan can't handle strings
353405
new_cols = np.argwhere(~np.isnan(self.array[row, :]))
354406
else:
@@ -370,71 +422,50 @@ def argwhere(self, key_idx, value, condition="eq"):
370422
new_cols = np.squeeze(new_cols)
371423
return new_cols
372424

373-
374-
def keep_cols_where(self, key_idx, values, condition='eq'):
375-
"""Return NavData containing columns that contain value from given array
376-
377-
Given a list of values, for the equality condition the returned
378-
subset of measurements will contain one of the given values in
379-
the given key_idx; for the inequality condition the returned
380-
subset of measurements will contain none of the given values in
381-
the given key_idx.
425+
def sort(self, order=None, ind=None, ascending=True,
426+
inplace=False):
427+
"""Sort values along given row or using given index
382428
383429
Parameters
384430
----------
385-
key_idx : string/int
386-
Key or index of the row in which conditions will be checked
387-
values : list
388-
List of values that form equality/inequality criteria for the
389-
NavData subset that is returned.
390-
condition : string
391-
Only equality ("eq") and not-equal ("neq") condition types
392-
are supported
431+
order : string/int
432+
Key or index of the row on which NavData will be sorted
433+
ind : list/np.ndarray
434+
Ordering of indices to be used for sorting
435+
ascending : bool
436+
If true, sorts "ascending", otherwise sorts "descending"
437+
inplace : bool
438+
If False, will return new NavData instance with rows
439+
renamed. If True, will rename data rows in the
440+
current NavData instance.
393441
394442
Returns
395443
-------
396-
subset_navdata : gnss_lib_py.parsers.navdata.NavData
397-
NavData containing values in the key_idx that satisfy the
398-
condition for the given values list
399-
"""
400-
keep_cols = []
401-
for value in values:
402-
cols = self.argwhere(key_idx, value, condition)
403-
if isinstance(cols, np.ndarray) and np.size(cols)==1:
404-
cols = [cols]
405-
keep_cols.extend(cols)
406-
keep_cols = np.sort(keep_cols)
407-
subset_navdata = self.copy(cols=keep_cols)
408-
return subset_navdata
409-
410-
411-
def sort(self, key_idx=None, ind=None, order="ascending"):
412-
"""Sort values along given row or using given index
444+
new_navdata : gnss_lib_py.parsers.navdata.NavData or None
445+
If inplace is False, returns NavData instance after renaming
446+
specified rows. If inplace is True, returns
447+
None.
413448
414-
Parameters
415-
----------
416-
key_idx : string/int
417-
Key or index of the row in which conditions will be checked
418-
ind : list/np.ndarray
419-
Ordering of indices to be used for sorting
420-
order : string
421-
Order in which to sort: "ascending" or "descending"
422449
"""
423450
if ind is None:
424-
assert key_idx is not None, \
425-
"Provide row along which to sort because index is not given"
426-
if order=="ascending":
427-
ind = np.argsort(self[key_idx])
428-
elif order=="descending":
429-
ind = np.argsort(-self[key_idx])
451+
assert order is not None, \
452+
"Provide 'order' arg as row on which NavData is sorted"
453+
if ascending:
454+
ind = np.argsort(self[order])
430455
else:
431-
raise RuntimeError("Can only sort in ascending or ", \
432-
"descending order")
433-
new_navdata = NavData()
434-
for row in self.rows:
435-
new_navdata[row] = self[row][ind]
436-
return new_navdata
456+
ind = np.argsort(-self[order])
457+
458+
if not inplace:
459+
new_navdata = self.copy() # create copy to return
460+
for row_idx in range(self.shape[0]):
461+
if inplace:
462+
self.array[row_idx,:] = self.array[row_idx,ind]
463+
else:
464+
new_navdata.array[row_idx,:] = new_navdata.array[row_idx,ind]
437465

466+
if inplace:
467+
return None
468+
return new_navdata
438469

439470
def loop_time(self, time_row, delta_t_decimals=2):
440471
"""Generator object to loop over columns from same times.

gnss_lib_py/utils/sv_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def _filter_ephemeris_measurements(measurements, constellations, ephemeris_path)
314314
req_const_set = set(constellations)
315315
keep_consts = req_const_set.intersection(rx_const_set)
316316

317-
measurements_subset = measurements.keep_cols_where('gnss_id', keep_consts, condition="eq")
317+
measurements_subset = measurements.where('gnss_id', keep_consts, condition="eq")
318318

319319
# preprocessing of received quantities for downloading ephemeris file
320320
eph_sv = _combine_gnss_sv_ids(measurements)
@@ -381,7 +381,7 @@ def _sort_ephem_measures(measure_frame, ephem):
381381
sorted_sats_ind = np.argsort(gnss_sv_id)
382382
inv_sort_order = np.argsort(sorted_sats_ind)
383383
sorted_sats = gnss_sv_id[sorted_sats_ind]
384-
rx_ephem = ephem.keep_cols_where('gnss_sv_id', sorted_sats, condition="eq")
384+
rx_ephem = ephem.where('gnss_sv_id', sorted_sats, condition="eq")
385385
return rx_ephem, sorted_sats_ind, inv_sort_order
386386

387387

tests/parsers/test_navdata.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,6 +1756,9 @@ def test_where_errors(csv_simple):
17561756
# Test condition that is not defined
17571757
with pytest.raises(ValueError):
17581758
_ = data.where("integers", 10, condition="eq_sqrt")
1759+
# Test passing float in for string check
1760+
with pytest.raises(ValueError):
1761+
_ = data.where("names", 0.342, condition="eq")
17591762

17601763
def test_time_looping(csv_simple):
17611764
"""Testing implementation to loop over times
@@ -2265,16 +2268,34 @@ def test_interpolate_fails():
22652268

22662269

22672270
def test_keep_cols_where(data, df_simple):
2271+
"""Test keep columns with where.
2272+
2273+
"""
2274+
# test for strings
22682275
keep_cols = ['gps', 'glonass']
2269-
data_subset = data.keep_cols_where('strings', keep_cols,
2276+
2277+
data_subset = data.where('strings', keep_cols,
22702278
condition="eq")
22712279
df_simple_subset = df_simple.loc[df_simple['strings'].isin(keep_cols), :]
22722280

22732281
df_simple_subset = df_simple_subset.reset_index(drop=True)
22742282
pd.testing.assert_frame_equal(data_subset.pandas_df(), df_simple_subset, check_dtype=False)
22752283

2284+
# test for floats
2285+
keep_cols = [0.5, 0.45]
2286+
2287+
data_subset = data.where('floats', keep_cols,
2288+
condition="neq")
2289+
df_simple_subset = df_simple.loc[~df_simple['floats'].isin(keep_cols), :]
2290+
2291+
df_simple_subset = df_simple_subset.reset_index(drop=True)
2292+
pd.testing.assert_frame_equal(data_subset.pandas_df(), df_simple_subset, check_dtype=False)
22762293

22772294
def test_sort(data, df_simple):
2295+
"""Test sorting function across simple dataframe.
2296+
2297+
"""
2298+
22782299
df_sorted_int = df_simple.sort_values('integers').reset_index(drop=True)
22792300
df_sorted_float = df_simple.sort_values('floats').reset_index(drop=True)
22802301
data_sorted_int = data.sort('integers').pandas_df()
@@ -2284,10 +2305,22 @@ def test_sort(data, df_simple):
22842305
pd.testing.assert_frame_equal(data_sorted_int, df_sorted_int)
22852306
pd.testing.assert_frame_equal(df_sorted_float, data_sorted_float)
22862307
pd.testing.assert_frame_equal(df_sorted_float, data_sorted_ind)
2308+
# test strings as well:
2309+
df_sorted_names = df_simple.sort_values('names').reset_index(drop=True)
2310+
data_sorted_names = data.sort('names').pandas_df()
2311+
pd.testing.assert_frame_equal(df_sorted_names, data_sorted_names)
2312+
2313+
df_sorted_strings = df_simple.sort_values('strings').reset_index(drop=True)
2314+
data_sorted_strings = data.sort('strings').pandas_df()
2315+
pd.testing.assert_frame_equal(df_sorted_strings, data_sorted_strings)
2316+
22872317
# Test usecase when descending order is given
22882318
df_sorted_int_des = df_simple.sort_values('integers', ascending=False).reset_index(drop=True)
2289-
data_sorted_int_des = data.sort('integers', order="descending").pandas_df()
2319+
data_sorted_int_des = data.sort('integers', ascending=False).pandas_df()
2320+
pd.testing.assert_frame_equal(df_sorted_int_des, data_sorted_int_des)
2321+
2322+
# test inplace
2323+
data_sorted_int_des = data.copy()
2324+
data_sorted_int_des.sort('integers', ascending=False, inplace=True)
2325+
data_sorted_int_des = data_sorted_int_des.pandas_df()
22902326
pd.testing.assert_frame_equal(df_sorted_int_des, data_sorted_int_des)
2291-
# Test usecase when incorrect order is given
2292-
with pytest.raises(RuntimeError):
2293-
_ = data.sort('integers', order="equality")

0 commit comments

Comments
 (0)