Skip to content

Commit c49febd

Browse files
committed
nominally working ekf
1 parent cd4b015 commit c49febd

5 files changed

Lines changed: 235 additions & 39 deletions

File tree

gnss_lib_py/algorithms/gnss_filters.py

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
import numpy as np
1111

1212
from gnss_lib_py.parsers.navdata import NavData
13+
from gnss_lib_py.algorithms.snapshot import solve_wls
1314
from gnss_lib_py.utils.coordinates import ecef_to_geodetic
1415
from gnss_lib_py.utils.filters import BaseExtendedKalmanFilter
1516

16-
def solve_ekf_gnss(measurements, init_dict = dict(),
17-
params_dict = dict()):
17+
def solve_gnss_ekf(measurements, init_dict = None,
18+
params_dict = None):
1819
"""Runs a GNSS Extended Kalman Filter across each timestep.
1920
2021
Runs an Extended Kalman Filter across each timestep and adds a new
@@ -44,21 +45,59 @@ def solve_ekf_gnss(measurements, init_dict = dict(),
4445
"x_sv_m","y_sv_m","z_sv_m",
4546
])
4647

48+
if init_dict is None:
49+
init_dict = {}
50+
51+
if "state_0" not in init_dict:
52+
pos_0 = None
53+
for _, _, measurement_subset in measurements.loop_time("gps_millis"):
54+
pos_0 = solve_wls(measurement_subset)
55+
if pos_0 is not None:
56+
break
57+
58+
state_0 = np.zeros((7,1))
59+
if pos_0 is not None:
60+
state_0[:3,0] = pos_0[["x_rx_m","y_rx_m","z_rx_m"]]
61+
state_0[6,0] = pos_0[["b_rx_m"]]
62+
63+
init_dict["state_0"] = state_0
64+
65+
if "sigma_0" not in init_dict:
66+
sigma_0 = np.eye(init_dict["state_0"].size)
67+
init_dict["sigma_0"] = sigma_0
68+
69+
if "Q" not in init_dict:
70+
process_noise = np.eye(init_dict["state_0"].size)
71+
init_dict["Q"] = process_noise
72+
73+
if "R" not in init_dict:
74+
measurement_noise = np.eye(1) # gets overwritten
75+
init_dict["R"] = measurement_noise
76+
77+
# initialize parameter dictionary
78+
if params_dict is None:
79+
params_dict = {}
80+
81+
if "motion_type" not in params_dict:
82+
params_dict["motion_type"] = "constant_velocity"
83+
84+
if "measure_type" not in params_dict:
85+
params_dict["measure_type"] = "pseudorange"
86+
4787
# create initialization parameters.
4888
gnss_ekf = GNSSEKF(init_dict, params_dict)
4989

5090
states = []
5191

5292
for timestamp, delta_t, measurement_subset in measurements.loop_time("gps_millis"):
53-
54-
pos_sv_m = measurement_subset[["x_sv_m","y_sv_m","z_sv_m"]].T
93+
pos_sv_m = measurement_subset[["x_sv_m","y_sv_m","z_sv_m"]]
5594
pos_sv_m = np.atleast_2d(pos_sv_m)
5695

5796
corr_pr_m = measurement_subset["corr_pr_m"].reshape(-1,1)
5897

5998
# remove NaN indexes
60-
not_nan_indexes = ~np.isnan(pos_sv_m).any(axis=1)
61-
pos_sv_m = pos_sv_m[not_nan_indexes]
99+
not_nan_indexes = ~np.isnan(pos_sv_m).any(axis=0)
100+
pos_sv_m = pos_sv_m[:,not_nan_indexes]
62101
corr_pr_m = corr_pr_m[not_nan_indexes]
63102

64103
# prediction step
@@ -67,14 +106,15 @@ def solve_ekf_gnss(measurements, init_dict = dict(),
67106

68107
# update step
69108
update_dict = {"pos_sv_m" : pos_sv_m}
109+
update_dict["measurement_noise"] = np.eye(pos_sv_m.shape[1])
70110
gnss_ekf.update(corr_pr_m, update_dict=update_dict)
71111

72112
states.append([timestamp] + np.squeeze(gnss_ekf.state).tolist())
73113

74114
states = np.array(states)
75115

76116
if states.size == 0:
77-
warnings.warn("No valid state estimate computed in solve_ekf_gnss, "\
117+
warnings.warn("No valid state estimate computed in solve_gnss_ekf, "\
78118
+ "returning None.", RuntimeWarning)
79119
return None
80120

@@ -83,11 +123,13 @@ def solve_ekf_gnss(measurements, init_dict = dict(),
83123
state_estimate["x_rx_m"] = states[:,1]
84124
state_estimate["y_rx_m"] = states[:,2]
85125
state_estimate["z_rx_m"] = states[:,3]
86-
state_estimate["b_rx_m"] = states[:,4]
126+
state_estimate["vx_rx_mps"] = states[:,4]
127+
state_estimate["vy_rx_mps"] = states[:,5]
128+
state_estimate["vz_rx_mps"] = states[:,6]
129+
state_estimate["b_rx_m"] = states[:,7]
87130

88-
lat,lon,alt = ecef_to_geodetic(state_estimate[["x_rx_m",
89-
"y_rx_m",
90-
"z_rx_m"]])
131+
lat,lon,alt = ecef_to_geodetic(state_estimate[["x_rx_m","y_rx_m",
132+
"z_rx_m"]].reshape(3,-1))
91133
state_estimate["lat_rx_deg"] = lat
92134
state_estimate["lon_rx_deg"] = lon
93135
state_estimate["alt_rx_deg"] = alt
@@ -118,7 +160,7 @@ def __init__(self, init_dict, params_dict):
118160
self.motion_type = params_dict.get('motion_type','stationary')
119161
self.measure_type = params_dict.get('measure_type','pseudorange')
120162

121-
def dyn_model(self, u, predict_dict=dict()):
163+
def dyn_model(self, u, predict_dict=None):
122164
"""Nonlinear dynamics
123165
124166
Parameters
@@ -134,6 +176,9 @@ def dyn_model(self, u, predict_dict=dict()):
134176
new_x : np.ndarray
135177
Propagated state
136178
"""
179+
if predict_dict is None:
180+
predict_dict = {}
181+
137182
A = self.linearize_dynamics(predict_dict)
138183
new_x = A @ self.state
139184
return new_x
@@ -178,7 +223,7 @@ def measure_model(self, update_dict):
178223
raise NotImplementedError
179224
return z
180225

181-
def linearize_dynamics(self, predict_dict=dict()):
226+
def linearize_dynamics(self, predict_dict=None):
182227
"""Linearization of dynamics model
183228
184229
Parameters
@@ -194,6 +239,9 @@ def linearize_dynamics(self, predict_dict=dict()):
194239
Dictionary of prediction parameters.
195240
"""
196241

242+
if predict_dict is None:
243+
predict_dict = {}
244+
197245
# uses delta_t from predict_dict if exists, otherwise delta_t
198246
# from the class initialization.
199247
delta_t = predict_dict.get('delta_t', self.delta_t)
@@ -213,12 +261,14 @@ def linearize_measurements(self, update_dict):
213261
Parameters
214262
----------
215263
update_dict : dict
216-
Update dictionary containing satellite positions with key 'pos_sv_m'
264+
Update dictionary containing satellite positions with key
265+
``pos_sv_m``.
217266
218267
Returns
219268
-------
220269
H : np.ndarray
221-
Jacobian of measurement model, dimension M x N
270+
Jacobian of measurement model, of dimension
271+
#measurements x #states
222272
"""
223273
if self.measure_type == 'pseudorange':
224274
pos_sv_m = update_dict['pos_sv_m']

