Beyond Convexity #2: Barygradient flow
Published:
In my latest paper, I introduced a generalized proximal point algorithm (PPA): given a point $(x,q) \in \mathbb{R}^m \times \mathring \Delta_S$ ($m\ge 1$, $S \ge 2$ and $\Delta_S$ the probability simplex), the next iterate $(x’,q’)$ is given by: \(( \nabla f + \lambda A )(x',q') = \nabla f(x,q) + c \begin{pmatrix} 0_m \\ 1_S \end{pmatrix}\) where \(c = \log(\sum_s e^{\log(q'_s)-\lambda \ell_s(x')}) -\log(\sum_s e^{\log(q_s)}),\) for step-size $\lambda>0$, $f(x,q)=\frac12 ||x||^2 + h(q)$ with $h(q)=\sum_{s=1}^S q_s \log(q_s)$ the negentropy, and \(A(x,q) = \begin{pmatrix} J_\ell(x)^\intercal q \\ -\ell(x) \end{pmatrix}\) where $J_\ell$ denotes the Jacobian matrix of $\ell=(\ell_1,\dots,\ell_S):\mathbb{R}^m \rightarrow \mathbb{R}^S$ with each $\ell_s$ convex ($\forall 1\le s \le S$).
In particular, I showed that $A$ is a monotone operator and that the $f$ -resolvent $(\nabla f + \lambda A)^{-1} \circ \nabla f$ is Bregman firmly nonexpansive with respect to the Bregman divergence $D_f$. In other words, this ensures that this generalized PPA will converge to a fixed point, if there exists any.
In this blog post, we propose to generalize the gradient flow ordinary differential equation (ODE) by letting $\lambda \rightarrow 0$ in our generalized PPA (for a great introduction to gradient flow, see Bach’s blog post).
Definition: Let $F(x,q) = q^\intercal \ell(x)$. We define the barygradient flow ODE as \(\dot \zeta(t) = - \begin{pmatrix} I_m & 0 \\ 0 & -I_S \end{pmatrix} \nabla F( (\nabla f)^{-1}( \zeta(t) -\log(\sum_s e^{\xi_s(t)-1}) \begin{pmatrix} 0_m \\ 1_S \end{pmatrix} ) ) + \gamma(t) \begin{pmatrix} 0_m \\ 1_S \end{pmatrix} ,\) where $\zeta = (x,\xi) : \mathbb{R}_+ \rightarrow \mathbb{R}^m \times \mathbb{R}^S$ and \(\gamma(t) = \frac{\sum_s [ \dot \xi_s(t) - \ell_s(x(t)) ] e^{\xi_s(t)-1}}{\sum_s e^{\xi_s(t)-1}} = q(t)^\intercal [ \dot \xi(t) - \ell(x(t)) ]\) with $q(t)=(\nabla h)^{-1}(\xi(t)-\log(\sum_s e^{\xi_s(t)-1}) 1_S)$.
We point out that \(\begin{pmatrix} I_m & 0 \\ 0 & -I_S \end{pmatrix} \nabla F( (\nabla f)^{-1}( \zeta(t) -\log(\sum_s e^{\xi_s(t)-1}) \begin{pmatrix} 0_m \\ 1_S \end{pmatrix}) ) = A(x(t), q(t)) .\)
Monotonicity analysis
Contrary to classic gradient flow, the function $F(x(t),q(t))$ is not necessarily nonincreasing along the flow. Indeed, \(\frac{d}{dt}F( (\nabla f)^{-1}( \zeta(t) -\log(\sum_s e^{\xi_s(t)-1}) \begin{pmatrix} 0_m \\ 1_S \end{pmatrix} ) ) = \frac{d}{dt}[(\nabla h)^{-1}(\xi(t)-\log(\sum_s e^{\xi_s(t)-1}) 1_S)]^\intercal \ell(x(t)) + q(t)^\intercal \frac{d}{dt} \ell(x(t)) ,\) where \(\frac{d}{dt}[(\nabla h)^{-1}(\xi(t)-\log(\sum_s e^{\xi_s(t)-1}) 1_S)] = [\nabla^2 h(q(t))]^{-1} \dot \xi(t) - \frac{\sum_s \dot \xi_s(t) e^{\xi_s(t)-1}}{\sum_s e^{\xi_s(t)-1}} q(t)\) and $\frac{d}{dt} \ell(x(t)) = J_\ell(x(t)) \dot x(t)$.
Hence, \(\frac{d}{dt}F(x(t),q(t)) = \underbrace{ \ell(x(t))^\intercal [\nabla^2 h(q(t))]^{-1} \ell(x(t)) - F(x(t),q(t))^2 }_{\text{Var}_{\tau \sim q(t)}(\ell_\tau(x(t)))} - ||J_\ell(x(t))^\intercal q(t)||^2 ,\) which is not necessarily nonpositive.
Entropy analysis
Denote $\chi(t) = h(q(t))$. Then, \(\frac{d}{dt} \chi(t) = \dot q(t)^\intercal \nabla h(q(t)) = \{ [\nabla^2 h(q(t))]^{-1} \dot \xi(t) - [q(t)^\intercal \dot \xi(t)] q(t) \}^\intercal \{\xi(t)-\log(\sum_s e^{\xi_s(t)-1}) 1_S\} \\ = \xi(t)^\intercal \underbrace{[ \text{Diag}(q(t))-q(t)q(t)^\intercal ]}_{\text{Cov}(q(t))} \ell(x(t)),\) where $\text{Cov}(q(t))$ denotes the covariance matrix1 of the categorical distribution $q(t)$.
Remark: The barygradient flow can be equivalently rewritten as the following natural gradient flow: \(\dot \zeta(t) = - \begin{pmatrix} I_m & 0 \\ 0 & -\text{Cov}(q(t))^\dagger \end{pmatrix} \nabla \tilde F( \zeta(t) ) + [ \gamma(t) + \frac{1_S^\intercal \ell(x(t))}{S} ] \begin{pmatrix} 0_m \\ 1_S \end{pmatrix} ,\) where $\dagger$ denotes the Moore–Penrose pseudoinverse and \(\tilde F(x,\xi) = (\nabla h)^{-1}(\xi - \log(\sum_s e^{\xi_s-1}) 1_S)^\intercal \ell(x).\)
$\text{Cov}(q(t))$ is also the Jacobian matrix of the softargmax function $\xi \mapsto (\nabla h)^{-1}(\xi-\log(\sum_s e^{\xi_s-1})1_S)$ evaluated at $\xi(t)$. ↩