Replicating_a_famous_optmisation_gif
There is this chart from Alec Radford that became quite famous for showing a comparative run between different optimizers.
Let’s try replicating it ourselves. We will use the framework developed in our previous post for this.
Since this gif is 4 years old now, I’ll also add a few additional optimizers for comparison.
Code
Loading the playground
All code on this blog is released under the following license.
# Copyright 2020 Cristian Lungu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
The test functions
import numpy as np
import numpy as np
from mpl_toolkits import mplot3d
from matplotlib import pyplot as plt
import matplotlib.colors as colors
import jax.numpy as jnp
class Ifunction:
def __init__(self):
pass
def __call__(*args) -> np.ndarray:
pass
def min(self) -> np.ndarray:
"""
Returns a np.array of the shape (k, 3) with all the minimum k points of this function.
The two values of the second dimension are the (x,y,z) coordinates of the minimum values
"""
return self.coord(self._min())
def coord(self, points: np.ndarray) -> np.ndarray:
"""
Returns a np.array of the shape (k, 3) with all the evaluations of the given
k points of this function.
The three values of the second dimension are the (x,y,z) coordinates of the minimum values
"""
z = np.expand_dims(self(points[:, 0], points[:, 1]), axis=-1)
return np.hstack((
points,
z
))
def domain(self) -> np.ndarray:
"""
Returns the ((x_min, x_max), (y_min, y_max)) values where this function
is of most interest
"""
pass
# =========================
# Function implementations
# =========================
class himmelblau(Ifunction):
def __call__(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""
Computes the given function
"""
return (x**2+y-11)**2 + (x+y**2-7)**2
def _min(self) -> np.ndarray:
"""
Returns a np.array of the shape (k, 2) with all the minimum k points of this function.
The two values of the second dimension are the (x,y) coordinates of the minimum values
"""
return np.array([
[3.0, 2.0],
[-2.805118, 3.131312],
[-3.779310, -3.283186],
[3.584428, -1.848126]
])
def domain(self) -> np.ndarray:
"""
Returns the ((x_min, x_max), (y_min, y_max)) values where this function
is of most interest
"""
return np.array([
[-5, 5],
[-5, 5]
])
class mc_cormick(Ifunction):
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> np.ndarray:
"""
Computes the given function
"""
return jnp.sin(x+y) + (x-y)**2-1.5*x+2.5*y+1
def _min(self) -> np.ndarray:
"""
Returns a np.array of the shape (k, 2) with all the minimum k points of this function.
The two values of the second dimension are the (x,y) coordinates of the minimum values
"""
return np.array([
[-0.54719, -1.54719],
])
def domain(self) -> np.ndarray:
"""
Returns the ((x_min, x_max), (y_min, y_max)) values where this function
is of most interest
"""
return np.array([
[-1.5, 4],
[-3, 4]
])
class holder_table(Ifunction):
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> np.ndarray:
"""
Computes the given function
"""
return -jnp.abs(jnp.sin(x)*jnp.cos(y)*jnp.exp(jnp.abs(1-jnp.sqrt(x**2+y**2)/jnp.pi)))
def _min(self) -> np.ndarray:
"""
Returns a np.array of the shape (k, 2) with all the minimum k points of this function.
The two values of the second dimension are the (x,y) coordinates of the minimum values
"""
return np.array([
[8.05502, 9.66459],
[8.05502, -9.66459],
[-8.05502, 9.66459],
[-8.05502, -9.66459],
])
def domain(self) -> np.ndarray:
"""
Returns the ((x_min, x_max), (y_min, y_max)) values where this function
is of most interest
"""
return np.array([
[-10, 10],
[-10, 10]
])
class beale(Ifunction):
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> np.ndarray:
"""
Computes the given function
"""
return (1.5-x+x*y)**2 + (2.25 - x + x*y**2)**2 + (2.625 - x + x*y**3)**2
def _min(self) -> np.ndarray:
"""
Returns a np.array of the shape (k, 2) with all the minimum k points of this function.
The two values of the second dimension are the (x,y) coordinates of the minimum values
"""
return np.array([
[3, 0.5],
])
def domain(self) -> np.ndarray:
"""
Returns the ((x_min, x_max), (y_min, y_max)) values where this function
is of most interest
"""
return np.array([
[-4, 4],
[-4, 4]
])
class saddle_point(Ifunction):
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> np.ndarray:
"""
Computes the given function
"""
return x**2 - y**2
def _min(self) -> np.ndarray:
"""
Returns a np.array of the shape (k, 2) with all the minimum k points of this function.
The two values of the second dimension are the (x,y) coordinates of the minimum values
"""
return np.array([
[3, 0.5],
])
def domain(self) -> np.ndarray:
"""
Returns the ((x_min, x_max), (y_min, y_max)) values where this function
is of most interest
"""
return np.array([
[-1.5, 1],
[-1.5, 1]
])
class eggholder(Ifunction):
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> np.ndarray:
"""
Computes the given function
"""
return -(y+47)*jnp.sin(jnp.sqrt(jnp.abs(x/2+(y+47)))) - x*jnp.sin(jnp.sqrt(jnp.abs(x-(y+47))))
def _min(self) -> np.ndarray:
"""
Returns a np.array of the shape (k, 2) with all the minimum k points of this function.
The two values of the second dimension are the (x,y) coordinates of the minimum values
"""
return np.array([
[512, 404.239],
])
def domain(self) -> np.ndarray:
"""
Returns the ((x_min, x_max), (y_min, y_max)) values where this function
is of most interest
"""
return np.array([
[-512, 512],
[-512, 512]
])
# Let's collect all the functions implemented into a single datastructure.
Function = {clazz.__name__: clazz() for clazz in Ifunction.__subclasses__()}
Function
{'beale': <__main__.beale at 0x7f45b01016a0>,
'eggholder': <__main__.eggholder at 0x7f45b0101710>,
'himmelblau': <__main__.himmelblau at 0x7f45b01014a8>,
'holder_table': <__main__.holder_table at 0x7f45b01015f8>,
'mc_cormick': <__main__.mc_cormick at 0x7f45b0101588>,
'saddle_point': <__main__.saddle_point at 0x7f45b01016d8>}
The charting utils
from typing import Tuple
from functools import lru_cache
from tqdm.notebook import tqdm
@lru_cache(maxsize=None)
def contour(function: Ifunction, x_min=-5, x_max=5, y_min=-5, y_max=5, mesh_size=100):
"""
Returns a (x, y, z) 3D coordinates, where `z = function(x,y)` evaluated on a
mesh of size (mesh_size, mesh_size) generated from the linear space defined by
the boundaries returned by `function.domain()`.
This function is usually used for displaying the contour of the given function.
"""
xx, yy = np.meshgrid(
np.linspace(x_min, x_max, num=mesh_size),
np.linspace(y_min, y_max, num=mesh_size)
)
zz = function(xx, yy)
return xx, yy, zz
def zoom(x_domain, y_domain, zoom_factor):
(x_min, x_max), (y_min, y_max) = x_domain, y_domain
# zoom
x_mean = (x_min + x_max) / 2
y_mean = (y_min + y_max) / 2
x_min = x_min + (x_mean - x_min) * zoom_factor
x_max = x_max - (x_max - x_mean) * zoom_factor
y_min = y_min + (y_mean - y_min) * zoom_factor
y_max = y_max - (y_max - y_mean) * zoom_factor
return (x_min, x_max), (y_min, y_max)
def rotate(_x: np.ndarray, _y: np.ndarray, angle=45) -> Tuple[np.ndarray, np.ndarray]:
def __is_mesh(x: np.ndarray) -> bool:
__is_2d = len(x.shape) == 2
if __is_2d:
__is_repeated_on_axis_0 = np.allclose(x.mean(axis=0), x[0, :])
__is_repeated_on_axis_1 = np.allclose(x.mean(axis=1), x[:, 0])
__is_repeated_array = __is_repeated_on_axis_0 or __is_repeated_on_axis_1
return __is_repeated_array
else:
return False
def __is_single_dimension(x: np.ndarray) -> bool:
# when the function only has one minimum the initial x will have the shape (1,)
# and doing a np.squeeze before calling this function will result in a x of shape ()
# when we reach this control point
_is_scalar_point = len(x.shape) == 0
return len(x.shape) == 1 or _is_scalar_point
def __rotate_mesh(xx: np.ndarray, yy: np.ndarray) -> np.ndarray:
xx, yy = np.einsum('ij, mnj -> imn', rotation_matrix, np.dstack([xx, yy]))
return xx, yy
def __rotate_points(x: np.ndarray, y: np.ndarray) -> np.ndarray:
points = np.hstack((x[:, np.newaxis], y[:, np.newaxis]))
# anti-clockwise rotation matrix
x, y = np.einsum('mi, ij -> jm', points, np.array([
[np.cos(radians), -np.sin(radians)],
[np.sin(radians), np.cos(radians)]
]))
return x, y
# apply rotation
angle = (angle + 90) % 360
radians = angle * np.pi/180
# clockwise rotation matrix
rotation_matrix = np.array([
[np.cos(radians), np.sin(radians)],
[-np.sin(radians), np.cos(radians)]
])
if __is_mesh(_x) and __is_mesh(_y):
_x, _y = __rotate_mesh(_x, _y)
elif __is_single_dimension(np.squeeze(_x)) and __is_single_dimension(np.squeeze(_y)):
def __squeeze(_x):
"""
We need to reduce the redundant 1 dimensions from shapes like (3, 1, 2) to (3, 2),
but at the same time, making sure we don't end up with scalar values (going from (1, 1) to a shape ())
We need at least a shape of (1,)
"""
if len(np.squeeze(_x).shape) == 0:
return np.array([np.squeeze(_x)])
else:
return np.squeeze(_x)
_x, _y = __squeeze(_x), __squeeze(_y)
_x, _y = __rotate_points(_x, _y)
else:
raise AssertionError(f"Unknown rotation types for shapes {_x.shape} and {_y.shape}")
return _x, _y
def log_contour_levels(zz, number_of_contour_lines=35):
min_function_value = zz.min() # we use the mesh values because some functions could actually go to -inf, we only need the minimum in the viewable window
max_function_value = np.percentile(zz, 5) # min(5, zz.max()/2)
contour_levels = min_function_value + np.logspace(0, max_function_value-min_function_value, number_of_contour_lines)
return contour_levels
def plot_function_2d(function: Ifunction, ax=None, angle=45, zoom_factor=0, contour_log_scale=True):
(x_min, x_max), (y_min, y_max) = zoom(*function.domain(), zoom_factor)
xx, yy, zz = contour(function, x_min, x_max, y_min, y_max)
ax = ax if ax else plt.gca()
xx, yy = rotate(xx, yy, angle=angle) # I wonder why I shouldn't also rotate zz?!
contour_levels = log_contour_levels(zz) if contour_log_scale else 200
contour_color_normalisation = colors.LogNorm(vmin=contour_levels.min(), vmax=contour_levels.max()) if contour_log_scale else colors.Normalize(vmin=zz.min(), vmax=zz.max())
ax.contour(xx, yy, zz, levels=contour_levels, cmap='Spectral', norm=contour_color_normalisation, alpha=0.5)
min_coords = function.min()
ax.scatter(*rotate(min_coords[:, 0], min_coords[:, 1], angle=angle))
ax.axis("off")
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
def plot_function_3d(function: Ifunction, ax=None, azimuth=45, angle=45, zoom_factor=0, show_projections=False, contour_log_scale=True):
(x_min, x_max), (y_min, y_max) = zoom(*function.domain(), zoom_factor)
xx, yy, zz = contour(function, x_min, x_max, y_min, y_max)
# evaluate once, use in multiple places
zz_min = zz.min()
zz_max = zz.max()
norm = colors.Normalize(vmin=zz_min, vmax=zz_max)
# put the 2d contour floor, a fit lower than the minimum to look like a reflection
zz_floor_offset = int((zz_max - zz_min) * 0.065)
# create 3D axis if not provided
ax = ax if ax else plt.axes(projection='3d')
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_zlim(zz.min() - zz_floor_offset, zz.max())
min_coordinates = function.min()
ax.scatter3D(min_coordinates[:, 0], min_coordinates[:, 1], min_coordinates[:, 2], marker='.', color='black', s=120, alpha=1, zorder=1)
contour_levels = log_contour_levels(zz) if contour_log_scale else 200
contour_color_normalisation = colors.LogNorm(vmin=contour_levels.min(), vmax=contour_levels.max()) if contour_log_scale else colors.Normalize(vmin=zz.min(), vmax=zz.max())
ax.contourf(xx, yy, zz, zdir='z', levels=contour_levels, offset=zz_min-zz_floor_offset, cmap='Spectral', norm=norm, alpha=0.5, zorder=1)
if show_projections:
ax.contourf(xx, yy, zz, zdir='x', levels=300, offset=xx.max()+1, cmap='gray', norm=contour_color_normalisation, alpha=0.05, zorder=1)
ax.contourf(xx, yy, zz, zdir='y', levels=300, offset=yy.max()+1, cmap='gray', norm=contour_color_normalisation, alpha=0.05, zorder=1)
ax.plot_surface(xx, yy, zz, cmap='Spectral', rcount=100, ccount=100, norm=norm, shade=False, antialiased=True, alpha=0.6)
# ax.plot_wireframe(xx, yy, zz, cmap='Spectral', rcount=100, ccount=100, norm=norm, alpha=0.5)
# apply rotation
ax.view_init(azimuth, angle)
def plot_function(function: Ifunction, angle=45, zoom_factor=0, azimuth_3d=30, fig=None, ax_2d=None, ax_3d=None, contour_log_scale=True):
fig = plt.figure(figsize=(26, 10)) if fig is None else fig
ax_3d = fig.add_subplot(1, 2, 1, projection='3d') if ax_3d is None else ax_3d
ax_2d = fig.add_subplot(1, 2, 2) if ax_2d is None else ax_2d
plot_function_3d(function=function, ax=ax_3d, azimuth=azimuth_3d, angle=angle, zoom_factor=zoom_factor, contour_log_scale=contour_log_scale)
plot_function_2d(function=function, ax=ax_2d, angle=angle, zoom_factor=zoom_factor, contour_log_scale=contour_log_scale)
return fig, ax_3d, ax_2d
def plot_all_functions(functions: dict):
nr_functions = len(functions)
fig = plt.figure(figsize=(13, 5*nr_functions))
for i, (name, function) in enumerate(tqdm(functions.items()), start=1):
ax_3d = fig.add_subplot(nr_functions, 2, i*2-1, projection='3d')
ax_2d = fig.add_subplot(nr_functions, 2, i*2)
ax_3d.title.set_text(f"{name}")
try:
plot_function(function, fig=fig, ax_2d=ax_2d, ax_3d=ax_3d, angle=225)
except:
plot_function(function, fig=fig, ax_2d=ax_2d, ax_3d=ax_3d, angle=225, contour_log_scale=False)
ax_2d.text(1, 0, 'www.clungu.com', transform=ax_2d.transAxes, ha='right',
color='#777777', bbox=dict(facecolor='white', alpha=0.8, edgecolor='white'))
plot_function(himmelblau(), angle=225)
The jax
optimizer utils
from jax import grad
class optimize:
def __init__(self, function):
self.function = function
self.grad_function = grad(function, argnums=(0, 1))
self.x, self.y = list(), list()
def using(self, optimizer, name='sgd'):
self._init, self._update, self._get_params = optimizer
self.optimizer = optimizer
self.optimizer_name = name
return self
def start_from(self, params):
self.state = self._init(tuple(params))
return self
def update(self, nr_iterations=1):
for i in range(nr_iterations):
params = self._get_params(self.state)
self.__add_point(*params)
grads = self.grad_function(*params)
self.state = self._update(i, grads, self.state)
return self.x, self.y
def __add_point(self, _x, _y):
"""
Adds the x and y coordinates for these point to the trace lists
"""
self.x.append(float(_x))
self.y.append(float(_y))
class optimize_multi:
def __init__(self, function):
self.function = function
def using(self, optimizers):
self.optimizers = optimizers
return self
def start_from(self, params):
self.params = params
return self
def tolist(self):
return [optimize(self.function).using(optimizer, name=name).start_from(self.params) for optimizer, name in self.optimizers]
The animation utils
import matplotlib.animation as animation
from IPython.display import HTML, display
from itertools import cycle
from typing import List
from cycler import cycler
from functools import partial
from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection
class FixZorderLine3D(Line3D):
@property
def zorder(self):
return 1000
@zorder.setter
def zorder(self, value):
pass
color_cycles = 'bgrcmyk'
cmap_cycles = [
'binary', 'gist_yarg', 'gist_gray', 'gray', 'bone', 'pink',
'spring', 'summer', 'autumn', 'winter', 'cool', 'Wistia',
'hot', 'afmhot', 'gist_heat', 'copper']
cmap_cycles = [
'Greys', 'Purples', 'Blues', 'Greens', 'Oranges', 'Reds'
]
def single_frame(i, optimisations: List[optimize], fig, _ax_3d, _ax_2d, angle, use_flat_colors=True, contour_log_scale=True, legend_location="upper right", azimuth_3d=30, zoom_factor=0, force_line_zorder=True):
_ax_3d.clear()
_ax_2d.clear()
assert len(optimisations) >= 1, f"We need at least one optimisation to animate, but {len(optimisations)} given."
# assert all functions to optimise have the same definition
plot_function(optimisations[0].function, angle=angle, fig=fig, ax_3d=_ax_3d, ax_2d=_ax_2d, contour_log_scale=contour_log_scale, azimuth_3d=azimuth_3d, zoom_factor=zoom_factor)
for i, (optimisation, color) in enumerate(zip(optimisations, cycle(color_cycles if use_flat_colors else cmap_cycles))):
x, y = optimisation.update()
x, y = np.array(x), np.array(y)
label = optimisation.optimizer_name
if use_flat_colors:
# draw line paths with a point at the edge
_ax_2d.plot(*rotate(x, y, angle=angle), color=color, label=label)
lines = _ax_3d.plot3D(x, y, optimisation.function(x, y), color=color, label=label)
if force_line_zorder:
for line in lines:
line.__class__ = FixZorderLine3D
#draw the last points
_ax_2d.scatter(*rotate(x[-1:], y[-1:], angle=angle), color=color)
_ax_3d.scatter3D(x[-1:], y[-1:], optimisation.function(x[-1:], y[-1:]), color=color)
else:
#draws only update points, not connected with lines
_ax_2d.scatter(*rotate(x, y, angle=angle), c=np.arange(len(x)), cmap=color, label=label)
_ax_3d.scatter3D(x, y, optimisation.function(x, y), c=np.arange(len(x)), cmap=color, label=label)
_ax_2d.legend(loc=legend_location)
# add a credits watermark such as not to overlap with the legend
if legend_location == "upper right":
_ax_2d.text(1, 0, 'www.clungu.com', transform=_ax_2d.transAxes, ha='right',
color='#777777', bbox=dict(facecolor='white', alpha=0.8, edgecolor='white'))
else:
_ax_2d.text(1, 1, 'www.clungu.com', transform=_ax_2d.transAxes, ha='right',
color='#777777', bbox=dict(facecolor='white', alpha=0.8, edgecolor='white'))
_ax_2d.plot()
print(".", end ="")
def animate(optimisations, frames=20, interval=250, use_flat_colors=True, contour_log_scale=True, legend_location="upper right", angle=225, azimuth_3d=30, zoom_factor=0, force_line_zorder=True):
assert len(optimisations) >= 1, f"We need at least one optimisation to animate, but {len(optimisations)} given."
fig = plt.figure(figsize=(13,5))
fig, ax_3d, ax_2d = plot_function(optimisations[0].function, fig=fig, angle=angle, contour_log_scale=contour_log_scale, zoom_factor=zoom_factor)
fig.tight_layout()
animator = animation.FuncAnimation(fig, partial(single_frame, use_flat_colors=use_flat_colors, contour_log_scale=contour_log_scale, legend_location=legend_location, azimuth_3d=azimuth_3d, zoom_factor=zoom_factor, force_line_zorder=force_line_zorder), fargs=(optimisations, fig, ax_3d, ax_2d, angle), frames=frames, interval=interval, blit=False)
video = animator.to_html5_video()
display(HTML(video))
plt.close()
return video
from jax.experimental.optimizers import sgd
animate(
optimize_multi(holder_table())\
.using([
(sgd(step_size=0.3), "sgd"),
(adam(step_size=0.3), "adam"),
])\
.start_from([-1., -1.])\
.tolist(),
frames=4,
contour_log_scale=False,
legend_location="lower right"
);
from jax.experimental.optimizers import adam, adagrad, rmsprop, sgd, rmsprop_momentum, adamax
step_size = 0.03
animate(
optimize_multi(saddle_point())\
.using([
(sgd(step_size=step_size), "sgd"),
(momentum(step_size=step_size, mass=0.9), "momentum"),
(adam(step_size=step_size), "adam"),
(adagrad(step_size=step_size), "adagrad"),
(rmsprop(step_size=step_size), "rmsprop"),
(rmsprop_momentum(step_size=step_size), "rmsprop_momentum"),
(adamax(step_size=step_size), "adamax"),
])\
.start_from([.5, -0.000000001])\
.tolist(),
frames=200,
use_flat_colors=True,
contour_log_scale=False,
interval=50
);
I wasn’t actually able to find the exact setup used for generating the GIF (the initial starting point, and the learning rates used on all the optimizers) but I did a best effort approach on finding these and I think these results replicate the behavior shown in the original image.
Since we’ve come so far, let’s play around and learn how these optimizers behave in different settings!
Experiments
We will use a few functions to study the behavior of the optimizers, and you can see them bellow:
plot_all_functions(Function)
SGD with “just right” step_size
From the animation above we’ve saw that SGD, did poorly compared with the others. This is a design weakenes, but a learning rate setting failure as well.
Let’s see how each optimizer fares, when compared to an SGD instance that has the “just right” learning rate set.
from jax.experimental.optimizers import adam, adagrad, rmsprop, sgd, rmsprop_momentum, adamax, momentum, nesterov
animate(
optimize_multi(himmelblau())\
.using([
(sgd(step_size=0.01), "sgd"),
(adam(step_size=0.01), "adam"),
(adagrad(step_size=0.01), "adagrad"),
(adamax(step_size=0.1), "adamax"),
(rmsprop(step_size=0.01), "rmsprop"),
])\
.start_from([-1., -1.])\
.tolist(),
frames=100,
use_flat_colors=True,
interval=50
);
It seems that SGD is really good at this function.
Adamax did almost equally well but wasn’t as stable as SGD was near the minimum.
You can’t really see the trace of all the other optimizers because they are over-imposed and follow roughly the same path.
My conclusion from this experiment is that SGD, when set up correctly surpasses the other optimizers.
A well tuned SGD beats all the other adaptive optimizers
Beale on large step_size
Finding a good step_size
(also known as learning_rate
) is hard (or not trivial). Let’s see how all of the optimizers fare when given a rather large (but not too large) step_size
.
from jax.experimental.optimizers import adam, adagrad, rmsprop, sgd, rmsprop_momentum, adamax, sm3
animate(
optimize_multi(beale())\
.using([
(sgd(step_size=0.01), "sgd"),
(adam(step_size=0.01), "adam"),
(adagrad(step_size=0.01), "adagrad"),
(rmsprop(step_size=0.01), "rmsprop"),
(rmsprop_momentum(step_size=0.01), "rmsprop_momentum"),
(adamax(step_size=0.01), "adamax"),
])\
.start_from([-2., -2.])\
.tolist(),
frames=200,
use_flat_colors=True,
contour_log_scale=True,
interval=50,
angle=85
);
SGD with 0.01 is off the charts and doesn’t converges, while the others, even with a suboptimal step_size manage to run towards the minimum.
For example, here is the output of the first 5 update steps that SGD computes if we start from (-2, -2)
. It’s obvious that the step_size
is too large as the optimization moves from negative to positive and back, increasing the magnitude each time ending with a -inf
value by it’s fifth update.
optimize(beale())\
.using(sgd(step_size=0.01))\
.start_from([-2., -2.])\
.update(5)
([-2.0, 2.3874998092651367, -22429.189453125, 1.6297707243559544e+28, nan],
[-2.0, 8.800000190734863, -18198.984375, 6.025792522646275e+28, -inf])
Let’s fine tune the SGD variants so we can still compare them with the others. Just in case, we will zoom out a bit so we can see the larger picture in case SGD goes off the charts again.
from jax.experimental.optimizers import adam, adagrad, rmsprop, sgd, rmsprop_momentum, adamax, momentum, nesterov
animate(
optimize_multi(beale())\
.using([
(sgd(step_size=0.00001), "sgd"),
(momentum(step_size=0.00001, mass=0.9), "momentum"),
(nesterov(step_size=0.00001, mass=0.9), "nesterov"),
(adam(step_size=0.1), "adam"),
(adagrad(step_size=0.1), "adagrad"),
(rmsprop(step_size=0.1), "rmsprop"),
(rmsprop_momentum(step_size=0.1), "rmsprop_momentum"),
(adamax(step_size=0.1), "adamax"),
])\
.start_from([4., 3.])\
.tolist(),
frames=200,
use_flat_colors=True,
contour_log_scale=True,
interval=50,
angle=45,
zoom_factor=-2
);
Conclusions:
- rmsprop:
- most stable of all
- rmsprop with momentum:
- stumbles a bit but gets to the minimum quickly
- when the minimum is reached, it follows a rather erratic path
- all SGD variants had to have the learning rate lowered to 0.00001 so not to explode so this comparison is rather misleading.
What happens when we have really large step sizes?
In the setup above we’ve assumed a “large but not too large” step_size
, but what happens with an obvious too large value?
It’s clear that the SGD variants need to have the right step_size
set in order to work properly, so we will ignore them in this experiment.
from jax.experimental.optimizers import adam, adagrad, rmsprop, sgd, rmsprop_momentum, adamax, sm3, momentum, nesterov, sm3
animate(
optimize_multi(beale())\
.using([
(adam(step_size=0.5), "adam"),
(adagrad(step_size=0.5), "adagrad"),
(rmsprop(step_size=0.5), "rmsprop"),
(rmsprop_momentum(step_size=0.5), "rmsprop_momentum"),
(adamax(step_size=0.5), "adamax"),
])\
.start_from([4., 3.])\
.tolist(),
frames=200,
use_flat_colors=True,
contour_log_scale=True,
interval=50,
angle=45,
zoom_factor=-2
);
Conclusions:
- rmsprop is still stable, except near the optimum value
- rmsprop with momentum is quite erratic but manages to get to the correct result
- ada* variants found a local minimum and settled there
What happens when there are more than one minimum and the step_size is too large?
animate(
optimize_multi(himmelblau())\
.using([
# (sgd(step_size=0.5), "sgd"),
# (momentum(step_size=0.5, mass=0.9), "momentum"),
# (nesterov(step_size=0.5, mass=0.9), "nesterov"),
(adam(step_size=0.5), "adam"),
(adagrad(step_size=0.5), "adagrad"),
(rmsprop(step_size=0.5), "rmsprop"),
(rmsprop_momentum(step_size=0.5), "rmsprop_momentum"),
(adamax(step_size=0.5), "adamax"),
])\
.start_from([0., 0.])\
.tolist(),
frames=100,
use_flat_colors=True,
contour_log_scale=False,
interval=50,
);
- rmsprop found a solution quickly
- rmsprop with momentum shoots off with so much momentum, that soon reaches a really high point from which it falls back down, passes near the closest minimum, and goes on settling on the second most distant one. It’s beggining was quite erratic but on a high dimensional optimization (like a neural network, where the minimum is quite likely far away from the initialization), the initial large momentum might not be a problem because it takes a few hundred batches (updates) to get near a minimum anyway. Also, in this instance, it takes a lot of steps.
- sgd, sgd with momentum, and sgd with nesterov fell off the grid so I removed them.
- adagrad fell smoothly to the nearest minimum but with some drag
- adamax, because of the momentum, moved past the minimum (while finding it!) and went to another minimum.
- this means, that were we to start from a high point, close to one a local minimum, we will get ouf of it due to momentum.
- adam almost did the same as adamax but went for the closest minimum in the end.
What happens to a really hard optimization function when the learning rate is too high?
animate(
optimize_multi(holder_table())\
.using([
(sgd(step_size=0.5), "sgd"),
(momentum(step_size=0.5, mass=0.9), "momentum"),
(nesterov(step_size=0.5, mass=0.9), "nesterov"),
(adam(step_size=0.5), "adam"),
(adagrad(step_size=0.5), "adagrad"),
(rmsprop(step_size=0.5), "rmsprop"),
(rmsprop_momentum(step_size=0.5), "rmsprop_momentum"),
(adamax(step_size=0.5), "adamax"),
])\
.start_from([-1., -1.])\
.tolist(),
frames=20,
use_flat_colors=True,
contour_log_scale=False,
interval=50,
);
Basically all, except rmsprop with momentum settle to a local minimum while it diverges off the chart
Who can solve holder_table, if we tune the step_size correctly?
animate(
optimize_multi(holder_table())\
.using([
(rmsprop_momentum(step_size=0.47), "rmsprop_momentum"),
])\
.start_from([-1., -1.])\
.tolist(),
frames=15,
use_flat_colors=True,
contour_log_scale=False,
interval=50,
);