gnss_lib_py/algorithms/snapshot.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,8 @@ def solve_wls(measurements, weight_type = None,
102102
state_estimate["z_rx_m"] = states[:,3]
103103
state_estimate["b_rx_m"] = states[:,4]
104104

105-
lat,lon,alt = ecef_to_geodetic(state_estimate[["x_rx_m",
106-
"y_rx_m",
107-
"z_rx_m"]])
105+
lat,lon,alt = ecef_to_geodetic(state_estimate[["x_rx_m","y_rx_m",
106+
"z_rx_m"]].reshape(3,-1))
108107
state_estimate["lat_rx_deg"] = lat
109108
state_estimate["lon_rx_deg"] = lon
110109
state_estimate["alt_rx_deg"] = alt

gnss_lib_py/utils/filters.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class BaseFilter(ABC):
1818
----------
1919
state_0 : np.ndarray
2020
Initial state estimate
21-
P : np.ndarray
21+
sigma_0 : np.ndarray
2222
Current uncertainty estimated for state estimate (2D covariance)
2323
"""
2424

@@ -62,7 +62,7 @@ def __init__(self, init_dict, params_dict):
6262
self.R = init_dict['R']
6363
self.params_dict = params_dict
6464

65-
def predict(self, u=None, predict_dict=dict()):
65+
def predict(self, u=None, predict_dict=None):
6666
"""Predict the state of the filter given the control input
6767
6868
Parameters
@@ -73,15 +73,17 @@ def predict(self, u=None, predict_dict=dict()):
7373
Additional parameters needed to implement predict step
7474
"""
7575
if u is None:
76-
u = np.zeros((1,self.state_dim))
76+
u = np.zeros((self.state_dim,1))
77+
if predict_dict is None:
78+
predict_dict = {}
7779
assert _check_col_vect(u, np.size(u)), "Control input is not a column vector"
7880
self.state = self.dyn_model(u, predict_dict) # Can pass parameters via predict_dict
7981
A = self.linearize_dynamics(predict_dict)
8082
self.sigma = A @ self.sigma @ A.T + self.Q
8183
assert _check_col_vect(self.state, self.state_dim), "Incorrect state shape after prediction"
8284
assert _check_square_mat(self.sigma, self.state_dim), "Incorrect covariance shape after prediction"
8385

84-
def update(self, z, update_dict=dict()):
86+
def update(self, z, update_dict=None):
8587
"""Update the state of the filter given a noisy measurement of the state
8688
8789
Parameters
@@ -91,9 +93,16 @@ def update(self, z, update_dict=dict()):
9193
update_dict : dict
9294
Additional parameters needed to implement update step
9395
"""
96+
if update_dict is None:
97+
update_dict = {}
98+
99+
# uses process_noise from update_dict if exists, otherwise use
100+
# process_noise from the class initialization.
101+
measurement_noise = update_dict.get('measurement_noise', self.R)
102+
94103
assert _check_col_vect(z, np.size(z)), "Measurements are not a column vector"
95104
H = self.linearize_measurements(update_dict) # Can pass arguments via update_dict
96-
S = H @ self.sigma @ H.T + self.R
105+
S = H @ self.sigma @ H.T + measurement_noise
97106
K = self.sigma @ H.T @ np.linalg.inv(S)
98107
z_expect = self.measure_model(update_dict) # Can pass arguments via update_dict
99108
assert _check_col_vect(z_expect, np.size(z)), "Expected measurements are not a column vector"
@@ -105,25 +114,25 @@ def update(self, z, update_dict=dict()):
105114
assert _check_square_mat(self.sigma, self.state_dim), "Incorrect covariance shape after update"
106115

107116
@abstractmethod
108-
def linearize_dynamics(self, predict_dict=dict()):
117+
def linearize_dynamics(self, predict_dict=None):
109118
"""Linearization of system dynamics, should return A matrix
110119
"""
111120
raise NotImplementedError
112121

113122
@abstractmethod
114-
def linearize_measurements(self, update_dict=dict()):
123+
def linearize_measurements(self, update_dict=None):
115124
"""Linearization of measurement model, should return H matrix
116125
"""
117126
raise NotImplementedError
118127

119128
@abstractmethod
120-
def measure_model(self, update_dict=dict()):
129+
def measure_model(self, update_dict=None):
121130
"""Non-linear measurement model
122131
"""
123132
raise NotImplementedError
124133

125134
@abstractmethod
126-
def dyn_model(self, u, predict_dict=dict()):
135+
def dyn_model(self, u, predict_dict=None):
127136
"""Non-linear dynamics model
128137
"""
129138
raise NotImplementedError
@@ -135,7 +144,7 @@ class BaseKalmanFilter(BaseExtendedKalmanFilter):
135144
model
136145
"""
137146

138-
def dyn_model(self, u, predict_dict=dict()):
147+
def dyn_model(self, u, predict_dict=None):
139148
"""Linear dynamics model
140149
141150
Parameters
@@ -149,12 +158,14 @@ def dyn_model(self, u, predict_dict=dict()):
149158
-------
150159
new_x : State after propagation
151160
"""
161+
if predict_dict is None:
162+
predict_dict = {}
152163
A = self.linearize_dynamics(predict_dict)
153164
B = self.get_B(predict_dict)
154165
new_x = A @ self.state + B @ u
155166
return new_x
156167

157-
def measure_model(self, update_dict=dict()):
168+
def measure_model(self, update_dict=None):
158169
"""Linear measurement model
159170
160171
Parameters
@@ -166,12 +177,14 @@ def measure_model(self, update_dict=dict()):
166177
-------
167178
z_expect : Measurement expected for current state
168179
"""
180+
if update_dict is None:
181+
update_dict = {}
169182
H = self.linearize_measurements(update_dict)
170183
z_expect = H @ self.state
171184
return z_expect
172185

173186
@abstractmethod
174-
def get_B(self, predict_dict=dict()):
187+
def get_B(self, predict_dict=None):
175188
"""Map from control to state, should return B matrix
176189
"""
177190
raise NotImplementedError
@@ -207,7 +220,7 @@ def __init__(self, init_dict, params_dict):
207220
self.N_sig = int(2 * self.state_dim + 1)
208221
self.params_dict = params_dict
209222

210-
def predict(self, u, predict_dict=dict()):
223+
def predict(self, u, predict_dict=None):
211224
"""Predict the state of the filter given the control input
212225
213226
Parameters
@@ -218,6 +231,9 @@ def predict(self, u, predict_dict=dict()):
218231
Additional parameters needed to implement predict step
219232
"""
220233

234+
if predict_dict is None:
235+
predict_dict = {}
236+
221237
N = self.state_dim
222238
N_sig = self.N_sig
223239
x_t_tm = np.zeros((N, N_sig))
@@ -238,7 +254,7 @@ def predict(self, u, predict_dict=dict()):
238254
assert _check_col_vect(self.state, self.state_dim), "Incorrect state shape after prediction"
239255
assert _check_square_mat(self.sigma, self.state_dim), "Incorrect covariance shape after prediction"
240256

241-
def update(self, z, update_dict=dict()):
257+
def update(self, z, update_dict=None):
242258
"""Update the state of the filter given a noisy measurement of the state
243259
244260
Parameters
@@ -248,6 +264,9 @@ def update(self, z, update_dict=dict()):
248264
update_dict : dict
249265
Additional parameters needed to implement update step
250266
"""
267+
if update_dict is None:
268+
update_dict = {}
269+
251270
assert _check_col_vect(z, np.size(z)), "Measurements are not a column vector"
252271
N = self.state_dim
253272
N_sig = self.N_sig
@@ -309,13 +328,13 @@ def inv_U_transform(self, W, x_t_tm):
309328
return np.expand_dims(mu, axis=1), S
310329

311330
@abstractmethod
312-
def measure_model(self, x, update_dict=dict()):
331+
def measure_model(self, x, update_dict=None):
313332
"""Non-linear measurement model
314333
"""
315334
raise NotImplementedError
316335

317336
@abstractmethod
318-
def dyn_model(self, x, u, predict_dict=dict()):
337+
def dyn_model(self, x, u, predict_dict=None):
319338
"""Non-linear dynamics model
320339
"""
321340
raise NotImplementedError

0 commit comments

Comments
 (0)