import diffrax
import jax.numpy as jnp
from tqdm.notebook import tqdm_notebook as tqdm
# 問題設定
A = jnp.array([[0, 1], [-6, -2]], dtype=float)
B = jnp.array([[0], [1]], dtype=float)
S_f = jnp.array([[13, 0], [0, 1]], dtype=float)
Q = jnp.array([[13, 0], [0, 1]], dtype=float)
R = jnp.array([[1]], dtype=float)
x_0 = jnp.array([[-1], [0]], dtype=float)
# 解く区間
t0, t1 = 0, 1.5
dt = 0.01
# diffrax の出力数
N = 1000LQ制御の最適制御を勾配法で求める
問題設定
状態ベクトルの次元n=2, 入力制御ベクトルの次元m=1とし, 線形システム
\dot{x}(t) = Ax(t) + B u(t)
を考える. ここで A=\begin{bmatrix} 0 & 1\\ -6&-2\end{bmatrix},\quad B=\begin{bmatrix} 0\\1 \end{bmatrix} とし, 初期値は x(0)=\begin{bmatrix} -1\\ 0\end{bmatrix}
とする. 評価関数は
J=\dfrac12 x^T(T) S_Tx(T) + \int_0^T\dfrac12\left(x^T(t)Qx(t) + u^T(t)Ru(t)\right)\, dt, \quad T=1.5
において S_f=Q=\begin{bmatrix} 13& 0\\ 0&1\end{bmatrix},\quad R=\begin{bmatrix} 1 \end{bmatrix} とする. すなわち
J=\dfrac12 \left(13x_1(T)^2 + x_2(T)^2 \right) +\dfrac12 \int_0^T\left(13x_1(t)^2 + x_2(t)^2 +u(t)^2\right)\, dt.
を最小化する最適制御を考える.
制御入力がない場合
まず, 制御入力がなかった場合(u=0)のダイナミクスを明らかにする
import jax
@jax.jit
def compute_J(x, u):
N = x.shape[0] # サンプル数
dt = (t1 - t0) / N # 時間刻み
x_T = x[-1] # 最後の時刻の状態, 形状は (n, 1)
terminal_cost = 0.5 * jnp.matmul(x_T.T, jnp.matmul(S_f, x_T)).squeeze()
xQx = jnp.einsum("nkj,ki,nij->n", x, Q, x) # 形状は (N,)
uRu = jnp.einsum("nkj,ki,nij->n", u, R, u) # 形状は (N,)
integral_cost = 0.5 * jnp.sum(xQx + uRu) * dt
# 総コスト J の計算
J = terminal_cost + integral_cost
return jnp.sum(J)# 汎用設定
ts = jnp.linspace(t0, t1, N)
solver = diffrax.Tsit5()
saveat = diffrax.SaveAt(ts=ts)def function_f(x, u):
return A @ x + B @ u
def vector_field_x(t, x, args):
u_t = args.evaluate(t)
return function_f(x, u_t)
state_eq = diffrax.ODETerm(vector_field_x)u = jnp.zeros((N, 1, 1))
u_func = diffrax.LinearInterpolation(ts=ts, ys=u)
sol = diffrax.diffeqsolve(state_eq, solver, t0, t1, dt, x_0, args=u_func, saveat=saveat)
x = sol.ysimport matplotlib.pyplot as plt
def plot_control(x, u):
X = jnp.array(x).reshape(N, 2)
U = jnp.array(u).reshape(N, 1)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
for k in range(2):
plt.plot(ts, X[:, k], label=f"x_{k}")
plt.axhline(0, color="black", linestyle="--", linewidth=0.7)
plt.legend()
plt.subplot(1, 2, 2)
for k in range(1):
plt.plot(ts, U[:, k], linestyle="--", label="u")
plt.axhline(0, color="black", linestyle="--", linewidth=0.7)
plt.legend()
plt.show()plot_control(x, u)
print(compute_J(x, u))
3.6685557
オイラー・ラグランジュ方程式
\begin{aligned} &\dot{x}(t) = Ax(t) + Bu(t), \quad x(0)=x_0, \\ &\dot{\lambda}(t) =-Qx(t)-A^T \lambda(t) , \quad \lambda(t_f) = S_f (x(t_f)), \\ &u^T(t)R + \lambda^T(t)B=0. \end{aligned}
に対して, ハミルトニアンをH=L+\lambda^T fとする.
最急降下法
以下のアルゴリズムを構築する
- 適当なuを制御入力の初期推定解とする
- uを用いて状態方程式を解いてxを, 随伴方程式を解いて\lambdaを求める
- x,u,\lambdaから\frac{\partial H}{\partial u}を計算する.
- 勾配のノルム\left(\int_{t_0}^{t_f}\left\|\frac{\partial H}{\partial u}\right\|^2\right)^{\frac{1}{2}}が十分小さければ停止する
- s=-\left(\frac{\partial H}{\partial u}\right)^Tとおく
- 制御入力を u+\alpha sとしたときの評価関数値J[u+\alpha s]が最小になるスカラー\alphaを求め, u=u+\alpha sと更新してステップ2に戻る
import jax
@jax.jit
def hamiltonian(x, u, lambda_):
L = 0.5 * (jnp.matmul(x.T, jnp.matmul(Q, x)) + jnp.matmul(u.T, jnp.matmul(R, u)))
f = function_f(x, u)
H = L + jnp.matmul(lambda_.T, f)
return H.squeeze()
grad_H_x = jax.grad(hamiltonian, argnums=0)
def vector_field_lambda(t, lambda_, args):
x_t = args[0].evaluate(t)
u = args[1].evaluate(t)
dot_lambda = grad_H_x(x_t, u, lambda_)
return dot_lambda
lambda_eq = diffrax.ODETerm(vector_field_lambda)
@jax.jit
def compute_hamiltonian_and_gradients(x, u, lambda_):
"""
各時刻におけるハミルトニアンとその u に関する勾配を計算します。
Parameters:
x (jnp.array): 状態変数、形状は (N, n, 1)
u (jnp.array): 制御入力、形状は (N, m, 1)
lambda_ (jnp.array): ラグランジュ乗数、形状は (N, n, 1)
Returns:
Tuple[jnp.array, jnp.array]: ハミルトニアンの配列と勾配の配列
"""
# ベクトル化されたハミルトニアン関数
H = jax.vmap(hamiltonian, in_axes=(0, 0, 0))(x, u, lambda_)
# ベクトル化された勾配関数
hamiltonian_grad_u = jax.grad(hamiltonian, argnums=1)
grad_u = jax.vmap(hamiltonian_grad_u, in_axes=(0, 0, 0))(x, u, lambda_)
return H, grad_u@jax.jit
def J_alpha(alpha, u, s):
u_new = u + alpha * s
u_func = diffrax.LinearInterpolation(ts=ts, ys=u_new)
sol = diffrax.diffeqsolve(state_eq, solver, t0, t1, dt, x_0, args=u_func, saveat=saveat)
x = sol.ys
return compute_J(x, u_new)from scipy.optimize import minimize_scalar
u = jnp.zeros((N, 1, 1))
for _ in tqdm(range(20)):
u_func = diffrax.LinearInterpolation(ts=ts, ys=u)
sol = diffrax.diffeqsolve(state_eq, solver, t0, t1, dt, x_0, args=u_func, saveat=saveat)
x = sol.ys
lambda_tf = S_f @ x[-1]
x_func = diffrax.LinearInterpolation(ts=ts, ys=x[::-1])
sol = diffrax.diffeqsolve(lambda_eq, solver, t0, t1, dt, lambda_tf, args=[x_func, u_func], saveat=saveat)
lambda_ = sol.ys[::-1]
H, grad_u = compute_hamiltonian_and_gradients(x, u, lambda_)
# u の更新
grad_norm = jnp.sqrt(jnp.sum(grad_u**2) * dt)
s = -grad_u
minimize_result = minimize_scalar(J_alpha, bounds=(0, 1), args=(u, s))
alpha_opt = minimize_result.x
score = compute_J(x, u)
print(f"{score=:.5f}, {grad_norm=:.5f}")
u += alpha_opt * sscore=3.66856, grad_norm=2.67171
score=3.35222, grad_norm=0.35953
score=3.34473, grad_norm=0.07673
score=3.34446, grad_norm=0.01657
score=3.34446, grad_norm=0.00531
score=3.34446, grad_norm=0.00107
score=3.34446, grad_norm=0.00069
score=3.34446, grad_norm=0.00046
score=3.34446, grad_norm=0.00011
score=3.34446, grad_norm=0.00004
score=3.34446, grad_norm=0.00001
score=3.34446, grad_norm=0.00000
score=3.34446, grad_norm=0.00000
score=3.34446, grad_norm=0.00000
score=3.34446, grad_norm=0.00000
score=3.34446, grad_norm=0.00000
score=3.34446, grad_norm=0.00000
score=3.34446, grad_norm=0.00000
score=3.34446, grad_norm=0.00000
score=3.34446, grad_norm=0.00000
plot_control(x, u)
共役勾配法
以下のアルゴリズムを構築する
- 適当なuを制御入力の初期推定解, s_-=0, d_-は適当(計算に使われない)に置く
- 制御入力を u としたときの状態方程式の解x, 随伴方程式の解\lambda に対してd = -\frac{\partial H}{\partial u} とする
- 勾配のノルム\left(\int_{t_0}^{t_f}\left\|d(t)\right\|^2\,dt\right)^{\frac{1}{2}}が十分小さければ停止する
- ポラックリビエポリャック法やフレッチャーリーブス法により\betaをd_-, d から定める
- s=d + \beta s_-とする
- 制御入力を u+\alpha sとしたときの評価関数値J[u+\alpha s]が最小になるスカラー\alphaを求め, u=u+\alpha sとおく
- d_-=d, s_-=sと代入する
from scipy.optimize import minimize_scalar
u = jnp.zeros((N, 1, 1))
s_ = jnp.zeros((N, 1, 1))
d_ = jnp.ones((N, 1, 1))
for _ in tqdm(range(20)):
u_func = diffrax.LinearInterpolation(ts=ts, ys=u)
sol = diffrax.diffeqsolve(state_eq, solver, t0, t1, dt, x_0, args=u_func, saveat=saveat)
x = sol.ys
lambda_tf = S_f @ x[-1]
x_func = diffrax.LinearInterpolation(ts=ts, ys=x[::-1])
sol = diffrax.diffeqsolve(lambda_eq, solver, t0, t1, dt, lambda_tf, args=[x_func, u_func], saveat=saveat)
lambda_ = sol.ys[::-1]
H, grad_u = compute_hamiltonian_and_gradients(x, u, lambda_)
d = -grad_u
grad_norm = jnp.sqrt(jnp.sum(d**2) * dt).squeeze()
beta = jnp.einsum("ijk, ijk", d, d) / jnp.einsum("ijk, ijk", d_, d_) # フレッチャーリーブス法に
s = d + beta * s_
alpha_opt = minimize_scalar(J_alpha, bounds=(0, 1), args=(u, s)).x
score = compute_J(x, u)
print(f"{score=:.5f}, {grad_norm=:.5f}")
u += alpha_opt * sscore=3.66856, grad_norm=2.67171
score=3.35222, grad_norm=0.35953
score=3.34473, grad_norm=0.07673
score=3.34446, grad_norm=0.01657
score=3.34446, grad_norm=0.00531
score=3.34446, grad_norm=0.00107
score=3.34446, grad_norm=0.00069
score=3.34446, grad_norm=0.00046
score=3.34446, grad_norm=0.00011
score=3.34446, grad_norm=0.00004
score=3.34446, grad_norm=0.00001
score=3.34446, grad_norm=0.00000
score=3.34446, grad_norm=0.00000
score=3.34446, grad_norm=0.00000
score=3.34446, grad_norm=0.00000
score=3.34446, grad_norm=0.00000
score=3.34446, grad_norm=0.00000
score=3.34446, grad_norm=0.00000
score=3.34446, grad_norm=0.00000
score=3.34446, grad_norm=0.00000
plot_control(x, u)