import diffrax
import jax.numpy as jnp
# 問題設定
A = jnp.array([[0, 1], [-6, -2]], dtype=float)
B = jnp.array([[0], [1]], dtype=float)
S_f = jnp.array([[13, 0], [0, 1]], dtype=float)
Q = jnp.array([[13, 0], [0, 1]], dtype=float)
R = jnp.array([[1]], dtype=float)
x_0 = jnp.array([[-1], [0]], dtype=float)
# 解く区間
t0, t1 = 0, 1.5
dt = 0.1
# diffrax の出力数
N = 1000LQ制御と線形レギュレータの比較
オーバーシュートが起こる設定としておく. 解く時間区間が短く, 最適レギュレータでは平衡状態には至らないような状況では, 最適レギュレータによる評価値がよくないことがある
問題設定
状態ベクトルの次元n=2, 入力制御ベクトルの次元m=1とし, 線形システム
\dot{x}(t) = Ax(t) + B u(t)
を考える. ここで A=\begin{bmatrix} 0 & 1\\ -6&-2\end{bmatrix},\quad B=\begin{bmatrix} 0\\1 \end{bmatrix} とし, 初期値は x(0)=\begin{bmatrix} -1\\ 0\end{bmatrix}
とする. 評価関数は
J=\dfrac12 x^T(T) S_Tx(T) + \int_0^T\dfrac12\left(x^T(t)Qx(t) + u^T(t)Ru(t)\right)\, dt, \quad T=1.5
において S_f=Q=\begin{bmatrix} 13& 0\\ 0&1\end{bmatrix},\quad R=\begin{bmatrix} 1 \end{bmatrix} とする. すなわち
J=\dfrac12 \left(13x_1(T)^2 + x_2(T)^2 \right) +\dfrac12 \int_0^T\left(13x_1(t)^2 + x_2(t)^2 +u(t)^2\right)\, dt.
を最小化する最適制御を考える. このときリッカチ方程式は
\left\{\begin{aligned} &\dot{S}(t)+A^TS(t)+S(t)A-S(t)BR^{-1}B^TS(t)+Q=0,\\ &S(t_f)=S_f \end{aligned}\right.
である. さらに, \left\{\begin{aligned} &\dot{x}(t)=(A-BR^{-1}B^TS(t))x(t),\\ &x(t_0)=x_0 \end{aligned}\right.
により状態ベクトルが得られ, \lambda(t)=S(t)x(t), u(t)=-R^{-1}B^T\lambda(t)により随伴ベクトルも最適制御も得られる.
以下のステップで最適制御を求める
- リッカチ方程式の解S(t)を求める
- 状態方程式の解x(t)をS(t)を用いて求める
- \lambda(t)=S(t)x(t), u(t)=-R^{-1}B^T\lambda(t)により随伴ベクトルと最適制御を求める
- 可視化
1. リッカチ方程式の解S(t)を求める
def vector_field(t, S, args):
dot_S = A.T @ S + S @ A - S @ B @ jnp.linalg.inv(R) @ B.T @ S + Q
return dot_S
term = diffrax.ODETerm(vector_field)
solver = diffrax.Tsit5()
saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t1, N))
sol = diffrax.diffeqsolve(term, solver, t0, t1, dt, S_f, saveat=saveat)
S_t = sol.ys[::-1]2. x(t)を求める
先程得たS(t)を用いて, x(t)は次の常微分方程式に従う \left\{\begin{aligned} &\dot{x}(t)=(A-BR^{-1}B^TS(t))x(t),\\ &x(t_0)=x_0 \end{aligned}\right.
def return_value_S_at_t(t):
k = N * (t - t0) // (t1 - t0)
return jnp.array(k, int)
def vector_field(t, x, args):
k = return_value_S_at_t(t)
dot_x = (A - B @ jnp.linalg.inv(R) @ B.T * S_t[k]) @ x
return dot_x
term = diffrax.ODETerm(vector_field)
solver = diffrax.Tsit5()
saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t1, 1000))
sol = diffrax.diffeqsolve(term, solver, t0, t1, dt, x_0, saveat=saveat)
x_t = sol.ys
ts = sol.ts3. \lambda(t)=S(t)x(t), u(t)=-R^{-1}B^T\lambda(t)により随伴ベクトルも最適制御を得る
lambda_t = jnp.array([S_t[k] @ x_t[k] for k in range(len(ts))])
u_t = -jnp.linalg.inv(R) @ B.T @ jnp.array([lambda_t[k] for k in range(len(ts))])# スコアの計算
def calclate_score(x, u) -> float:
N = x.shape[0] # サンプル数
dt = (t1 - t0) / N # 時間刻み
x_T = x[-1] # 最後の時刻の状態, 形状は (n, 1)
terminal_cost = 0.5 * jnp.matmul(x_T.T, jnp.matmul(S_f, x_T)).squeeze()
xQx = jnp.einsum("nkj,ki,nij->n", x, Q, x) # 形状は (N,)
uRu = jnp.einsum("nkj,ki,nij->n", u, R, u) # 形状は (N,)
integral_cost = 0.5 * jnp.sum(xQx + uRu) * dt
# 総コスト J の計算
J = terminal_cost + integral_cost
return J4. 可視化
得られた状態ベクトル(左図)と制御入力(右図)をグラフを表示する
- 左図より, オーバーシュートは発生しているが抑えられていることがわかる
- 状態ベクトル, 制御入力ともにk=5あたりで0の十分近くに到達している
import matplotlib.pyplot as plt
X = jnp.array(x_t).reshape(N, 2)
U = jnp.array(u_t).reshape(N, 1)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
for k in range(2):
plt.plot(ts, X[:, k], label=f"x_{k}")
plt.axhline(0, color="black", linestyle="--", linewidth=0.7)
plt.legend()
plt.subplot(1, 2, 2)
for k in range(1):
plt.plot(ts, U[:, k], linestyle="--", label="u")
plt.axhline(0, color="black", linestyle="--", linewidth=0.7)
plt.legend()
plt.show()
print(calclate_score(x_t, u_t))
3.366081
リッカチ方程式の解法
リッカチ微分方程式を線形微分方程式に変換して解く手法は今回スキップ
最適レギュレーターとの比較
最適レギュレーターは項を分けて解説するが, フィードバックを
u = R^{-1}B^TP x
により与える制御である. ここでPは以下のリッカチ方程式の解
A^TP+PA-PBR^{-1}B^TP+Q=0
とする. リッカチの微分方程式と区別するために, 代数リッカチ方程式などと呼ぶ.
Arimoto-Potter の方法
説明省略
H = jnp.array([[0, -6, 13, 0], [1, -2, 0, 1], [0, 0, 0, -1], [0, 1, 6, 2]])
_, S = jnp.linalg.eig(H)
E = S[:2, 2:]
D = S[2:, 2:]
P = (E @ jnp.linalg.inv(D)).real # 複素行列となるため虚部を落としておくdef vector_field(t, x, args):
dot_x = (A - B @ jnp.linalg.inv(R) @ B.T * P.real) @ x
return dot_x
term = diffrax.ODETerm(vector_field)
solver = diffrax.Tsit5()
saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t1, 1000))
sol = diffrax.diffeqsolve(term, solver, t0, t1, dt, x_0, saveat=saveat)
x_t = sol.ys
ts = sol.ts
u_t = -jnp.linalg.inv(R) @ B.T @ P.real @ x_tX = jnp.array(x_t).reshape(N, 2)
U = jnp.array(u_t).reshape(N, 1)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
for k in range(2):
plt.plot(ts, X[:, k], label=f"x_{k}")
plt.axhline(0, color="black", linestyle="--", linewidth=0.7)
plt.legend()
plt.subplot(1, 2, 2)
for k in range(1):
plt.plot(ts, U[:, k], linestyle="--", label="u")
plt.axhline(0, color="black", linestyle="--", linewidth=0.7)
plt.legend()
plt.show()
print(calclate_score(x_t, u_t))
3.4015243