Skip to content

Commit c230911

Browse files
committed
compile new estimators
Signed-off-by: Alex Kmetko <akmetko04@gmail.com>
1 parent 4e73e03 commit c230911

File tree

3 files changed

+205
-5
lines changed

3 files changed

+205
-5
lines changed

app/rdd2/src/casadi/rdd2.py

+133-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,13 @@
77
from pathlib import Path
88
import casadi as ca
99
import cyecca.lie as lie
10-
from cyecca.lie.group_so3 import SO3Quat, SO3EulerB321
11-
from cyecca.lie.group_se23 import SE23Quat, se23, SE23LieGroupElement, SE23LieAlgebraElement
10+
from cyecca.lie.group_so3 import SO3Quat, SO3EulerB321, so3
11+
from cyecca.lie.group_se23 import (
12+
SE23Quat,
13+
se23,
14+
SE23LieGroupElement,
15+
SE23LieAlgebraElement,
16+
)
1217
from cyecca.symbolic import SERIES
1318

1419
print('python: ', sys.executable)
@@ -33,6 +38,11 @@
3338
z_integral_max = 0 # 5.0
3439
ki_z = 0.05 # velocity z integral gain
3540

41+
# estimator params
42+
att_w_acc =0.2
43+
att_w_gyro_bias = 0.1
44+
param_att_w_mag = 0.2
45+
3646
def derive_control_allocation(
3747
):
3848
"""
@@ -457,6 +467,125 @@ def derive_strapdown_ins_propagation():
457467
}
458468
return eqs
459469

470+
def derive_attitude_estimator():
471+
# Define Casadi variables
472+
q0 = ca.SX.sym("q", 4)
473+
q = SO3Quat.elem(param=q0)
474+
mag = ca.SX.sym("mag", 3)
475+
mag_decl = ca.SX.sym("mag_decl", 1)
476+
gyro = ca.SX.sym("gyro", 3)
477+
accel = ca.SX.sym("accel", 3)
478+
dt = ca.SX.sym("dt", 1)
479+
480+
# Convert magnetometer to quat
481+
mag1 = SO3Quat.elem(ca.vertcat(0, mag))
482+
483+
# correction angular velocity vector
484+
correction = ca.SX.zeros(3, 1)
485+
486+
# Convert vector to world frame and extract xy component
487+
spin_rate = ca.norm_2(gyro)
488+
mag_earth = (q.inverse() * mag1 * q).param[1:]
489+
490+
mag_err = (
491+
ca.fmod(ca.atan2(mag_earth[1], mag_earth[0]) - mag_decl + ca.pi, 2 * ca.pi)
492+
- ca.pi
493+
)
494+
495+
# Change gain if spin rate is large
496+
fifty_dps = 0.873
497+
gain_mult = ca.if_else(spin_rate > fifty_dps, ca.fmin(spin_rate / fifty_dps, 10), 1)
498+
499+
# Move magnetometer correction in body frame
500+
correction += (
501+
(q.inverse() * SO3Quat.elem(ca.vertcat(0, 0, 0, mag_err)) * q).param[1:]
502+
* param_att_w_mag
503+
* gain_mult
504+
)
505+
506+
# Correction from accelerometer
507+
accel_norm_sq = ca.norm_2(accel) ** 2
508+
509+
# Correct accelerometer only if g between
510+
higher_lim_check = ca.if_else(accel_norm_sq < ((g * 1.1) ** 2), 1, 0)
511+
lower_lim_check = ca.if_else(accel_norm_sq > ((g * 0.9) ** 2), 1, 0)
512+
513+
# Correct gravity as z
514+
correction += (
515+
lower_lim_check
516+
* higher_lim_check
517+
* ca.cross(np.array([[0], [0], [-1]]), accel / ca.norm_2(accel))
518+
* att_w_acc
519+
)
520+
521+
## TODO add gyro bias stuff
522+
523+
# Add gyro to correction
524+
correction += gyro
525+
526+
# Make the correction
527+
q1 = q * so3.elem(correction * dt).exp(SO3Quat)
528+
529+
# Return estimator
530+
f_att_estimator = ca.Function(
531+
"attitude_estimator",
532+
[q0, mag, mag_decl, gyro, accel, dt],
533+
[q1.param],
534+
["q", "mag", "mag_decl", "gyro", "accel", "dt"],
535+
["q1"],
536+
)
537+
538+
return {"attitude_estimator": f_att_estimator}
539+
540+
def derive_position_correction():
541+
## Initilaizing measurments
542+
z = ca.SX.sym("gps", 3)
543+
dt = ca.SX.sym("dt", 1)
544+
P = ca.SX.sym("P", 6, 6)
545+
546+
# Initialize state
547+
est_x = ca.SX.sym("est", 10) # [x,y,z,u,v,w,q0,q1,q2,q3]
548+
x0 = est_x[0:6] # [x,y,z,u,v,w]
549+
550+
# Define the state transition matrix (A)
551+
A = ca.SX.eye(6)
552+
A[0:3, 3:6] = np.eye(3) * dt # The velocity elements multiply by dt
553+
554+
## TODO: may need to pass Q and R throught the casadi function
555+
Q = np.eye(6) * 1e-5 # Process noise (uncertainty in system model)
556+
R = np.eye(3) * 1e-2 # Measurement noise (uncertainty in sensors)
557+
558+
# Measurement matrix
559+
H = ca.horzcat(ca.SX.eye(3), ca.SX.zeros(3, 3))
560+
561+
# extrapolate uncertainty
562+
P_s = A @ P @ A.T + Q
563+
564+
## Measurment Update
565+
## vel is a basic integral given acceleration values. need to figure out how to get v0
566+
y = H @ P_s @ H.T + R
567+
568+
# Update Kalman Gain
569+
K = P_s @ H.T @ ca.inv(y)
570+
571+
# Update estimate w/ measurment
572+
x_new = x0 + K @ (z - H @ x0)
573+
574+
# Update the measurement uncertainty
575+
P_new = (np.eye(6) - (K @ H)) @ P_s
576+
577+
# Return to have attitude updated
578+
x_new = ca.vertcat(x_new, ca.SX.zeros(4))
579+
580+
f_pos_estimator = ca.Function(
581+
"position_correction",
582+
[est_x, z, dt, P],
583+
[x_new, P_new],
584+
["est_x", "gps", "dt", "P"],
585+
["x_new", "P_new"],
586+
)
587+
return {"position_correction": f_pos_estimator}
588+
460589

461590
def generate_code(eqs: dict, filename, dest_dir: str, **kwargs):
462591
"""
@@ -499,6 +628,8 @@ def generate_code(eqs: dict, filename, dest_dir: str, **kwargs):
499628
eqs.update(derive_joy_auto_level())
500629
eqs.update(derive_strapdown_ins_propagation())
501630
eqs.update(derive_control_allocation())
631+
eqs.update(derive_attitude_estimator())
632+
eqs.update(derive_position_correction())
502633

