Skip to content

Commit 1e8a996

Browse files
committed
Merge branch 'main' into ashwin/android_derived
2 parents 3b944d0 + 832b776 commit 1e8a996

6 files changed

Lines changed: 27 additions & 16 deletions

File tree

gnss_lib_py/algorithms/residuals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def solve_residuals(measurements, state_estimate):
2727

2828
unique_timesteps = np.unique(measurements["gps_millis",:])
2929
for t_idx, timestep in enumerate(unique_timesteps):
30-
idxs = np.where(measurements["gps_millis",:] == timestep)[1]
30+
idxs = np.where(measurements["millisSinceGpsEpoch",:] == timestep)[0]
3131

3232
pos_sv_m = np.hstack((measurements["x_sv_m",idxs].reshape(-1,1),
3333
measurements["y_sv_m",idxs].reshape(-1,1),

gnss_lib_py/algorithms/snapshot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def solve_wls(measurements, weight_type = None,
6565
states = np.nan*np.ones((4,len(unique_timesteps)))
6666

6767
for t_idx, timestep in enumerate(unique_timesteps):
68-
idxs = np.where(measurements["gps_millis",:] == timestep)[1]
68+
idxs = np.where(measurements["millisSinceGpsEpoch",:] == timestep)[0]
6969
pos_sv_m = np.hstack((measurements["x_sv_m",idxs].reshape(-1,1),
7070
measurements["y_sv_m",idxs].reshape(-1,1),
7171
measurements["z_sv_m",idxs].reshape(-1,1)))

gnss_lib_py/parsers/navdata.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def from_pandas_df(self, pandas_df):
107107

108108
for _, col_name in enumerate(pandas_df.columns):
109109
newvalue = pandas_df[col_name].to_numpy()
110-
self.__setitem__(col_name, newvalue)
110+
self[col_name] = newvalue
111111

112112
def from_numpy_array(self, numpy_array):
113113
"""Build attributes of NavData using np.ndarray.
@@ -125,7 +125,7 @@ def from_numpy_array(self, numpy_array):
125125
self.build_navdata()
126126

127127
for row_num in range(numpy_array.shape[0]):
128-
self.__setitem__(str(row_num), numpy_array[row_num,:])
128+
self[str(row_num)] = numpy_array[row_num,:]
129129

130130
@staticmethod
131131
def _row_map():
@@ -277,7 +277,8 @@ def __getitem__(self, key_idx):
277277
-------
278278
arr_slice : np.ndarray
279279
Array of data containing row names and time indexed
280-
columns
280+
columns. The return is squeezed meaning that all dimensions
281+
of the output that are length of one are removed
281282
"""
282283
rows, cols = self._parse_key_idx(key_idx)
283284
row_list, row_str = self._get_str_rows(rows)
@@ -293,6 +294,10 @@ def __getitem__(self, key_idx):
293294
# arr_slice.append(str_arr[ cols])
294295
else:
295296
arr_slice = self.array[rows, cols]
297+
298+
# remove all dimensions of length one
299+
arr_slice = np.squeeze(arr_slice)
300+
296301
return arr_slice
297302

298303
def __setitem__(self, key_idx, newvalue):
@@ -330,7 +335,7 @@ def __setitem__(self, key_idx, newvalue):
330335
else:
331336
# print("\n",key_idx,"\n")#,newvalue)
332337
if not isinstance(newvalue, int) and not isinstance(newvalue, float):
333-
assert not isinstance(np.asarray(newvalue)[0], str), \
338+
assert not isinstance(np.asarray(newvalue).item(0), str), \
334339
"Cannot set a row with list of strings, please use np.ndarray with dtype=object"
335340
# Adding numeric values
336341
self.str_map[key_idx] = {}
@@ -408,7 +413,10 @@ def _str_2_val(self, new_str_vals, newvalue, key):
408413
dtype=self.arr_dtype)
409414
# Set unassigned value to int not accessed by string map
410415
for str_key, str_val in str_dict.items():
411-
new_str_vals[newvalue==str_val] = str_key
416+
if new_str_vals.size == 1:
417+
new_str_vals = np.array(str_key,dtype=self.arr_dtype)
418+
else:
419+
new_str_vals[newvalue==str_val] = str_key
412420
# Copy set to false to prevent memory overflows
413421
new_str_vals = np.round(new_str_vals.astype(self.arr_dtype,
414422
copy=False))
@@ -693,7 +701,10 @@ def fillna(self, array):
693701
694702
"""
695703
nan_str = np.array([np.nan]).astype(str)[0]
696-
array[np.where(array.astype(str)==nan_str)] = ""
704+
if array.size > 1:
705+
array[np.where(array.astype(str)==nan_str)] = ""
706+
elif array.size == 1 and array == nan_str:
707+
array = np.array("")
697708

698709
def rename(self, mapper):
699710
"""Rename rows of NavData class.

gnss_lib_py/utils/visualizations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def plot_metric(navdata, *args, save=True, prefix=""):
132132
plt_title = y_metric
133133
plt.title(plt_title)
134134
data = navdata[y_metric]
135-
axes.scatter(range(data.shape[1]),data,s=5.)
135+
axes.scatter(range(data.shape[0]),data,s=5.)
136136
plt.xlabel("index")
137137
plt.ylabel(y_metric)
138138
else:
@@ -319,7 +319,7 @@ def plot_skyplot(navdata, state_estimate, save=True, prefix=""):
319319
navdata["z_sv_m",:].reshape(-1,1)))
320320

