Skip to content

Commit cc4f3d8

Browse files
committed
move checks to private functions
1 parent e77bbcf commit cc4f3d8

1 file changed

Lines changed: 44 additions & 50 deletions

File tree

gnss_lib_py/utils/filters.py

Lines changed: 44 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,14 @@
1-
"""Parent classes for Kalman filter algorithms
1+
"""Parent classes for Kalman filter algorithms.
22
33
"""
44

55
__authors__ = "Ashwin Kanhere, Shivam Soni"
66
__date__ = "20 January 2020"
77

8-
import numpy as np
9-
from scipy.linalg import sqrtm
108
from abc import ABC, abstractmethod
119

12-
13-
def check_col_vect(vect, dim):
14-
"""Boolean for whether input vector is column shaped or not
15-
16-
Parameters
17-
----------
18-
vect : np.ndarray
19-
Input vector
20-
dim : int
21-
Number of row elements in column vector
22-
"""
23-
check = False
24-
if np.shape(vect)[0] == dim and np.shape(vect)[1] == 1:
25-
check = True
26-
return check
27-
28-
29-
def check_square_mat(mat, dim):
30-
"""Boolean for whether input matrices are square or not
31-
32-
Parameters
33-
----------
34-
vect : np.ndarray
35-
Input matrix
36-
dim : int
37-
Number of elements for row and column = N for N x N
38-
"""
39-
check = False
40-
if np.shape(mat)[0] == dim and np.shape(mat)[1] == dim:
41-
check = True
42-
return check
10+
import numpy as np
11+
from scipy.linalg import sqrtm
4312

4413

4514
class BaseFilter(ABC):
@@ -57,8 +26,8 @@ class BaseFilter(ABC):
5726

5827
def __init__(self, x_dim, x0, P0):
5928
self.x_dim = x_dim
60-
assert check_col_vect(x0, self.x_dim), "Incorrect initial state shape"
61-
assert check_square_mat(P0, self.x_dim), "Incorrect initial cov shape"
29+
assert _check_col_vect(x0, self.x_dim), "Incorrect initial state shape"
30+
assert _check_square_mat(P0, self.x_dim), "Incorrect initial cov shape"
6231
self.x = x0
6332
self.P = P0
6433

@@ -90,7 +59,7 @@ class BaseExtendedKalmanFilter(BaseFilter):
9059

9160
def __init__(self, init_dict, params_dict):
9261
super().__init__(init_dict['x_dim'], init_dict['x0'], init_dict['P0'])
93-
assert check_square_mat(init_dict['Q'], self.x_dim)
62+
assert _check_square_mat(init_dict['Q'], self.x_dim)
9463
self.Q = init_dict['Q']
9564
self.R = init_dict['R']
9665
self.params_dict = params_dict
@@ -105,12 +74,12 @@ def predict(self, u, predict_dict=None):
10574
predict_dict : Dict
10675
Additional parameters needed to implement predict step
10776
"""
108-
assert check_col_vect(u, np.size(u)), "Control input is not a column vector"
77+
assert _check_col_vect(u, np.size(u)), "Control input is not a column vector"
10978
self.x = self.dyn_model(u, predict_dict) # Can pass parameters via predict_dict
11079
A = self.linearize_dynamics(predict_dict)
11180
self.P = A @ self.P @ A.T + self.Q
112-
assert check_col_vect(self.x, self.x_dim), "Incorrect state shape after prediction"
113-
assert check_square_mat(self.P, self.x_dim), "Incorrect covariance shape after prediction"
81+
assert _check_col_vect(self.x, self.x_dim), "Incorrect state shape after prediction"
82+
assert _check_square_mat(self.P, self.x_dim), "Incorrect covariance shape after prediction"
11483