503634
for name, eq in eqs.items():
504635
print('eq: ', name)

app/rdd2/src/estimate.c

+71-2
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,15 @@ struct context {
4343
synapse_pb_Odometry odometry_ethernet;
4444
synapse_pb_Imu imu;
4545
synapse_pb_Odometry odometry;
46-
struct zros_sub sub_odometry_ethernet, sub_imu;
46+
struct zros_sub sub_odometry_ethernet, sub_imu, sub_mag;
4747
struct zros_pub pub_odometry;
4848
double x[3];
4949
struct k_sem running;
5050
size_t stack_size;
5151
k_thread_stack_t *stack_area;
5252
struct k_thread thread_data;
5353
struct perf_counter perf;
54+
synapse_pb_MagneticField mag;
5455
};
5556

5657
// private initialization
@@ -73,19 +74,22 @@ static struct context g_ctx = {
7374
},
7475
.sub_odometry_ethernet = {},
7576
.sub_imu = {},
77+
.sub_mag = {},
7678
.pub_odometry = {},
7779
.x = {},
7880
.running = Z_SEM_INITIALIZER(g_ctx.running, 1, 1),
7981
.stack_size = MY_STACK_SIZE,
8082
.stack_area = g_my_stack_area,
8183
.thread_data = {},
8284
.perf = {},
85+
.mag = {},
8386
};
8487

