@@ -18,7 +18,7 @@ class BaseFilter(ABC):
1818 ----------
1919 state_0 : np.ndarray
2020 Initial state estimate
21- P : np.ndarray
21+ sigma_0 : np.ndarray
2222 Current uncertainty estimated for state estimate (2D covariance)
2323 """
2424
@@ -62,7 +62,7 @@ def __init__(self, init_dict, params_dict):
6262 self .R = init_dict ['R' ]
6363 self .params_dict = params_dict
6464
65- def predict (self , u = None , predict_dict = dict () ):
65+ def predict (self , u = None , predict_dict = None ):
6666 """Predict the state of the filter given the control input
6767
6868 Parameters
@@ -73,15 +73,17 @@ def predict(self, u=None, predict_dict=dict()):
7373 Additional parameters needed to implement predict step
7474 """
7575 if u is None :
76- u = np .zeros ((1 ,self .state_dim ))
76+ u = np .zeros ((self .state_dim ,1 ))
77+ if predict_dict is None :
78+ predict_dict = {}
7779 assert _check_col_vect (u , np .size (u )), "Control input is not a column vector"
7880 self .state = self .dyn_model (u , predict_dict ) # Can pass parameters via predict_dict
7981 A = self .linearize_dynamics (predict_dict )
8082 self .sigma = A @ self .sigma @ A .T + self .Q
8183 assert _check_col_vect (self .state , self .state_dim ), "Incorrect state shape after prediction"
8284 assert _check_square_mat (self .sigma , self .state_dim ), "Incorrect covariance shape after prediction"
8385
84- def update (self , z , update_dict = dict () ):
86+ def update (self , z , update_dict = None ):
8587 """Update the state of the filter given a noisy measurement of the state
8688
8789 Parameters
@@ -91,9 +93,16 @@ def update(self, z, update_dict=dict()):
9193 update_dict : dict
9294 Additional parameters needed to implement update step
9395 """
96+ if update_dict is None :
97+ update_dict = {}
98+
99+ # uses process_noise from update_dict if exists, otherwise use
100+ # process_noise from the class initialization.
101+ measurement_noise = update_dict .get ('measurement_noise' , self .R )
102+
94103 assert _check_col_vect (z , np .size (z )), "Measurements are not a column vector"
95104 H = self .linearize_measurements (update_dict ) # Can pass arguments via update_dict
96- S = H @ self .sigma @ H .T + self . R
105+ S = H @ self .sigma @ H .T + measurement_noise
97106 K = self .sigma @ H .T @ np .linalg .inv (S )
98107 z_expect = self .measure_model (update_dict ) # Can pass arguments via update_dict
99108 assert _check_col_vect (z_expect , np .size (z )), "Expected measurements are not a column vector"
@@ -105,25 +114,25 @@ def update(self, z, update_dict=dict()):
105114 assert _check_square_mat (self .sigma , self .state_dim ), "Incorrect covariance shape after update"
106115
107116 @abstractmethod
108- def linearize_dynamics (self , predict_dict = dict () ):
117+ def linearize_dynamics (self , predict_dict = None ):
109118 """Linearization of system dynamics, should return A matrix
110119 """
111120 raise NotImplementedError
112121
113122 @abstractmethod
114- def linearize_measurements (self , update_dict = dict () ):
123+ def linearize_measurements (self , update_dict = None ):
115124 """Linearization of measurement model, should return H matrix
116125 """
117126 raise NotImplementedError
118127
119128 @abstractmethod
120- def measure_model (self , update_dict = dict () ):
129+ def measure_model (self , update_dict = None ):
121130 """Non-linear measurement model
122131 """
123132 raise NotImplementedError
124133
125134 @abstractmethod
126- def dyn_model (self , u , predict_dict = dict () ):
135+ def dyn_model (self , u , predict_dict = None ):
127136 """Non-linear dynamics model
128137 """
129138 raise NotImplementedError
@@ -135,7 +144,7 @@ class BaseKalmanFilter(BaseExtendedKalmanFilter):
135144 model
136145 """
137146
138- def dyn_model (self , u , predict_dict = dict () ):
147+ def dyn_model (self , u , predict_dict = None ):
139148 """Linear dynamics model
140149
141150 Parameters
@@ -149,12 +158,14 @@ def dyn_model(self, u, predict_dict=dict()):
149158 -------
150159 new_x : State after propagation
151160 """
161+ if predict_dict is None :
162+ predict_dict = {}
152163 A = self .linearize_dynamics (predict_dict )
153164 B = self .get_B (predict_dict )
154165 new_x = A @ self .state + B @ u
155166 return new_x
156167
157- def measure_model (self , update_dict = dict () ):
168+ def measure_model (self , update_dict = None ):
158169 """Linear measurement model
159170
160171 Parameters
@@ -166,12 +177,14 @@ def measure_model(self, update_dict=dict()):
166177 -------
167178 z_expect : Measurement expected for current state
168179 """
180+ if update_dict is None :
181+ update_dict = {}
169182 H = self .linearize_measurements (update_dict )
170183 z_expect = H @ self .state
171184 return z_expect
172185
173186 @abstractmethod
174- def get_B (self , predict_dict = dict () ):
187+ def get_B (self , predict_dict = None ):
175188 """Map from control to state, should return B matrix
176189 """
177190 raise NotImplementedError
@@ -207,7 +220,7 @@ def __init__(self, init_dict, params_dict):
207220 self .N_sig = int (2 * self .state_dim + 1 )
208221 self .params_dict = params_dict
209222
210- def predict (self , u , predict_dict = dict () ):
223+ def predict (self , u , predict_dict = None ):
211224 """Predict the state of the filter given the control input
212225
213226 Parameters
@@ -218,6 +231,9 @@ def predict(self, u, predict_dict=dict()):
218231 Additional parameters needed to implement predict step
219232 """
220233
234+ if predict_dict is None :
235+ predict_dict = {}
236+
221237 N = self .state_dim
222238 N_sig = self .N_sig
223239 x_t_tm = np .zeros ((N , N_sig ))
@@ -238,7 +254,7 @@ def predict(self, u, predict_dict=dict()):
238254 assert _check_col_vect (self .state , self .state_dim ), "Incorrect state shape after prediction"
239255 assert _check_square_mat (self .sigma , self .state_dim ), "Incorrect covariance shape after prediction"
240256
241- def update (self , z , update_dict = dict () ):
257+ def update (self , z , update_dict = None ):
242258 """Update the state of the filter given a noisy measurement of the state
243259
244260 Parameters
@@ -248,6 +264,9 @@ def update(self, z, update_dict=dict()):
248264 update_dict : dict
249265 Additional parameters needed to implement update step
250266 """
267+ if update_dict is None :
268+ update_dict = {}
269+
251270 assert _check_col_vect (z , np .size (z )), "Measurements are not a column vector"
252271 N = self .state_dim
253272 N_sig = self .N_sig
@@ -309,13 +328,13 @@ def inv_U_transform(self, W, x_t_tm):
309328 return np .expand_dims (mu , axis = 1 ), S
310329
311330 @abstractmethod
312- def measure_model (self , x , update_dict = dict () ):
331+ def measure_model (self , x , update_dict = None ):
313332 """Non-linear measurement model
314333 """
315334 raise NotImplementedError
316335
317336 @abstractmethod
318- def dyn_model (self , x , u , predict_dict = dict () ):
337+ def dyn_model (self , x , u , predict_dict = None ):
319338 """Non-linear dynamics model
320339 """
321340 raise NotImplementedError
0 commit comments