11584
def update(self, z, update_dict=None):
11685
"""Update the state of the filter given a noisy measurement of the state
@@ -122,18 +91,18 @@ def update(self, z, update_dict=None):
12291
update_dict : Dict
12392
Additional parameters needed to implement update step
12493
"""
125-
assert check_col_vect(z, np.size(z)), "Measurements are not a column vector"
94+
assert _check_col_vect(z, np.size(z)), "Measurements are not a column vector"
12695
H = self.linearize_measurements(update_dict) # Can pass arguments via update_dict
12796
S = H @ self.P @ H.T + self.R
12897
K = self.P @ H.T @ np.linalg.inv(S)
12998
z_expect = self.measure_model(update_dict) # Can pass arguments via update_dict
130-
assert check_col_vect(z_expect, np.size(z)), "Expected measurements are not a column vector"
99+
assert _check_col_vect(z_expect, np.size(z)), "Expected measurements are not a column vector"
131100
# Updating state
132101
self.x = self.x + K @ (z - z_expect)
133102
# Update covariance
134103
self.P = (np.eye(self.x_dim) - K @ H) @ self.P
135-
assert check_col_vect(self.x, self.x_dim), "Incorrect state shape after update"
136-
assert check_square_mat(self.P, self.x_dim), "Incorrect covariance shape after update"
104+
assert _check_col_vect(self.x, self.x_dim), "Incorrect state shape after update"
105+
assert _check_square_mat(self.P, self.x_dim), "Incorrect covariance shape after update"
137106

138107
@abstractmethod
139108
def linearize_dynamics(self, predict_dict=None):
@@ -224,7 +193,7 @@ class BaseUnscentedKalmanFilter(BaseFilter):
224193

225194
def __init__(self, init_dict, params_dict):
226195
super().__init__(init_dict['x_dim'], init_dict['x0'], init_dict['P0'])
227-
assert check_square_mat(init_dict['Q'], self.x_dim)
196+
assert _check_square_mat(init_dict['Q'], self.x_dim)
228197
self.Q = init_dict['Q']
229198
self.R = init_dict['R']
230199
if 'lam' in init_dict:
@@ -266,8 +235,8 @@ def predict(self, u, predict_dict=None):
266235
S_t_tm = S_t_tm + self.Q
267236
self.x = mu_t_tm
268237
self.P = S_t_tm
269-
assert check_col_vect(self.x, self.x_dim), "Incorrect state shape after prediction"
270-
assert check_square_mat(self.P, self.x_dim), "Incorrect covariance shape after prediction"
238+
assert _check_col_vect(self.x, self.x_dim), "Incorrect state shape after prediction"
239+
assert _check_square_mat(self.P, self.x_dim), "Incorrect covariance shape after prediction"
271240

272241
def update(self, z, update_dict=None):
273242
"""Update the state of the filter given a noisy measurement of the state
@@ -279,7 +248,7 @@ def update(self, z, update_dict=None):
279248
update_dict : Dict
280249
Additional parameters needed to implement update step
281250
"""
282-
assert check_col_vect(z, np.size(z)), "Measurements are not a column vector"
251+
assert _check_col_vect(z, np.size(z)), "Measurements are not a column vector"
283252
N = self.x_dim
284253
N_sig = self.N_sig
285254

@@ -301,8 +270,8 @@ def update(self, z, update_dict=None):
301270
self.x = self.x + S_xy_t_tm @ np.linalg.inv(S_y_t_tm) @ meas_res
302271
self.P = self.P - S_xy_t_tm @ np.linalg.inv(S_y_t_tm) @ S_xy_t_tm.T
303272

304-
assert check_col_vect(self.x, self.x_dim), "Incorrect state shape after update"
305-
assert check_square_mat(self.P, self.x_dim), "Incorrect covariance shape after update"
273+
assert _check_col_vect(self.x, self.x_dim), "Incorrect state shape after update"
274+
assert _check_square_mat(self.P, self.x_dim), "Incorrect covariance shape after update"
306275

307276
def U_transform(self):
308277
"""
@@ -350,3 +319,28 @@ def dyn_model(self, x, u, predict_dict=None):
350319
"""Non-linear dynamics model
351320
"""
352321
raise NotImplementedError
322+
323+
def _check_col_vect(vect, dim):
324+
"""Boolean for whether input vector is column shaped or not
325+
326+
Parameters
327+
----------
328+
vect : np.ndarray
329+
Input vector
330+
dim : int
331+
Number of row elements in column vector
332+
"""
333+
return np.shape(vect)[0] == dim and np.shape(vect)[1] == 1
334+
335+
def _check_square_mat(mat, dim):
336+
"""Boolean for whether input matrices are square or not
337+
338+
Parameters
339+
----------
340+
vect : np.ndarray
341+
Input matrix
342+
dim : int
343+
Number of elements for row and column = N for N x N
344+
"""
345+
346+
return np.shape(mat)[0] == dim and np.shape(mat)[1] == dim

0 commit comments

Comments
 (0)