非線形システムの最適制御を勾配法のJAX実装で解く

Published

2024-09-19

Modified

2024-09-19

問題設定

状態ベクトルの次元n=2, 入力制御ベクトルの次元m=1とし, 線形システム

\dot{x}(t) = f(x(t),u(t)), \quad t\in (0,T)

を考える. ここで

f(x,t) = \begin{bmatrix} x_2\\ 2x_1(1-x_1^2) -x_2 + u\end{bmatrix}

とし, 初期値は

x(0)=\begin{bmatrix} 0.5\\ 0\end{bmatrix}

とする. 評価関数は

J=\dfrac12 x^T(T) S_Tx(T) + \dfrac12\int_0^T L(x(t), u(t))\, dt

において

L(x,t) = \left(x^TQx + u^TRu\right),\quad S_f=Q=\begin{bmatrix} 13& 0\\ 0&1\end{bmatrix},\quad R=\begin{bmatrix} 1 \end{bmatrix}

とする. すなわち

J=\dfrac12 \left(13x_1(T)^2 + x_2(T)^2 \right) +\dfrac12 \int_0^T\left(13x_1(t)^2 + x_2(t)^2 +u(t)^2\right)\, dt.

を最小化する最適制御を考える.

制御入力がない場合

まず, 制御入力を与えない場合(u=0)のダイナミクスを明らかにする

パラメータ設定

共通のパラメータはグローバル変数として定義しておく

import diffrax
import jax
import jax.numpy as jnp
from tqdm.notebook import tqdm_notebook as tqdm

# 問題設定
S_f = jnp.array([[13, 0], [0, 1]], dtype=float)
Q = jnp.array([[13, 0], [0, 1]], dtype=float)
R = jnp.array([[1]], dtype=float)

x_0 = jnp.array([[0.5], [0]], dtype=float)

# 解く区間
t0, t1 = 0, 4
dt = 0.01

評価関数の定義

@jax.jit
def compute_J(x, u):
    N = x.shape[0]
    dt_ = (t1 - t0) / N

    x_T = x[-1]  # 最後の時刻の状態, 形状は (n, 1)
    terminal_cost = 0.5 * jnp.matmul(x_T.T, jnp.matmul(S_f, x_T)).squeeze()

    xQx = jnp.einsum("nkj,ki,nij->n", x, Q, x)  # 形状は (N,)
    uRu = jnp.einsum("nkj,ki,nij->n", u, R, u)  # 形状は (N,)
    integral_cost = 0.5 * jnp.sum(xQx + uRu) * dt_

    J = terminal_cost + integral_cost
    return J

状態方程式の定義

@jax.jit
def function_f(x, u):
    """状態方程式における f"""
    x1, x2 = x
    u1 = u.squeeze()
    return jnp.array([x2, -2 * (x1**3) + 2 * x1 - x2 + u1], dtype=float)


def vector_field_x(t, x, args):
    u_t = args.evaluate(t)
    return function_f(x, u_t)


state_eq = diffrax.ODETerm(vector_field_x)

diffrax で常微分方程式を解く際の共通設定

N = 1000
ts = jnp.linspace(t0, t1, N)
solver = diffrax.Tsit5()
saveat = diffrax.SaveAt(ts=ts)

制御入力に対する状態ベクトルの可視化

状態ベクトルX(t)=(x(t),\dot{x}(t))を左のグラフに配置し, 制御入力を右のグラフに配置する.

制御入力がない場合, 安定平衡点(1,0)に漸近することがわかる

import matplotlib.pyplot as plt


def plot_control(u):
    u_func = diffrax.LinearInterpolation(ts=ts, ys=u)
    sol = diffrax.diffeqsolve(state_eq, solver, t0, t1, dt, x_0, args=u_func, saveat=saveat)
    x = sol.ys
    X = jnp.array(x).reshape(N, 2)
    U = jnp.array(u).reshape(N, 1)

    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    for k in range(2):
        plt.plot(ts, X[:, k], label=f"x_{k}")
    plt.axhline(1, color="blue", linestyle="--", linewidth=0.7)
    plt.axhline(0, color="black", linestyle="--", linewidth=0.7)
    plt.ylim(-0.5, 1.2)
    plt.legend()

    plt.subplot(1, 2, 2)
    for k in range(1):
        plt.plot(ts, U[:, k], linestyle="--", label="u")
    plt.axhline(0, color="black", linestyle="--", linewidth=0.7)
    plt.ylim(-3.0, 0.2)
    plt.legend()
    plt.show()

    score = float(compute_J(x, u))
    print(f"{score=}")
u = jnp.zeros((N, 1, 1))
plot_control(u)

score=28.224021911621094

オイラー・ラグランジュ方程式

ハミルトニアンをH=L+\lambda^T fとする. このときオイラー・ラグランジュ方程式は

\begin{aligned} &\dot{x}(t) = f(x(t), u(t)), \quad x(0)=x_0, \\ &\dot{\lambda}(t) = -\left(\dfrac{\partial H}{\partial x}\right)^T (x(t), u(t), \lambda(t)) , \quad \lambda(T) = S_f (x(T)), \\ &\dfrac{\partial H}{\partial u}(x(t),u(t),\lambda(t))=0 \end{aligned}

