Deriving the MAML Objective - 2

Vaasudev Narayanan - Thu 14 November 2019 - meta-learning

In this post, we'll be deriving the meta-gradient of the MAML objective for the Mean Squared Error (MSE) loss.
This is a follow-up to Deriving the MAML Objective - 1.

In our setup, we have $n$ tasks and the meta-update is defined as,

$$ \begin{equation} \label{eq1} \theta := \theta - \beta \nabla_\theta\sum_{T_{i = 1}}^n L_{T_i}(f_{\theta_i'}) \end{equation} $$

We are interested in calculating $ \nabla_\theta\sum_{T_{i = 1}}^n L_{T_i}(f_{\theta_i'}) $ where $L$ is the mean-squared error defined as, $ \newcommand{\norm}[1]{\left\lVert #1 \right\rVert} $ $$ \begin{equation} \label{eq2} L_{T_i}(f_{\theta_i'}) = \frac{1}{2} \sum_{j = 1}^K \norm{f_{\theta_i'}(x^{(j)}) - y^{(j)}}_{2}^{2} \end{equation} $$

We will consider the simpler case of a single-output regression task. So,

$$ \begin{equation} \label{eq3} L_{T_i}(f_{\theta_i'}) = \frac{1}{2} \sum_{j = 1}^K (f_{\theta_i'}(x^{(j)}) - y^{(j)})^{2} \end{equation} $$

The gradient for the $ i^{th} $ task is defined as,

$$ \begin{equation} \label{eq4} \frac{\partial L_{T_i}(f_{\theta_i'})}{\partial \theta} = \frac{\partial L_{T_i}(f_{\theta_i'})}{\partial \theta_i'} \cdot \frac{\partial \theta_i'}{\partial \theta} \end{equation} $$

The first term of equation \ref{eq4}, $ \frac{\partial L_{T_i}(f_{\theta_i'})}{\partial \theta_i'} $ is the derivative of the task loss function w.r.t. to the task-adapted parameters,

$$ \begin{equation} \begin{split} \frac{\partial L_{T_i}(f_{\theta_i'})}{\partial \theta_i'} &= \frac{\partial L_{T_i}(f_{\theta_i'})}{\partial f_{\theta_i'}} \cdot \frac{\partial f_{\theta_i'}}{\partial \theta_i'} \\ &= \sum_{k = 1}^K (f_{\theta_i'}(x^{(k)}) - y^{(k)}) \cdot \frac{\partial f_{\theta_i'}}{\partial \theta_i'} (x^{(k)}) \end{split} \label{eq5} \end{equation} $$

K is the number of data-points for a task in a K-shot setting.

The second term of equation \ref{eq3}, $ \frac{\partial \theta_i'}{\partial \theta} $ is the derivative of the task-adapted parameters w.r.t. the meta-parameters.

To calculate $ \frac{\partial \theta_i'}{\partial \theta} $, let us assume that $ \theta_i' $ and $ \theta $ are both (D+1)-dimensional vectors,

$$ \begin{equation} \begin{split} &\theta_i' = [\theta_{i0}', \theta_{i1}', \dots, \theta_{iD}'] \\ &\theta = [\theta_0, \theta_1, \dots, \theta_D] \end{split} \label{eq6} \end{equation} $$

Then the Jacobian is defined as,

$$ \begin{equation} \label{eq7} \frac{\partial \theta_i'}{\partial \theta} = \begin{bmatrix} \frac{\partial \theta_{i0}'}{\partial \theta_0} & \frac{\partial \theta_{i0}'}{\partial \theta_1} & \dots & \frac{\partial \theta_{i0}'}{\partial \theta_D} \\ \frac{\partial \theta_{i1}'}{\partial \theta_0} & \frac{\partial \theta_{i1}'}{\partial \theta_1} & \dots & \frac{\partial \theta_{i1}'}{\partial \theta_D} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial \theta_{iD}'}{\partial \theta_0} & \frac{\partial \theta_{iD}'}{\partial \theta_1} & \dots & \frac{\partial \theta_{iD}'}{\partial \theta_D} \end{bmatrix} \end{equation} $$

By definition,

$$ \begin{equation} \label{eq8} \theta_{ij}' = \theta_j - \alpha \frac{\partial L_{T_i}(\theta)}{\partial \theta_j} \end{equation} $$

Calculating a few terms of the Jacobian, $$ \begin{equation*} \frac{\partial \theta_{i0}'}{\partial \theta_0} = 1 - \alpha \frac{\partial^2 L_{T_i}(\theta)}{\partial^2 \theta_0} \\ \frac{\partial \theta_{i0}'}{\partial \theta_1} = -\alpha \frac{\partial^2 L_{T_i}(\theta)}{\partial \theta_0 \partial \theta_1} \end{equation*} $$

Generalizing,

$$ \begin{equation} \label{eq9} \frac{\partial \theta_{ip}'}{\partial \theta_q} = \begin{cases} 1 - \alpha \frac{\partial^2 L_{T_i}(\theta)}{\partial^2 \theta_q}, & \text{if } p = q \\ -\alpha \frac{\partial^2 L_{T_i}(\theta)}{\partial \theta_p \partial \theta_q}, & {p \neq q} \end{cases} \end{equation} $$


Let us calculate the double derivative term $ \frac{\partial^2 L_{T_i}(f_{\theta})}{\partial^2 \theta} $,

$$ \begin{equation} \label{eq10} \frac{\partial^2 L_{T_i}(f_{\theta})}{\partial^2 \theta} = \begin{bmatrix} \frac{\partial^2 L_{T_i}(f_{\theta})}{\partial^2 \theta_0} & \frac{\partial^2 L_{T_i}(f_{\theta})}{\partial \theta_0 \partial \theta_1} & \dots & \frac{\partial^2 L_{T_i}(f_{\theta})}{\partial \theta_0 \partial \theta_D} \\ \frac{\partial^2 L_{T_i}(f_{\theta})}{\partial \theta_1 \partial \theta_0} & \frac{\partial^2 L_{T_i}(f_{\theta})}{\partial^2 \theta_1} & \dots & \frac{\partial^2 L_{T_i}(f_{\theta})}{\partial \theta_1 \partial \theta_D} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\partial^2 L_{T_i}(f_{\theta})}{\partial \theta_D \partial \theta_0} & \frac{\partial^2 L_{T_i}(f_{\theta})}{\partial \theta_D \partial \theta_1} & \dots & \frac{\partial^2 L_{T_i}(f_{\theta})}{\partial^2 \theta_D} \end{bmatrix} \end{equation} $$

Thus, $$ \begin{equation} \label{eq11} \frac{\partial L_{T_i}(\theta)}{\partial \theta_q} = \sum_{j = 1}^K (f_\theta (x^{(j)}) - y^{(j)}) \cdot \frac{\partial f_{\theta}}{\partial \theta_q} (x^{(j)}) \end{equation} $$

$$ \begin{equation} \label{eq12} \frac{\partial^2 L_{T_i}(\theta)}{\partial \theta_p \partial \theta_q} = \sum_{j = 1}^K [(f_\theta (x^{(j)}) - y^{(j)}) \cdot \frac{\partial^2 f_{\theta}}{\partial \theta_p \partial \theta_q} (x^{(j)}) + \frac{\partial f_{\theta}}{\partial \theta_p} (x^{(j)}) \cdot \frac{\partial f_{\theta}}{\partial \theta_q} (x^{(j)})] \end{equation} $$


Therefore, using equation \ref{eq5}, \ref{eq8} & \ref{eq12} we can now compute the gradient for the $i^{th}$ task w.r.t. the meta-parameters: $ \frac{\partial L_{T_i}(f_{\theta_i'})}{\partial \theta} $ (equation \ref{eq4})

$$ \begin{equation} \label{eq13} \frac{\partial L_{T_i}(f_{\theta_i'})}{\partial \theta} = \sum_{k = 1}^K (f_{\theta_i'}(x^{(k)}) - y^{(k)}) \cdot \frac{\partial f_{\theta_i'}}{\partial \theta_i'} (x^{(k)}) \cdot \frac{\partial \theta_i'}{\partial \theta} \end{equation} $$

Calculating the gradient for all the $n$ tasks & summing them up will give us $ \nabla_\theta\sum_{T_{i = 1}}^n L_{T_i}(f_{\theta_i'}) $ (equation \ref{eq2})

We can now perform our meta-update,

$$ \begin{equation*} \theta := \theta - \beta \nabla_\theta\sum_{T_{i = 1}}^n L_{T_i}(f_{\theta_i'}) \end{equation*} $$

References:

Chelsea Finn, Pieter Abbeel, and Sergey Levine. “Model-agnostic meta-learning for fast adaptation of deep networks.” ICML 2017.


Proudly powered by bootstrap, pelican, python and Alex!