8588
static void rdd2_estimate_init(struct context *ctx)
8689
{
8790
zros_node_init(&ctx->node, "rdd2_estimate");
8891
zros_sub_init(&ctx->sub_imu, &ctx->node, &topic_imu, &ctx->imu, 300);
92+
zros_sub_init(&ctx->sub_mag, &ctx->node, &topic_magnetic_field, &ctx->mag, 300);
8993
zros_sub_init(&ctx->sub_odometry_ethernet, &ctx->node, &topic_odometry_ethernet,
9094
&ctx->odometry_ethernet, 10);
9195
zros_pub_init(&ctx->pub_odometry, &ctx->node, &topic_odometry_estimator, &ctx->odometry);
@@ -97,6 +101,7 @@ static void rdd2_estimate_init(struct context *ctx)
97101
static void rdd2_estimate_fini(struct context *ctx)
98102
{
99103
zros_sub_fini(&ctx->sub_imu);
104+
zros_sub_fini(&ctx->sub_mag);
100105
zros_sub_fini(&ctx->sub_odometry_ethernet);
101106
zros_pub_fini(&ctx->pub_odometry);
102107
zros_node_fini(&ctx->node);
@@ -135,10 +140,18 @@ static void rdd2_estimate_run(void *p0, void *p1, void *p2)
135140

136141
// estimator states
137142
double x[10] = {0, 0, 0, 0, 0, 0, 1, 0, 0, 0};
143+
double q[4] = {1,0,0,0};
144+
145+
// estimator covariance
146+
double P[36] = {1e-2,0,0,0,0,0,
147+
0,1e-2,0,0,0,0,
148+
0,0,1e-2,0,0,0,
149+
0,0,0,1e-2,0,0,
150+
0,0,0,0,1e-2,0,
151+
0,0,0,0,0,1e-2};
138152

139153
// poll on imu
140154
events[0] = *zros_sub_get_event(&ctx->sub_imu);
141-
142155
// int j = 0;
143156

144157
while (k_sem_take(&ctx->running, K_NO_WAIT) < 0) {
@@ -232,6 +245,62 @@ static void rdd2_estimate_run(void *p0, void *p1, void *p2)
232245
CASADI_FUNC_CALL(strapdown_ins_propagate)
233246
}
234247

248+
{
249+
CASADI_FUNC_ARGS(position_correction)
250+
251+
double gps[3] = {ctx->odometry_ethernet.pose.position.x,
252+
ctx->odometry_ethernet.pose.position.y,
253+
ctx->odometry_ethernet.pose.position.z};
254+
255+
args[0] = x;
256+
args[1] = gps;
257+
args[2] = &dt;
258+
args[3] = P;
259+
260+
res[0] = x;
261+
res[1] = P;
262+
263+
CASADI_FUNC_CALL(position_correction)
264+
}
265+
266+
/*
267+
f_att_estimator = ca.Function(
268+
"attitude_estimator",
269+
[q0, mag, mag_decl, gyro, accel, dt],
270+
[q1.param],
271+
["q", "mag", "mag_decl", "gyro", "accel", "dt"],
272+
["q1"],
273+
)*/
274+
{
275+
CASADI_FUNC_ARGS(attitude_estimator)
276+
277+
double a_b[3] = {ctx->imu.linear_acceleration.x,
278+
ctx->imu.linear_acceleration.y,
279+
ctx->imu.linear_acceleration.z};
280+
double omega_b[3] = {ctx->imu.angular_velocity.x,
281+
ctx->imu.angular_velocity.y,
282+
ctx->imu.angular_velocity.z};
283+
284+
double mag[3] = {ctx->mag.magnetic_field.x, ctx->mag.magnetic_field.y, ctx->mag.magnetic_field.z};
285+
const double decl_WL = -6.66;
286+
args[0] = q;
287+
args[1] = mag;
288+
args[2] = &decl_WL;
289+
args[3] = omega_b;
290+
args[4] = a_b;
291+
args[5] = &dt;
292+
res[0] = q;
293+
294+
CASADI_FUNC_CALL(attitude_estimator)
295+
296+
x[6] = q[0];
297+
x[7] = q[1];
298+
x[8] = q[2];
299+
x[9] = q[3];
300+
}
301+
302+
303+
235304
bool data_ok = true;
236305
for (int i = 0; i < 10; i++) {
237306
if (!isfinite(x[i])) {

west.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ manifest:
3636
path: modules/lib/zros
3737
- name: cyecca
3838
remote: cognipilot
39-
revision: 45f92d1107e350504ed8dad4e526aedf0843490b # main 12/1/24
39+
revision: 78acfbef5505da71d21b1a90ecbc2268869f0451 # main 3/8/25
4040
path: modules/lib/cyecca
4141
- name: synapse_pb
4242
remote: cognipilot

0 commit comments

Comments
 (0)