321321
for t_idx, timestep in enumerate(np.unique(navdata["gps_millis",:])):
322-
idxs = np.where(navdata["gps_millis",:] == timestep)[1]
322+
idxs = np.where(navdata["gps_millis",:] == timestep)[0]
323323
for m_idx in idxs:
324324

325325
if signal_types[m_idx] not in skyplot_data:

tests/parsers/test_android.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def test_get_and_set_num(derived):
219219
key = 'testing123'
220220
value = np.zeros(len(derived))
221221
derived[key] = value
222-
np.testing.assert_equal(derived[key, :], np.reshape(value, [1, -1]))
222+
np.testing.assert_equal(derived[key, :], value)
223223

224224

225225
def test_get_and_set_str(derived):
@@ -240,7 +240,7 @@ def test_get_and_set_str(derived):
240240
value = np.concatenate((np.asarray(value1, dtype=object), np.asarray(value2, dtype=object)))
241241
derived[key] = value
242242

243-
np.testing.assert_equal(derived[key, :], [value])
243+
np.testing.assert_equal(derived[key, :], value)
244244

245245

246246
def test_imu_raw(android_raw_path):

tests/parsers/test_navdata.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ def test_get_item(data, index, exp_value):
546546
exp_value : np.ndarray
547547
Expected value at queried indices
548548
"""
549-
np.testing.assert_array_equal(data[index], exp_value)
549+
np.testing.assert_array_equal(data[index], np.squeeze(exp_value))
550550

551551

552552
def test_get_all_numpy(numpy_array):
@@ -648,7 +648,7 @@ def test_set_get_item(data, index, new_value, exp_value):
648648
Expected value at queried indices
649649
"""
650650
data[index] = new_value
651-
np.testing.assert_array_equal(data[index], exp_value)
651+
np.testing.assert_array_equal(data[index], np.squeeze(exp_value))
652652

653653
@pytest.mark.parametrize("row_idx",
654654
[slice(7, 8),
@@ -709,8 +709,8 @@ def test_add_numpy_1d():
709709
"""
710710
data = NavData(numpy_array=np.zeros([1,6]))
711711
data.add(numpy_array=np.ones(8))
712-
np.testing.assert_array_equal(data[0, :], np.hstack((np.zeros([1,6]),
713-
np.ones([1, 8]))))
712+
np.testing.assert_array_equal(data[0, :], np.hstack((np.zeros(6),
713+
np.ones(8))))
714714

715715
# test adding to empty NavData
716716
data_empty = NavData()

0 commit comments

Comments
 (0)