となる.

ここでは第一式(状態方程式)と第二式(随伴方程式)を使用する以下の2種類の勾配法を検証する.

最急降下法

以下のアルゴリズムを構築する

  1. 適当なuを制御入力の初期推定解とする
  2. uを用いて状態方程式を解いてxを, 随伴方程式を解いて\lambdaを求める
  3. x,u,\lambdaから\frac{\partial H}{\partial u}を計算する.
  4. s=-\left(\frac{\partial H}{\partial u}\right)^Tとおく
  5. 制御入力を u+\alpha sとしたときの評価関数値J[u+\alpha s]が最小になるスカラー\alphaを求め, u=u+\alpha sと更新してステップ2に戻る
    • ただし, 以下の条件を満たす場合は収束したとみなす
      • 勾配のノルム\left(\int_{0}^{T}\left\|s(t)\right\|^2\, dt\right)^{\frac{1}{2}}が十分小さい再場合
      • 制御の変更のノルム\alpha \left(\int_{0}^{T}\left\|s(t)\right\|^2\, dt\right)^{\frac{1}{2}}が十分小さい場合

随伴方程式の定義

ハミルトニアンの偏微分を用いて書けることから, 自動微分を用いて定義する.

また, 時間逆方向に解くので符号を逆転しておく.

@jax.jit
def hamiltonian(x, u, lambda_):
    L = 0.5 * (jnp.matmul(x.T, jnp.matmul(Q, x)) + jnp.matmul(u.T, jnp.matmul(R, u)))
    f = function_f(x, u)
    H = L + jnp.matmul(lambda_.T, f)
    return H.squeeze()


grad_H_x = jax.grad(hamiltonian, argnums=0)


def vector_field_lambda(t, lambda_, args):
    x_t = args[0].evaluate(t)
    u = args[1].evaluate(t)
    dot_lambda = grad_H_x(x_t, u, lambda_)
    return dot_lambda


lambda_eq = diffrax.ODETerm(vector_field_lambda)

目的関数の勾配計算

制御入力に対する変分として, ハミルトニアンの勾配を用いるのでこちらも自動微分を用いて定義する.

直線探索の最適値\alpha\in (0,1)を求めるために scipy.optimize.minimize_scalar を使用するため, 制御入力に対応する評価関数を計算できるようにしておく

@jax.jit
def compute_sequential_hamiltonian_and_gradients(x, u, lambda_):
    """
    各時刻におけるハミルトニアンとその u に関する勾配を計算します。

    Parameters:
    x (jnp.array): 状態変数、形状は (N, n, 1)
    u (jnp.array): 制御入力、形状は (N, m, 1)
    lambda_ (jnp.array): ラグランジュ乗数、形状は (N, n, 1)

    Returns:
    Tuple[jnp.array, jnp.array]: ハミルトニアンの配列と勾配の配列
    """
    # ベクトル化されたハミルトニアン関数
    H = jax.vmap(hamiltonian, in_axes=(0, 0, 0))(x, u, lambda_)
    # ベクトル化された勾配関数
    hamiltonian_grad_u = jax.grad(hamiltonian, argnums=1)
    grad_u = jax.vmap(hamiltonian_grad_u, in_axes=(0, 0, 0))(x, u, lambda_)
    return H, grad_u


@jax.jit
def J_alpha(alpha, u, s):
    """argmin_{alpha} J(u + alpha * s) を算出する"""
    u_new = u + alpha * s
    u_func = diffrax.LinearInterpolation(ts=ts, ys=u_new)
    sol = diffrax.diffeqsolve(state_eq, solver, t0, t1, dt, x_0, args=u_func, saveat=saveat)
    x_new = sol.ys
    return compute_J(x_new, u_new)

最急降下法アルゴリズムの実行

from scipy.optimize import minimize_scalar

u = jnp.zeros((N, 1, 1))
eps1 = 1e-2
eps2 = 1e-7

for i in tqdm(range(10**3)):
    u_func = diffrax.LinearInterpolation(ts=ts, ys=u)
    sol = diffrax.diffeqsolve(state_eq, solver, t0, t1, dt, x_0, args=u_func, saveat=saveat)
    x = sol.ys

    lambda_tf = S_f @ x[-1]
    x_func = diffrax.LinearInterpolation(ts=ts, ys=x[::-1])
    sol = diffrax.diffeqsolve(lambda_eq, solver, t0, t1, dt, lambda_tf, args=[x_func, u_func], saveat=saveat)
    lambda_ = sol.ys[::-1]

    H, grad_u = compute_sequential_hamiltonian_and_gradients(x, u, lambda_)

    grad_norm = jnp.sqrt(jnp.sum(grad_u**2) * dt)
    if grad_norm < eps1:
        print(f"roop done for {i=}, because {grad_norm=:.5f}")
        break

    s = -grad_u
    alpha_opt = minimize_scalar(J_alpha, bounds=(0, 1), args=(u, s)).x

    diff_norm = alpha_opt * jnp.sqrt(jnp.sum(s**2) * dt)
    if diff_norm < eps2:
        print(f"roop done for {i=}, because {diff_norm=:.8f}")
        break
    if i % 100 == 0:
        score = compute_J(x, u)
        print(f"{score=:.7f}, {grad_norm=:.5f}, {diff_norm=:.8f}")
    u += alpha_opt * s
