Nesterov gradient descent



For Stanford AA222


April 12th, 2021



import numpy as np


class CountWillExceedWarning(RuntimeError):
    pass


def step_nesterov_momentum(f_fn, g_fn, x, v, alpha, beta):
    """ Source: Course textbook, Algorithm 5.4 (p. 76) """
    g = g_fn(x + beta * v)
    # normalization step
    g = g / np.linalg.norm(g)
    v[:] = beta * v - alpha * g
    return x + v, v


def impl_nesterov_momentum(f_fn, g_fn, x0, alpha, beta):
    x, v = x0, np.zeros(x0.size)
    x_history = [x]
    while True:
        try:
            x, v = step_nesterov_momentum(f_fn, g_fn, x, v, alpha, beta)
            x_history.append(x)
        except CountWillExceedWarning:
            break
    return x_history


def optimize(f, g, x0, n, count, prob, return_history=False):
    """
    Args:
        f (function): Function to be optimized
        g (function): Gradient function for `f`
        x0 (np.array): Initial position to start from
        n (int): Number of evaluations allowed. Remember `g` costs twice of `f`
        count (function): takes no arguments are returns current count
        prob (str): Name of the problem. So you can use a different strategy
                 for each problem. `prob` can be `simple1`,`simple2`,`simple3`,
                 `secret1` or `secret2`
        return_history (bool)
    Returns:
        x_best (np.array): best selection of variables found
    """

    def f_fn(x):
        if count() + 1 >= n:
            raise CountWillExceedWarning()
        return f(x)

    def g_fn(x):
        if count() + 2 >= n:
            raise CountWillExceedWarning()
        return g(x)

    x_history = impl_nesterov_momentum(f_fn, g_fn, x0, alpha=0.1, beta=0.5)

    if return_history:
        return x_history

    return x_history[-1]

This site is open source. Improve this page »