1- """Parent classes for Kalman filter algorithms
1+ """Parent classes for Kalman filter algorithms.
22
33"""
44
55__authors__ = "Ashwin Kanhere, Shivam Soni"
66__date__ = "20 January 2020"
77
8- import numpy as np
9- from scipy .linalg import sqrtm
108from abc import ABC , abstractmethod
119
12-
13- def check_col_vect (vect , dim ):
14- """Boolean for whether input vector is column shaped or not
15-
16- Parameters
17- ----------
18- vect : np.ndarray
19- Input vector
20- dim : int
21- Number of row elements in column vector
22- """
23- check = False
24- if np .shape (vect )[0 ] == dim and np .shape (vect )[1 ] == 1 :
25- check = True
26- return check
27-
28-
29- def check_square_mat (mat , dim ):
30- """Boolean for whether input matrices are square or not
31-
32- Parameters
33- ----------
34- vect : np.ndarray
35- Input matrix
36- dim : int
37- Number of elements for row and column = N for N x N
38- """
39- check = False
40- if np .shape (mat )[0 ] == dim and np .shape (mat )[1 ] == dim :
41- check = True
42- return check
10+ import numpy as np
11+ from scipy .linalg import sqrtm
4312
4413
4514class BaseFilter (ABC ):
@@ -57,8 +26,8 @@ class BaseFilter(ABC):
5726
5827 def __init__ (self , x_dim , x0 , P0 ):
5928 self .x_dim = x_dim
60- assert check_col_vect (x0 , self .x_dim ), "Incorrect initial state shape"
61- assert check_square_mat (P0 , self .x_dim ), "Incorrect initial cov shape"
29+ assert _check_col_vect (x0 , self .x_dim ), "Incorrect initial state shape"
30+ assert _check_square_mat (P0 , self .x_dim ), "Incorrect initial cov shape"
6231 self .x = x0
6332 self .P = P0
6433
@@ -90,7 +59,7 @@ class BaseExtendedKalmanFilter(BaseFilter):
9059
9160 def __init__ (self , init_dict , params_dict ):
9261 super ().__init__ (init_dict ['x_dim' ], init_dict ['x0' ], init_dict ['P0' ])
93- assert check_square_mat (init_dict ['Q' ], self .x_dim )
62+ assert _check_square_mat (init_dict ['Q' ], self .x_dim )
9463 self .Q = init_dict ['Q' ]
9564 self .R = init_dict ['R' ]
9665 self .params_dict = params_dict
@@ -105,12 +74,12 @@ def predict(self, u, predict_dict=None):
10574 predict_dict : Dict
10675 Additional parameters needed to implement predict step
10776 """
108- assert check_col_vect (u , np .size (u )), "Control input is not a column vector"
77+ assert _check_col_vect (u , np .size (u )), "Control input is not a column vector"
10978 self .x = self .dyn_model (u , predict_dict ) # Can pass parameters via predict_dict
11079 A = self .linearize_dynamics (predict_dict )
11180 self .P = A @ self .P @ A .T + self .Q
112- assert check_col_vect (self .x , self .x_dim ), "Incorrect state shape after prediction"
113- assert check_square_mat (self .P , self .x_dim ), "Incorrect covariance shape after prediction"
81+ assert _check_col_vect (self .x , self .x_dim ), "Incorrect state shape after prediction"
82+ assert _check_square_mat (self .P , self .x_dim ), "Incorrect covariance shape after prediction"
11483
11584 def update (self , z , update_dict = None ):
11685 """Update the state of the filter given a noisy measurement of the state
@@ -122,18 +91,18 @@ def update(self, z, update_dict=None):
12291 update_dict : Dict
12392 Additional parameters needed to implement update step
12493 """
125- assert check_col_vect (z , np .size (z )), "Measurements are not a column vector"
94+ assert _check_col_vect (z , np .size (z )), "Measurements are not a column vector"
12695 H = self .linearize_measurements (update_dict ) # Can pass arguments via update_dict
12796 S = H @ self .P @ H .T + self .R
12897 K = self .P @ H .T @ np .linalg .inv (S )
12998 z_expect = self .measure_model (update_dict ) # Can pass arguments via update_dict
130- assert check_col_vect (z_expect , np .size (z )), "Expected measurements are not a column vector"
99+ assert _check_col_vect (z_expect , np .size (z )), "Expected measurements are not a column vector"
131100 # Updating state
132101 self .x = self .x + K @ (z - z_expect )
133102 # Update covariance
134103 self .P = (np .eye (self .x_dim ) - K @ H ) @ self .P
135- assert check_col_vect (self .x , self .x_dim ), "Incorrect state shape after update"
136- assert check_square_mat (self .P , self .x_dim ), "Incorrect covariance shape after update"
104+ assert _check_col_vect (self .x , self .x_dim ), "Incorrect state shape after update"
105+ assert _check_square_mat (self .P , self .x_dim ), "Incorrect covariance shape after update"
137106
138107 @abstractmethod
139108 def linearize_dynamics (self , predict_dict = None ):
@@ -224,7 +193,7 @@ class BaseUnscentedKalmanFilter(BaseFilter):
224193
225194 def __init__ (self , init_dict , params_dict ):
226195 super ().__init__ (init_dict ['x_dim' ], init_dict ['x0' ], init_dict ['P0' ])
227- assert check_square_mat (init_dict ['Q' ], self .x_dim )
196+ assert _check_square_mat (init_dict ['Q' ], self .x_dim )
228197 self .Q = init_dict ['Q' ]
229198 self .R = init_dict ['R' ]
230199 if 'lam' in init_dict :
@@ -266,8 +235,8 @@ def predict(self, u, predict_dict=None):
266235 S_t_tm = S_t_tm + self .Q
267236 self .x = mu_t_tm
268237 self .P = S_t_tm
269- assert check_col_vect (self .x , self .x_dim ), "Incorrect state shape after prediction"
270- assert check_square_mat (self .P , self .x_dim ), "Incorrect covariance shape after prediction"
238+ assert _check_col_vect (self .x , self .x_dim ), "Incorrect state shape after prediction"
239+ assert _check_square_mat (self .P , self .x_dim ), "Incorrect covariance shape after prediction"
271240
272241 def update (self , z , update_dict = None ):
273242 """Update the state of the filter given a noisy measurement of the state
@@ -279,7 +248,7 @@ def update(self, z, update_dict=None):
279248 update_dict : Dict
280249 Additional parameters needed to implement update step
281250 """
282- assert check_col_vect (z , np .size (z )), "Measurements are not a column vector"
251+ assert _check_col_vect (z , np .size (z )), "Measurements are not a column vector"
283252 N = self .x_dim
284253 N_sig = self .N_sig
285254
@@ -301,8 +270,8 @@ def update(self, z, update_dict=None):
301270 self .x = self .x + S_xy_t_tm @ np .linalg .inv (S_y_t_tm ) @ meas_res
302271 self .P = self .P - S_xy_t_tm @ np .linalg .inv (S_y_t_tm ) @ S_xy_t_tm .T
303272
304- assert check_col_vect (self .x , self .x_dim ), "Incorrect state shape after update"
305- assert check_square_mat (self .P , self .x_dim ), "Incorrect covariance shape after update"
273+ assert _check_col_vect (self .x , self .x_dim ), "Incorrect state shape after update"
274+ assert _check_square_mat (self .P , self .x_dim ), "Incorrect covariance shape after update"
306275
307276 def U_transform (self ):
308277 """
@@ -350,3 +319,28 @@ def dyn_model(self, x, u, predict_dict=None):
350319 """Non-linear dynamics model
351320 """
352321 raise NotImplementedError
322+
323+ def _check_col_vect (vect , dim ):
324+ """Boolean for whether input vector is column shaped or not
325+
326+ Parameters
327+ ----------
328+ vect : np.ndarray
329+ Input vector
330+ dim : int
331+ Number of row elements in column vector
332+ """
333+ return np .shape (vect )[0 ] == dim and np .shape (vect )[1 ] == 1
334+
335+ def _check_square_mat (mat , dim ):
336+ """Boolean for whether input matrices are square or not
337+
338+ Parameters
339+ ----------
340+ vect : np.ndarray
341+ Input matrix
342+ dim : int
343+ Number of elements for row and column = N for N x N
344+ """
345+
346+ return np .shape (mat )[0 ] == dim and np .shape (mat )[1 ] == dim
0 commit comments