score=28.2240219, grad_norm=14.06202, diff_norm=3.41370487
score=2.0676792, grad_norm=0.72200, diff_norm=0.00191596
score=2.0527604, grad_norm=1.48579, diff_norm=0.00058045
score=2.0234804, grad_norm=0.36918, diff_norm=0.00208106
score=2.0040922, grad_norm=0.02795, diff_norm=0.00001824
score=2.0040870, grad_norm=0.07580, diff_norm=0.00002677
score=2.0040841, grad_norm=0.05358, diff_norm=0.00003470
roop done for i=682, because grad_norm=0.00986

最急降下法の計算結果

不安定平衡点の近くに状態を定める制御が構成できている

plot_control(u)

score=2.0040838718414307

共役勾配法

以下のアルゴリズムを構築する

  1. 適当なuを制御入力の初期推定解, s_-=0, d_-は適当(計算に使われない)に置く
  2. 制御入力を u としたときの状態方程式の解x, 随伴方程式の解\lambda に対してd = -\frac{\partial H}{\partial u} とする
  3. ポラック・リビエ・ポリャック法やフレッチャー・リーブス法により\betad_-, d から定める
    • ポラック・リビエ・ポリャック法のほうが収束が速かったのでそちらを採用
  4. s=d + \beta s_-とする
  5. 制御入力を u+\alpha sとしたときの評価関数値J[u+\alpha s]が最小になるスカラー\alphaを求め, u=u+\alpha sとおく
    • ただし, 以下の条件を満たす場合は収束したとみなす
      • 勾配のノルム\left(\int_{0}^{R}\left\|d(t)\right\|^2\, dt\right)^{\frac{1}{2}}が十分小さい再場合
      • 制御の変更のノルム\alpha \left(\int_{0}^{T}\left\|d(t)\right\|^2\, dt\right)^{\frac{1}{2}}が十分小さい場合
  6. d_-=d, s_-=sと代入する
  7. 共役方向の誤差が蓄積するので定期的にリセットする(これよくわからなかったけどやらないと収束しない)

共役勾配法アルゴリズムの実行

u = jnp.zeros((N, 1, 1))
s_ = jnp.zeros((N, 1, 1))
d_ = jnp.ones((N, 1, 1))

eps1 = 1e-2
eps2 = 1e-7

for i in tqdm(range(10**3)):
    u_func = diffrax.LinearInterpolation(ts=ts, ys=u)
    sol = diffrax.diffeqsolve(state_eq, solver, t0, t1, dt, x_0, args=u_func, saveat=saveat)
    x = sol.ys

    lambda_tf = S_f @ x[-1]
    x_func = diffrax.LinearInterpolation(ts=ts, ys=x[::-1])
    sol = diffrax.diffeqsolve(lambda_eq, solver, t0, t1, dt, lambda_tf, args=[x_func, u_func], saveat=saveat)
    lambda_ = sol.ys[::-1]

    _, grad_u = compute_sequential_hamiltonian_and_gradients(x, u, lambda_)
    d = -grad_u

    grad_norm = jnp.sqrt(jnp.sum(d**2) * dt).squeeze()
    if grad_norm < eps1:
        print(f"roop done for {i=}, because {grad_norm=:.5f}")
        break

    beta = jnp.einsum("ijk, ijk", d, d - d_) / jnp.einsum("ijk, ijk", d_, d_)  # ポラック・リビエ・ポリャック法
    # beta = jnp.einsum("ijk, ijk", d, d) / jnp.einsum("ijk, ijk", d_, d_)  # フレッチャー・リーブス法
    s = d + beta * s_
    alpha_opt = minimize_scalar(J_alpha, bounds=(0, 1), args=(u, s)).x

    diff_norm = alpha_opt * jnp.sqrt(jnp.sum(s**2) * dt)
    if diff_norm < eps2:
        print(f"roop done for {i=}, because {diff_norm=:.8f}")
        break
    if i % 3 == 0:
        # s_ には共役方向の誤差が蓄積するので定期的にリセットする(?)
        s_ = jnp.zeros((N, 1, 1))
    if i % 100 == 0:
        score = compute_J(x, u)
        print(f"{score=:.7f}, {grad_norm=:.5f}, {diff_norm=:.8f}")
    u += alpha_opt * s
    d_, s_ = d, s
score=28.2240219, grad_norm=14.06202, diff_norm=3.41370487
roop done for i=71, because grad_norm=0.00364

共役勾配法の計算結果

最急降下法より速く収束判定され, こちらでも制御ができている

plot_control(u)

score=2.004102945327759