Skip to content

Commit 18cdc7f

Browse files
committed
Added where eq, neq for single column when satisfied
1 parent 527faf5 commit 18cdc7f

2 files changed

Lines changed: 33 additions & 7 deletions

File tree

gnss_lib_py/parsers/navdata.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -377,19 +377,21 @@ def argwhere(self, key_idx, value, condition="eq"):
377377
+ "for string condition checks")
378378
# Extract columns where condition holds true and return new NavData
379379
if condition == "eq":
380-
new_cols = np.argwhere(np.isin(self[row, :],str_check))
380+
new_cols = np.argwhere(np.atleast_1d(np.isin(self[row, :],
381+
str_check)))
381382
else:
382383
# condition == "neq"
383-
new_cols = np.argwhere(~np.isin(self[row, :],str_check))
384+
new_cols = np.argwhere(np.atleast_1d(~np.isin(self[row, :],
385+
str_check)))
384386

385387
else:
386388
# Values in row are numerical
387389
# Find columns where value can be found and return new NavData
388390
if condition=="eq":
389391
if isinstance(value,(np.ndarray,list,tuple,set)):
390392
# use numpy's isin() condition if list of values
391-
new_cols = np.argwhere(np.isin(self.array[row, :],
392-
value))
393+
new_cols = np.argwhere(np.atleast_1d(np.isin(self.array[row, :],
394+
value)))
393395
elif not isinstance(value,str) and np.isnan(value):
394396
# check isinstance b/c np.isnan can't handle strings
395397
new_cols = np.argwhere(np.isnan(self.array[row, :]))
@@ -398,11 +400,11 @@ def argwhere(self, key_idx, value, condition="eq"):
398400
elif condition=="neq":
399401
if isinstance(value,(np.ndarray,list,tuple,set)):
400402
# use numpy's isin() condition if list of values
401-
new_cols = np.argwhere(~np.isin(self.array[row, :],
402-
value))
403+
new_cols = np.argwhere(np.atleast_1d(~np.isin(self.array[row, :],
404+
value)))
403405
elif not isinstance(value,str) and np.isnan(value):
404406
# check isinstance b/c np.isnan can't handle strings
405-
new_cols = np.argwhere(~np.isnan(self.array[row, :]))
407+
new_cols = np.argwhere(np.atleast_1d(~np.isnan(self.array[row, :])))
406408
else:
407409
new_cols = np.argwhere(self.array[row, :]!=value)
408410
elif condition == "leq":

tests/parsers/test_navdata.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,6 +1690,19 @@ def test_where_str(csv_simple):
16901690
compare_df = compare_df[compare_df['strings']!="gps"].reset_index(drop=True)
16911691
pd.testing.assert_frame_equal(data_small.pandas_df(), compare_df)
16921692

1693+
#Test equality for cases where there is only one column
1694+
data_single_column = data.where('strings', 'glonass', 'eq')
1695+
data_new = data_single_column.where('strings', 'glonass', 'eq')
1696+
compare_df = data.pandas_df()
1697+
compare_df = compare_df[compare_df['strings']=="glonass"].reset_index(drop=True)
1698+
pd.testing.assert_frame_equal(data_new.pandas_df(), compare_df)
1699+
1700+
#Test inequality for cases where there is only one column
1701+
data_new = data_single_column.where('strings', 'gps', 'neq')
1702+
# Both cases should return the same dataframe as before
1703+
pd.testing.assert_frame_equal(data_new.pandas_df(), compare_df)
1704+
1705+
16931706
def test_where_empty(df_simple):
16941707
"""Verify empty slices.
16951708
@@ -1750,6 +1763,17 @@ def test_where_numbers(csv_simple):
17501763
compare_df = compare_df.iloc[pd_rows[idx], :].reset_index(drop=True)
17511764
pd.testing.assert_frame_equal(data_small.pandas_df(), compare_df)
17521765

1766+
#Test equality for cases where there is only one column
1767+
data_single_column = data.where('integers', 10, 'eq')
1768+
data_new = data_single_column.where('integers', 10, 'eq')
1769+
compare_df = data.pandas_df()
1770+
compare_df = compare_df[compare_df['integers']==10].reset_index(drop=True)
1771+
pd.testing.assert_frame_equal(data_new.pandas_df(), compare_df)
1772+
1773+
#Test inequality for cases where there is only one column
1774+
data_new = data_single_column.where('integers', 56, 'neq')
1775+
pd.testing.assert_frame_equal(data_new.pandas_df(), compare_df)
1776+
17531777
def test_where_errors(csv_simple):
17541778
"""Testing error cases for NavData.where
17551779

0 commit comments

Comments
 (0)