Beyond Convexity #2: Barygradient flow
Published:
In my latest paper, I introduced a generalized proximal point algorithm (PPA): given a point $(x,q) \in \mathbb{R}^m \times \Delta_S$ ($m\ge 1$, $S \ge 2$ and $\Delta_S$ the probability simplex), the next iterate $(x’,q’)$ is given by: \(( \nabla f + \lambda A )(x',q') = \nabla f(x,q)\) for step-size $\lambda>0$, $f(x,q)=\frac12 ||x||^2 + h(q)$ with $h(q)=\sum_{s=1}^S q_s \log(q_s)$ the negentropy, and \(A(x,q) = \begin{pmatrix} J_\ell(x)^\intercal q \\ -\ell(x) \end{pmatrix}\) where $J_\ell$ denotes the Jacobian matrix of $\ell=(\ell_1,\dots,\ell_S):\mathbb{R}^m \rightarrow \mathbb{R}^S$ with each $\ell_s$ convex ($\forall 1\le s \le S$).
In particular, I showed that $A$ is a monotone operator and that the $f$-resolvent $(\nabla f + \lambda A)^{-1} \circ \nabla f$ is Bregman firmly nonexpansive with respect to the Bregman divergence $D_f$. In other words, this ensures that this generalized PPA will converge to a fixed point, if there exists any.
In this blog post, we propose to generalize the gradient flow ordinary differential equation (ODE) by letting $\lambda \rightarrow 0$ in our generalized PPA (for a great introduction to gradient flow, see Bach’s blog post).
Definition: Let $F(x,q) = q^\intercal \ell(x)$. We define the barygradient flow ODE as \(\dot \zeta(t) = - \begin{pmatrix} I_m & 0 \\ 0 & -I_S \end{pmatrix} \nabla F( (\nabla f)^{-1}( \zeta(t) ) ) ,\) where $\zeta : \mathbb{R}_+ \rightarrow \mathbb{R}^m \times \nabla h(\Delta_S) $.
Note that \(A((\nabla f)^{-1}(\zeta(t))) = \begin{pmatrix} I_m & 0 \\ 0 & -I_S \end{pmatrix} \nabla F( (\nabla f)^{-1}( \zeta(t) ) ) .\) Contrary to classic gradient flow, the function $F((\nabla f)^{-1}(\zeta(t)))$ is not necessarily decreasing along the flow. Indeed, denoting $\zeta(t)=(x(t), \xi(t))$ with $x$ (resp. $\xi$) valued in $\mathbb{R}^m$ (resp. $\nabla h(\mathring \Delta_S )$), we have:
\[\frac{d}{dt}F( (\nabla f)^{-1}( \zeta(t) ) ) = \frac{d}{dt}[(\nabla h)^{-1}(\xi(t))]^\intercal \ell(x(t)) + (\nabla h)^{-1}(\xi(t))^\intercal \frac{d}{dt} \ell(x(t)) ,\]where $\frac{d}{dt}[(\nabla h)^{-1}(\xi(t))] = [\nabla^2 h((\nabla h)^{-1}(\xi(t)))]^{-1} \dot \xi(t)$ and $\frac{d}{dt} \ell(x(t)) = J_\ell(x(t)) \dot x(t)$.
Hence, \(\frac{d}{dt}F((\nabla f)^{-1}(\zeta(t))) = \underbrace{ \ell(x(t))^\intercal [\nabla^2 h((\nabla h)^{-1}(\xi(t)))]^{-1} \ell(x(t)) }_{||\ell(x(t))||^2_{(\nabla h)^{-1}(\xi(t))}} - ||J_\ell(x(t))^\intercal (\nabla h)^{-1}(\xi(t))||^2 ,\) which is a difference of two nonnegative quantities.
PS: might help to subtract $\log(\sum_s e^{\xi_s(t)-1})$ from every component of $\xi(t)$