45 minute read

png

Goal

I’ve just stumbled upon this wiki page which describes optimization methods that can be used for optimizing functions or (or programs) where you don’t know or is hard to compute a derivative for it.

I plan to implement some optimizers from that page and see how they work, but before doing that I realized that I also need a simulation environment where I can see how the optimization is progressing.

In the end, the simulation looks something like this.

Code
import matplotlib.animation as animation
from IPython.display import HTML, display

x, y = [0.], [0.]

def sgd_update(x, y, learning_rate=0.01):
    d_x, d_y = function_grad(x, y)
    x = x - learning_rate * d_x
    y = y - learning_rate * d_y
    return x, y


def single_frame(i, _ax_3d, _ax_2d):
    _ax_3d.clear()
    _ax_2d.clear()
    angle = 225
    _x, _y = sgd_update(x[-1], y[-1], learning_rate=0.01)        
    x.append(float(_x)), y.append(float(_y))

    plot_function(function, fig=fig, ax_3d=_ax_3d, ax_2d=_ax_2d, angle=angle)
    _ax_2d.scatter(*rotate(np.array(x), np.array(y), angle=angle), c=np.arange(len(x)), cmap='binary')

    _ax_2d.plot()


function = himmelblau()
function_grad = grad(function, argnums=(0, 1))

fig = plt.figure(figsize=(13,5))
fig, ax_3d, ax_2d = plot_function(function, fig=fig)
fig.tight_layout()

frames = 20
animator = animation.FuncAnimation(fig, single_frame, fargs=(ax_3d, ax_2d), frames=frames, interval=250, blit=False)
display(HTML(animator.to_html5_video()))
plt.close();

And starting from a different point..

Code
import matplotlib.animation as animation
from IPython.display import HTML, display


def single_frame(i, optimisation, _ax_3d, _ax_2d, angle):
    _ax_3d.clear()
    _ax_2d.clear()
    
    x, y = optimisation.update()        
    x, y = np.array(x), np.array(y)

    plot_function(optimisation.function, angle=angle, fig=fig, ax_3d=_ax_3d, ax_2d=_ax_2d)
    _ax_2d.scatter(*rotate(x, y, angle=angle), c=np.arange(len(x)), cmap='cool')
    _ax_3d.scatter3D(x, y, optimisation.function(x, y), c=np.arange(len(x)), cmap='cool')

    _ax_2d.plot()


angle=225
optimisation = optimize(himmelblau())\
    .using(sgd(step_size=0.01))\
    .start_from([-0.5, -0.5])\

fig = plt.figure(figsize=(13,5))
fig, ax_3d, ax_2d = plot_function(optimisation.function, fig=fig, angle=angle)
fig.tight_layout()

frames = 20
animator = animation.FuncAnimation(fig, single_frame, fargs=(optimisation, ax_3d, ax_2d, angle), frames=frames, interval=250, blit=False)
display(HTML(animator.to_html5_video()))
plt.close()

This blog post mainly documents the exploration I did and experiments I made that lead to the final form of the simulation environment.

Test-ground

Optimizations algorithms (and their implementation’s) performance are usually showcased on some known mathematical functions with interesting (and hard) shapes. Some examples:

image.png

At the same time, an optimization algorithm, starts from some where (some from multiple places all at once), and iteratively progress to the final result. So there is a trace element to this environment to see how things are moving along.

Starting off with Himmelblau’s function

I’ll start with just replicating one random function from that list (Himmelblau) and implement if in pure numpy.

This function will be used to experiment different visualization aspects of it, but in the end we will want to substitute it with any other function we wish to use.

Numpy implementation

Following the given definition above: \(f(x, y) = (x^2+y-11)^2 + (x+y^2-7)^2\) the numpy code should be quite simple to write.

import numpy as np
def himmelblau(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    return (x**2+y-11)**2 + (x+y**2-7)**2

himmelblau(0, 1)
136

Drawing the shape (contour) of the function

Now that we have an example function implemented, let’s write the plotting code that shows the 2D contour of it.

This is a 2D function (and we will only deal with 2D functions in our simulation) and since for every $(x,y)$ point we also have the evaluation of the function $f(x,y)$ what we actually need to represent is a $(x, y, f(x,y))$ triple.

In plain 2D you can’t normally plot this kind of information (since you only have 2 dimensions) but you can artificially add a new (3rd) dimension by using colors.

The idea is that if you have some colormap like these,

image.png

you can replace the $f(x, y)$ value by the color you’ve got in that colormap. This will result in a contour plot. You can find a similar idea in other kind of plots like heat-maps or choropleth.

I’m no visualization expert, quite the opposite, so some things I say may be inaccurate, debatable or plain false. In such cases please leave a comment and I’ll try my best to correct this post.

2D Contour Plot

We can visualize the contour of any function, by using the plt.contourf function from matplolitb.

This function requires from us a triple of points in (x, y, z) form, which it will use to interpolate the shape of the function. This means that we need lots of probing points in the $(x,y)$ domain for which to evaluate $f(x, y)$.

What you usualy do in this instance is to create a mesh of points, a 2d grid of points, like the intersections of lines on a chess board.

Numpy provides a way of creating these mesh values by using the np.meshgrid function in conjunction with np.linspace.

xx, yy = np.meshgrid( 
    np.linspace(-5, 5),
    np.linspace(-5, 5)
)

xx.shape, yy.shape
((50, 50), (50, 50))

The result of the above function is list of (50, 50) points for the x coordinate and (50, 50) points for the y coordinate. If you’d plot them on a scatter-plot, this is how they would look like.

plt.scatter(xx, yy)
plt.xlim(-1, 1)
plt.ylim(-1, 1)
(-1.0, 1.0)

png

The first few values of the yy variable show how it is composed. It’s repeated values on the first axis (i.e. rows are equal among them), and linearly spaced values (between the maximum and minimum) on the second axis (i.e. columns).

yy[:3, :3]
array([[-5.        , -5.        , -5.        ],
       [-4.79591837, -4.79591837, -4.79591837],
       [-4.59183673, -4.59183673, -4.59183673]])

The xx variable is the opposite of yy. You have equal columns (second axis) and linearly spaced values on the rows (first axis).

xx[:3, :3]
array([[-5.        , -4.79591837, -4.59183673],
       [-5.        , -4.79591837, -4.59183673],
       [-5.        , -4.79591837, -4.59183673]])

Matplotlib has a function called contourf that will print our contours. There are some details in the code bellow, regarding the color normalization, the color map (cmap) and the number of levels we want contours for which I won’t explain but are easy to understand.

Code
import numpy as np
from mpl_toolkits import mplot3d
from matplotlib import pyplot as plt
import matplotlib.colors as colors

xx, yy = np.meshgrid( 
    np.linspace(-5, 5, num=100),
    np.linspace(-5, 5, num=100)
)
zz = himmelblau(xx, yy)

plt.contourf(xx, yy, zz, levels=1000, cmap='Spectral', norm=colors.Normalize(vmin=zz.min(), vmax=zz.max()))

png

3D shape

Since this is a 2D function (x, y) and the evaluation of it is a 3D dimension, we should really plot it in 3D.

Since matplotlib wasn’t specifically designed initially for 3D plots, the 3D add-on (included in the default package) is somewhat of a patch over the 2D machinery.

In any case, it was simple enough to print the 3D shape of out plot, as you can see bellow.

Code
import matplotlib.colors as colors

plt.figure(figsize=(13, 10))
ax = plt.axes(projection='3d')
ax.plot_surface(xx, yy, zz, cmap='Spectral', rcount=100, ccount=100, norm=colors.Normalize(vmin=zz.min(), vmax=zz.max()))
ax.contourf(xx, yy, zz, levels=100, zdir='z', offset=np.min(zz)-100, cmap='Spectral', norm=colors.Normalize(vmin=zz.min(), vmax=zz.max()))

# ax.view_init(50, 95)

png

On the other hand, some other (maybe reasonable) stuff I tried to accomplish in 3D weren’t so easy, as you can see further down bellow..

Making an animation of the function

Visualizing a 3D function by a static image is not that great. What would be great instead is if we could somehow see it in motion, rotating the graph in a 360 degree animation. In this way we could better understand the shape we are dealing with.

The two code snippets bellow quickly sketch how to use the matplotlib.animation package on a 3D plot.

Code
def single_frame(i, ax):
    ax.clear()
    ax.view_init(45, (i % 36) * 10 )
    ax.plot_surface(xx, yy, zz, cmap='Spectral', rcount=100, ccount=100, norm=colors.Normalize(vmin=zz.min(), vmax=zz.max()))     
    ax.contourf(xx, yy, zz, levels=300, offset=np.min(zz)-100, cmap='Spectral', norm=colors.Normalize(vmin=zz.min(), vmax=zz.max()), alpha=0.5)
    ax.scatter3D(min_coordinates[:, 0], min_coordinates[:, 1], himmelblau(min_coordinates[:, 0], min_coordinates[:, 1]), marker='.', color='black', s=60, alpha=1)
    ax.plot()

plt.figure(figsize=(8, 5))
ax = plt.axes(projection='3d')
single_frame(10, ax)

png

import matplotlib.animation as animation
from IPython.display import HTML, display

fig = plt.figure(figsize=(13, 10))
ax = plt.axes(projection='3d')

single_frame(0, ax)
fig.tight_layout()

frames = 36
animator = animation.FuncAnimation(fig, single_frame, fargs=(ax,), frames=frames, interval=100, blit=False)
display(HTML(animator.to_html5_video()))
plt.close()

Although it looks OK, the fact that there is some transparency to the graph makes it look a bit pixelated or rough. Let’s try another rendering of the above graph but using less transparency.

Code
import matplotlib.animation as animation
from IPython.display import HTML, display

fig = plt.figure(figsize=(13, 10))
ax = plt.axes(projection='3d')

single_frame(0, ax)
fig.tight_layout()

frames = 36
animator = animation.FuncAnimation(fig, single_frame, fargs=(ax,), frames=frames, interval=100, blit=False)
display(HTML(animator.to_html5_video()))
plt.close()

Displaying minimum values

Since we will use these functions for seeing how some optimization strategies work, by minimizing the functions we need to also include in the charts, the minimum points where these functions have the lowest values.

Finding the minimum, among the mesh grid evaluations, automatically

On way of finding the minimum values is to rely on the minimum values we find among the evaluations of the meshgrid.

This means computing the argmax of the zz mesh. But somewhat surprisingly, this returns a single value, which is the index of the minimum element in the flattened array.

np.argmin(zz)
1712

In order to recompute back the 2D coordinates (zz is 2D, remember?) we need to call unravel_index on the returned value, while specifying the shape of the array we extracted this flattened index from.

By doing this we can recompute the coordinates (x, y) which yielded that minimum z value.

xx[np.unravel_index(np.argmin(zz), zz.shape)], yy[np.unravel_index(np.argmin(zz), zz.shape)]
(-3.787878787878788, -3.282828282828283)

At the same time, if we look at the original definition of the function we see that this function has in fact 4 minimum locations, and we only find one.

This happens because the meshgrid didn’t land exactly on the minimum spots, but near them, and argmax returned only the smallest one, the one which by chance landed nearer one of the minimum.

Since we know that our function has 4 global minimums, some experimentation shows that these top 4 smallest evaluations lie at at most 0.08 difference apart.

So we could in theory say that the minimum values are all the smallest points that lie near the global minimum + a 0.08 threshold.

min_value = zz.min()
np.sum((min_value <= zz) & (zz <= (min_value + 0.08)))
4

And we can see that the coordinates we get back for these 4 points almost land on the minimum values.

xx[(min_value <= zz) & (zz <= (min_value + 0.08))], yy[(min_value <= zz) & (zz <= (min_value + 0.08))]
(array([-3.78787879,  3.58585859,  2.97979798,  2.97979798]),
 array([-3.28282828, -1.86868687,  1.96969697,  2.07070707]))
himmelblau(xx[(min_value <= zz) & (zz <= (min_value + 0.08))], yy[(min_value <= zz) & (zz <= (min_value + 0.08))])
array([0.00436989, 0.00616669, 0.04257242, 0.07413325])

The problem with defining it in this way is that in general, that 0.08 value is more or less wrong for other functions (actually most of the time).

This heuristic can’t possibly work for all the functions we could implement, so we need a better way, and not rely on this heuristic.

Using the already defined minimums

One simple and efficient way of doing this is just hard-coding the values, but this means transforming our function into a class with multiple properties (the function evaluation, minimum and possibly others..)

min_coordinates = np.array([
    [3.0, 2.0],
    [-2.805118, 3.131312],
    [-3.779310, -3.283186],
    [3.584428, -1.848126]
])

himmelblau(min_coordinates[:, 0], min_coordinates[:, 1])
array([0.00000000e+00, 1.09892967e-11, 3.79786108e-12, 8.89437650e-12])

Implementing the function as a python class

From the lessons above we conclude that a generic function will need 3 methods:

  • one for evaluations of coordinates (the actual function)
  • one for specifying a region of interest (the bounding boxes for the mesh grid)
  • one for providing the minimum values for that function (the ROI - region of interest)
    • actually, the function should be continuous and unbound, but only a certain region, near their minimum has an interesting shape worth looking at.
import numpy as np
import numpy as np
from mpl_toolkits import mplot3d
from matplotlib import pyplot as plt
import matplotlib.colors as colors
from functools import lru_cache

class Ifunction:
    def __call__(*args) -> np.ndarray:
        pass

    def min(self) -> np.ndarray:
        """
        Returns a np.array of the shape (k, 3) with all the minimum k points of this function. 
        The two values of the second dimension are the (x,y,z) coordinates of the minimum values 
        """
        return self.coord(self._min())

    def coord(self, points: np.ndarray) -> np.ndarray:
        """
        Returns a np.array of the shape (k, 3) with all the evaluations of the given
        k points of this function. 
        The three values of the second dimension are the (x,y,z) coordinates of the minimum values 
        """
        z = np.expand_dims(self(points[:, 0], points[:, 1]), axis=-1)
        return np.hstack((
            points,
            z
        ))

    def domain(self) -> np.ndarray:
        """
        Returns the ((x_min, x_max), (y_min, y_max)) values where this function 
        is of most interest
        """
        pass

Assembling everything into a single object

class himmelblau(Ifunction):
    def __call__(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
        """
        Computes the given function
        """
        return (x**2+y-11)**2 + (x+y**2-7)**2
    
    def _min(self) -> np.ndarray:
        """
        Returns a np.array of the shape (k, 2) with all the minimum k points of this function. 
        The two values of the second dimension are the (x,y) coordinates of the minimum values 
        """
        return np.array([
            [3.0, 2.0],
            [-2.805118, 3.131312],
            [-3.779310, -3.283186],
            [3.584428, -1.848126]
        ])

    def domain(self) -> np.ndarray:
        """
        Returns the ((x_min, x_max), (y_min, y_max)) values where this function 
        is of most interest
        """
        return np.array([
            [-5, 5],
            [-5, 5]
        ])

himmelblau().min()
array([[ 3.00000000e+00,  2.00000000e+00,  0.00000000e+00],
       [-2.80511800e+00,  3.13131200e+00,  1.09892967e-11],
       [-3.77931000e+00, -3.28318600e+00,  3.79786108e-12],
       [ 3.58442800e+00, -1.84812600e+00,  8.89437650e-12]])

Refining the 3D contour plot

We will need to compute the contour of a function (actually the evaluation of the meshgrid) in multiple places in our code and so we will extract this computation into a function and wrap it into a lru_cache so we don’t have to redo the math if we keep reusing the same parameters.

@lru_cache(maxsize=None)
def contour(function: Ifunction, x_min=-5, x_max=5, y_min=-5, y_max=5, mesh_size=100):
    """
    Returns a (x, y, z) 3D coordinates, where `z = function(x,y)` evaluated on a 
    mesh of size (mesh_size, mesh_size) generated from the linear space defined by 
    the boundaries returned by `function.domain()`.

    This function is usually used for displaying the contour of the diven function.
    """
    xx, yy = np.meshgrid( 
        np.linspace(x_min, x_max, num=mesh_size),
        np.linspace(y_min, y_max, num=mesh_size)
    )
    zz = function(xx, yy)
    return xx, yy, zz

There are also some points that I’d like to improve:

  • adding the ability to zoom in and out of the plot
  • adding a 2D contour projection beneath the plot
  • adding the minimum values on the 3D plot
  • adding the ability to control the rotation

Let’s take them one by one

Adding zoom param on the boundaries

The idea of zooming is actually closely related to the min and max boundaries set on the plots:

  • if we reduce the distance between them, we zoom in
  • if we increase the distance, we zoom out

Since zoom is usually a proportion of the current view, there’s a bit of math to do in order to get the computations correct.

This section is highly specific and beside the main point but I like to keep it just because it’s time spent that I don’t want to be lost to history. If you’re reading this, you might probably want to skip it since it isn’t clever nor informative..

Code

So the main idea of the zooming that I’m going to implement is described by the image bellow:

  • we have some set ‘min’ and ‘max’ values
  • we have a mean (center) between them
  • we want to move both min and max closer to the mean by the same amount (percentage), f called a zooming factor.

image.png

Initially, I’ve written this function, and it works quite well.

def zoom(x_domain, y_domain, zoom_factor):
    (x_min, x_max), (y_min, y_max) = x_domain, y_domain

    # zoom
    x_mean = (x_min + x_max) / 2
    y_mean = (y_min + y_max) / 2

    x_min = x_min + (x_mean - x_min) * zoom_factor
    x_max = x_max - (x_max - x_mean) * zoom_factor

    y_min = y_min + (y_mean - y_min) * zoom_factor
    y_max = y_max - (y_max - y_mean) * zoom_factor

    return (x_min, x_max), (y_min, y_max)

zoom((-5, 5), (-5, 5), 0.1)
((-4.5, 4.5), (-4.5, 4.5))

Funny enough, these can be vectorized by the numpy transformations bellow

d = np.array([(-5, 5), (-5, 5)])
zoom_factor = 0.1

means = d.mean(axis=-1)             # computing the means
distances = np.abs(d - means)
change = distances * zoom_factor    # compute the change need 
change_with_direction = change * [1, -1] # add signs for the direction of changes (mins should increase, maxes should decrese in value)
zoomed_d = d + change_with_direction

d + np.abs(d - d.mean(axis=-1)) * [1, -1] * zoom_factor # single line transformation
array([[-4.5,  4.5],
       [-4.5,  4.5]])

Now, we see that the above computation does, an abs where we eliminate the sign, and right after that we add it back by multiplying with [-1, 1]. If we think in terms of moving from the mean left and right we can simplify the formula a bit, as follows:

d.mean(axis=-1) - (d.mean(axis=-1) - d) * (1 - zoom_factor)
array([[-4.5,  4.5],
       [-4.5,  4.5]])

If we analytically decompose the (1 - 0.1) term and simplify the result, we remain with the bellow formula:

d + (d.mean(axis=-1) - d) * zoom_factor
array([[-4.5,  4.5],
       [-4.5,  4.5]])

Which is almost identical with the one we’ve started from but, as we’ve observed, does not trim the signs and adds them back.

Wrapping function

OK, now let’s put everything together and show what we have so far.

Code
def plot_function_3d(function: Ifunction, ax=None, azimuth=45, angle=45, zoom_factor=0, show_projections=False):    
    (x_min, x_max), (y_min, y_max) = zoom(*function.domain(), zoom_factor)

    xx, yy, zz = contour(function, x_min, x_max, y_min, y_max)

    # evaluate once, use in multiple places
    zz_min = zz.min()
    zz_max = zz.max()
    norm = colors.Normalize(vmin=zz_min, vmax=zz_max)

    # put the 2d contour floor, a fit lower than the minimum to look like a reflection
    zz_floor_offset = int((zz_max - zz_min) * 0.065)

    # create 3D axis if not provided
    ax = ax if ax else plt.axes(projection='3d')

    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_zlim(zz.min() - zz_floor_offset, zz.max())

    min_coordinates = function.min()
    ax.scatter3D(min_coordinates[:, 0], min_coordinates[:, 1], min_coordinates[:, 2], marker='.', color='black', s=120, alpha=1, zorder=1)
    
    ax.contourf(xx, yy, zz, zdir='z', levels=300, offset=zz_min-zz_floor_offset, cmap='Spectral', norm=norm, alpha=0.5, zorder=1)
    if show_projections:
        ax.contourf(xx, yy, zz, zdir='x', levels=300, offset=xx.max()+1, cmap='gray', norm=colors.Normalize(vmin=xx.min(), vmax=xx.max()), alpha=0.05, zorder=1)
        ax.contourf(xx, yy, zz, zdir='y', levels=300, offset=yy.max()+1, cmap='gray', norm=colors.Normalize(vmin=yy.min(), vmax=yy.max()), alpha=0.05, zorder=1)

    ax.plot_surface(xx, yy, zz, cmap='Spectral', rcount=100, ccount=100, norm=norm, shade=False, antialiased=True, alpha=0.6)     
    # ax.plot_wireframe(xx, yy, zz, cmap='Spectral', rcount=100, ccount=100, norm=norm, alpha=0.5)     

    # apply rotation
    ax.view_init(azimuth, angle)

fig = plt.figure(figsize=(13, 10))
ax = plt.axes(projection='3d')
plot_function_3d(himmelblau(), ax=ax, azimuth=20, angle=225)

png

Let’s make it interactive so we can play with it a bit.

Code

Note: This code will probably not work in the blog post

import ipywidgets as widgets
from ipywidgets import interact, interact_manual

@interact
def plot_interactive(azimuth=45, rotation=45, zoom_factor=(-1, 1, 0.1), show_projections={True, False}):
    return plot_function_3d(function=himmelblau(), azimuth=azimuth, angle=rotation, zoom_factor=zoom_factor, show_projections=show_projections)

A 2d contour with angle rotation

Not that we have a 3D plot what can handle rotations, we need to allow this capability to the 2D plot as well, since we wish the two charts to move in sync.

What we need to do is use linear algebra and rotate the initial meshgrid of points. We accomplish this by simply multiplying it with a rotation matrix.

Adding rotation to the second plot

Code
angle = 40
xx, yy, zz = contour(himmelblau())


radians = angle * np.pi/180

# counter-clockwise rotation matrix
# https://stackoverflow.com/questions/29708840/rotate-meshgrid-with-numpy
rotation_matrix = np.array([
    [np.cos(radians), -np.sin(radians)],
    [np.sin(radians), np.cos(radians)]
])
xx, yy = np.einsum('ji, mni -> jmn', rotation_matrix, np.dstack([xx, yy]))

plt.contourf(xx, yy, zz, levels=1000, cmap='Spectral', norm=colors.Normalize(vmin=zz.min(), vmax=zz.max()), alpha=0.3)

png

We’ve successfully rotated the contour, but we see that the axis were left unchanged. So we didn’t rotate the plot, but merely it’s contents.

At this point, I guess there are three options:

  • use some image processing packages (like PIL, imagemagick), export the 2D plot as an image, rotate it and then display it
  • hide the axes so we don’t see that we’ve actually rotated only the content.
  • use a 3D plot (which can easily support rotation as we’ve seen) and set a perpendicular viewing angle (birds-eye view) so the plot looks like a 2D one.

The first option looks like the biggest hack of all since it involves adding at least 2 new dependencies for this sole purpose (plus the additional computation).

The last one is easy to see without further experiments, so we only need to see how option 2 looks like.

Drawing the 2D plot on a 3D Axis and viewed from above. This makes it possible to display the rotated axis as well but also makes the plot smaller.

Code
fig = plt.figure()

ax_ = fig.add_subplot(1, 2, 2, projection='3d')

angle = 125

## plot_function_2d
xx, yy = np.meshgrid( 
    np.linspace(-5, 5, num=100),
    np.linspace(-5, 5, num=100)
)
zz = himmelblau()(xx, yy)

ax_.contour(xx, yy, zz, levels=200, offset=0, cmap='Spectral', norm=colors.Normalize(vmin=zz.min(), vmax=zz.max()), alpha=0.5, zorder=1)
ax_.view_init(90, angle)

plt.tight_layout()

png

I think I will stick with the 2D approach as the contour is bigger.

Rotation function

The rotate function bellow looks way more complex than we’ve sketched it out above, and this is because it is now handling two use cases that it needs to disambiguate:

  • the case when it receives meshgrids (xx, yy) when called to rotate the contour plots & the case when it receives points we want to draw on the contour (like the minimum values or the trace of an optimizer) which are not of the same shape as a meshgrid so have a different type of math operation applied on them (mainly due to vectorization).
Code
from typing import Tuple

def rotate(_x: np.ndarray, _y: np.ndarray, angle=45) -> Tuple[np.ndarray, np.ndarray]:
    def __is_mesh(x: np.ndarray) -> bool:
        __is_2d = len(x.shape) == 2
        if __is_2d:
            __is_repeated_on_axis_0 = np.allclose(x.mean(axis=0), x[0, :])
            __is_repeated_on_axis_1 = np.allclose(x.mean(axis=1), x[:, 0])
            __is_repeated_array = __is_repeated_on_axis_0 or __is_repeated_on_axis_1
            return __is_repeated_array
        else:
            return False

    def __is_single_dimension(x: np.ndarray) -> bool:
        # when the function only has one minimum the initial x will have the shape (1,)
        # and doing a np.squeeze before calling this function will result in a x of shape () 
        # when we reach this control point 
        _is_scalar_point = len(x.shape) == 0
        return len(x.shape) == 1 or _is_scalar_point

    def __rotate_mesh(xx: np.ndarray, yy: np.ndarray) -> np.ndarray:
        xx, yy = np.einsum('ij, mnj -> imn', rotation_matrix, np.dstack([xx, yy]))
        return xx, yy

    def __rotate_points(x: np.ndarray, y: np.ndarray) -> np.ndarray:
        points = np.hstack((x[:, np.newaxis], y[:, np.newaxis]))
        # anti-clockwise rotation matrix
        x, y = np.einsum('mi, ij -> jm',  points, np.array([
            [np.cos(radians), -np.sin(radians)],
            [np.sin(radians), np.cos(radians)]
        ]))
        return x, y

    # apply rotation
    angle = (angle + 90) % 360
    radians = angle * np.pi/180

    # clockwise rotation matrix
    rotation_matrix = np.array([
        [np.cos(radians), np.sin(radians)],
        [-np.sin(radians), np.cos(radians)]
    ])

    if __is_mesh(_x) and __is_mesh(_y):
        _x, _y = __rotate_mesh(_x, _y)
    elif __is_single_dimension(np.squeeze(_x)) and __is_single_dimension(np.squeeze(_y)):
        def __squeeze(_x):
            """
            We need to reduce the redundant 1 domensions from shapes like (3, 1, 2) to (3, 2), 
            but at the same time, making sure we don't end up with scalar values (going from (1, 1) to a shape ())

            We need at least a shape of (1,)
            """
            if len(np.squeeze(_x).shape) == 0:
                return np.array([np.squeeze(_x)])
            else:
                return np.squeeze(_x)

        _x, _y = __squeeze(_x), __squeeze(_y) 
        _x, _y = __rotate_points(_x, _y)
    else:
        raise AssertionError(f"Unknown rotation types for shapes {_x.shape} and {_y.shape}")
    return _x, _y

The np.einsum part is the core of the function, and it’s operation is more complex, leaving its explanation for a future post.

Another interesting bit of it is the way the code bellow works, namely how the parameters are returned.

xx, yy = np.einsum('ij, mnj -> imn', rotation_matrix, np.dstack([xx, yy]))

Normally, any np. prefixed operation returns a np.ndarray but in this case we see that the enisum function is able to de-structure the return into two separate parameters xx and yy. This happens with the help of the np.dstack function.

Let’s see what it does, bellow:

print(np.dstack([xx, yy]).shape)
np.dstack([xx, yy])[:3, :3, :]
(100, 100, 2)





array([[[-5.       , -5.       ],
        [-4.8989899, -5.       ],
        [-4.7979798, -5.       ]],

       [[-5.       , -4.8989899],
        [-4.8989899, -4.8989899],
        [-4.7979798, -4.8989899]],

       [[-5.       , -4.7979798],
        [-4.8989899, -4.7979798],
        [-4.7979798, -4.7979798]]])

So dstack just made a stack of it’s arguments, on a new, rightmost axis (the last 2 value from the shape). This shape, coupled with the specified einsum operation, where both i and j are equal to 2 gives a result with the shape of (2, 100, 100) semantically similar to a tuple (xx.shape(100, 100), yy.shape(100, 100)). This shape enbles the python language to de-structure this return into 2 independent values, the ones we actually wanted.

Single points rotation

The gist of single points rotation is a matrix multiplication but since the meshgrid is implemented in einsum notation as is rather elegant, at least because of the arguments decomposition trick, I’ll try experimenting a bit to replicate the matrix multiplication it with einsum as well.

Code
angle = 225

# apply rotation
angle = (angle) % 360
radians = angle * np.pi/180


# clockwise rotation matrix
rotation_matrix = np.array([
    [np.cos(radians), -np.sin(radians)],
    [np.sin(radians), np.cos(radians)]
])

coords[:, [0, 1]] @ rotation_matrix
(array([[-1.41421356e+00, -2.22044605e-16],
        [ 4.44089210e-16, -2.82842712e+00],
        [-4.44089210e-16,  4.24264069e+00],
        [-5.65685425e+00, -8.88178420e-16]]),)

Ok, this is the (correct) result we’re getting using the plain matrix multiplication. Our goal is the einsum operation that computes these, but reshapes them as well in (2, 4) so we can later on decompose them as (4,) and (4,) into x and y coordinates.

We start of with 4 points, where we also have the z value (the function evaluation on the first two coordinates), so a shape of (4, 2+1)

coords.shape
(4, 3)

We’ll be receiving x and y values in the following shape, so this is what we have to work with:

_x = coords[:, [0]]
_y = coords[:, [1]]
_x.shape, _y.shape
((4, 1), (4, 1))

Since the rotation_matrix is of shape (2,2), the final einsum operation is:

mi, ij -> jm

or 
(nr_points, 2), (2, 2) -> (2, nr_points)
__x, __y = np.einsum('mi, ij -> jm',  np.hstack((coords[:, [0]], coords[:, [1]])), rotation_matrix)
__x, __y
(array([-1.41421356e+00,  4.44089210e-16, -4.44089210e-16, -5.65685425e+00]),
 array([-2.22044605e-16, -2.82842712e+00,  4.24264069e+00, -8.88178420e-16]))

Wrapping function

Again, let’s wrap everything up in a single function to see where we are up until now.

Coming back to the wrapping function, we wan now also use the rotation function to rotate both the contour and the minimum values.

Code
def plot_function_2d(function: Ifunction, ax=None, angle=45, zoom_factor=0):
    (x_min, x_max), (y_min, y_max) = zoom(*function.domain(), zoom_factor)

    xx, yy, zz = contour(function, x_min, x_max, y_min, y_max)

    ax = ax if ax else plt.gca()
    
    xx, yy = rotate(xx, yy, angle=angle)    # I wonder why I shoudn't also rotate zz?!

    ax.contour(xx, yy, zz, levels=200, cmap='Spectral', norm=colors.Normalize(vmin=zz.min(), vmax=zz.max()), alpha=0.3)
    
    min_coords = function.min()
    ax.scatter(*rotate(min_coords[:, 0], min_coords[:, 1], angle=angle))
    ax.axis("off")

plot_function_2d(himmelblau())

png

A combined plot, with both the 3D and the 2D contours side by side

Using both 2D and 3D plotting function we can combine them to be shown on the same figure, and using the same rotation.

As always, we will first sketch the code in bulk then tie it up in a single function.

Code
fig = plt.figure(figsize=(26, 10))

ax_3d = fig.add_subplot(1, 2, 1, projection='3d')
ax_2d = fig.add_subplot(1, 2, 2)

angle = 225

function = himmelblau()

plot_function_3d(function, ax=ax_3d, azimuth=30, angle=angle)
plot_function_2d(function, ax=ax_2d, angle=angle)

png

Wrapping function with both figures

The 2D and 3D plots can now be combined in a unified function that you can see bellow:

Code
def plot_function(function: Ifunction, angle=45, zoom_factor=0, azimuth_3d=30, fig=None, ax_2d=None, ax_3d=None):
    fig = plt.figure(figsize=(26, 10)) if fig is None else fig
    ax_3d = fig.add_subplot(1, 2, 1, projection='3d') if ax_3d is None else ax_3d
    ax_2d = fig.add_subplot(1, 2, 2) if ax_2d is None else ax_2d

    plot_function_3d(function=function, ax=ax_3d, azimuth=azimuth_3d, angle=angle, zoom_factor=zoom_factor)
    plot_function_2d(function=function, ax=ax_2d, angle=angle, zoom_factor=zoom_factor)

    return fig, ax_3d, ax_2d

plot_function(himmelblau(), angle=225)

png

Implementing a few other functions so see how they look like

I guess we now have sufficient machinery to try plotting different interesting functions. But first, let’s just define them.

mc_cormick

Code
class mc_cormick(Ifunction):
    def __call__(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
        """
        Computes the given function
        """
        return np.sin(x+y) + (x-y)**2-1.5*x+2.5*y+1
    
    def _min(self) -> np.ndarray:
        """
        Returns a np.array of the shape (k, 2) with all the minimum k points of this function. 
        The two values of the second dimension are the (x,y) coordinates of the minimum values 
        """
        return np.array([
            [-0.54719, -1.54719],
        ])

    def domain(self) -> np.ndarray:
        """
        Returns the ((x_min, x_max), (y_min, y_max)) values where this function 
        is of most interest
        """
        return np.array([
            [-1.5, 4],
            [-3, 4]
        ])

mc_cormick().min()
array([[-0.54719   , -1.54719   , -1.91322295]])
plot_function(mc_cormick(), angle=225)

png

holder_table

Code
class holder_table(Ifunction):
    def __call__(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
        """
        Computes the given function
        """
        return -np.abs(np.sin(x)*np.cos(y)*np.exp(np.abs(1-np.sqrt(x**2+y**2)/np.pi)))
    
    def _min(self) -> np.ndarray:
        """
        Returns a np.array of the shape (k, 2) with all the minimum k points of this function. 
        The two values of the second dimension are the (x,y) coordinates of the minimum values 
        """
        return np.array([
            [8.05502, 9.66459],
            [8.05502, -9.66459],
            [-8.05502, 9.66459],
            [-8.05502, -9.66459],
        ])

    def domain(self) -> np.ndarray:
        """
        Returns the ((x_min, x_max), (y_min, y_max)) values where this function 
        is of most interest
        """
        return np.array([
            [-10, 10],
            [-10, 10]
        ])

holder_table().min()
array([[  8.05502   ,   9.66459   , -19.20850257],
       [  8.05502   ,  -9.66459   , -19.20850257],
       [ -8.05502   ,   9.66459   , -19.20850257],
       [ -8.05502   ,  -9.66459   , -19.20850257]])
plot_function(holder_table(), angle=225)

png

Drawing lines on the 3D plot

Some optimization strategies work by starting off with multiple initial points, that are constantly updated on each iteration of the optimization.

These points are viewed (as is the case of Nelder-Mead algo) as a convex polygon, which means we need to not only show the shapes but also the lines connecting them.

To enable this use-case we need to experiment a bit with showing lines on both plots.

Dealing with the zorder of 3D axes

The problems with lines though (and it was hard to anticipate how difficult this might be) is that in the case of 3D plots, they are always shown behind the contour plot.

I don’t really understand why this decision was made but it seems to be a side-effect of having a 3D framework squashed on top of a 2D designed one. There are some rough edges and this is one of them.

It seems that what is actually happening is that the zorder parameter which usually control the order in which figures are drawn on the screen is ignored for the 3D objects.

One suggestion was to override the zorder attribute in a custom class to force some objects (lines in this case) to be shown above the graph.

Code
fig = plt.figure(figsize=(13, 10))
ax = plt.axes(projection='3d')

plot_function_3d(himmelblau(), ax=ax, azimuth=20, angle=225)

from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection
class FixZorderLine3D(Line3D):
    @property
    def zorder(self):
        return 4

    @zorder.setter
    def zorder(self, value):
        pass


lines = ax.plot([3, -4], [4, 2], [500, 0], color="red", zorder=1, alpha=1)
# hack fix for the zorder fo the Line3D and plot_contour, taken from
# https://stackoverflow.com/questions/20781859/drawing-a-line-on-a-3d-plot-in-matplotlib
for line in lines:
    line.__class__ = FixZorderLine3D

png

Going even more into the guts of matplotlib (way more than I’d like, actually) we see that the lines are converted into a collection of lines, so we might simplify the last loop where we did the recasting of the class of each line, into a single recast, on the full collection object.

Code
from mpl_toolkits.mplot3d.art3d import Line3DCollection
class FixZorderCollection(Line3DCollection):
    @property
    def zorder(self):
        return 1000

    @zorder.setter
    def zorder(self, value):
        pass

fig = plt.figure(figsize=(13, 10))
ax = plt.axes(projection='3d')

# hack fix for the zorder fo the Line3D and plot_contour, taken from
# https://stackoverflow.com/questions/20781859/drawing-a-line-on-a-3d-plot-in-matplotlib
plot_function_3d(himmelblau(), ax=ax, azimuth=30, angle=225)
ax.plot_wireframe(np.array([[1], [-2], [3], [4]]), np.array([[1], [2], [-3], [4]]), np.array([[himmelblau()(1, 1)+1], [himmelblau()(-2, 2) + 1], [100], [500]]))
ax.collections[-1].__class__ = FixZorderCollection

png

This works, as intended, but besides being a huge ugly hack, it’s also strange to look at since now, the lines are actually floating above the 3D contour and not inside as it would be normal.

What worked in the end is a careful calibration of the transparency of the contours and the fact that multiple surfaces of the plot overlapped one on top of the other give the impression of a higher transparency, making the line look “beneath”.

So even though the line is actually drawn under the contour, because the contour has transparency (which isn’t too light) makes it so that when multiple shapes of the contour overlap, their transparencies effect add up, making the line look between the two… somehow..

Code
fig = plt.figure(figsize=(13, 10))
ax = plt.axes(projection='3d')

plot_function_3d(himmelblau(), ax=ax, azimuth=30, angle=225)
ax.plot_wireframe(np.array([[1], [-2], [3], [4]]), np.array([[1], [2], [-3], [4]]), np.array([[himmelblau()(1, 1)+1], [himmelblau()(-2, 2) + 1], [100], [500]]))

png

Now, let’s put join foreces with the 2D plot and display the lines on both graphs.

Code
points = np.array([
    [1, 1],
    [-2, 2],
    [3, -3],
    [4, 4]
])

function = himmelblau()
coords = function.coord(points)
rotation=225

fig, ax_3d, ax_2d = plot_function(function, angle=rotation)

ax_3d.plot_wireframe(coords[:, [0]], coords[:, [1]], coords[:, [2]], color='r')    
ax_2d.plot(*rotate(coords[:, 0], coords[:, 1], angle=rotation), color='r')

png

Code
import ipywidgets as widgets
from ipywidgets import interact, interact_manual

@interact
def plot_interactive(azimuth=(0, 90), rotation=45, zoom_factor=(-1, 1, 0.1), lines=[True, False]):
    fig, ax_3d, ax_2d = plot_function(function, angle=rotation, azimuth_3d=azimuth, zoom_factor=zoom_factor)

    if lines:
        ax_3d.plot_wireframe(coords[:, [0]], coords[:, [1]], coords[:, [2]])    
        ax_2d.plot(*rotate(coords[:, 0], coords[:, 1], angle=rotation))
    else:
        ax_3d.scatter3D(coords[:, [0]], coords[:, [1]], coords[:, [2]])
        ax_2d.scatter(*rotate(coords[:, 0], coords[:, 1], angle=rotation))

Show me the money!

Now that we have all in place we just need to show how this is useful, which means displaying an actual use-case, an optimization taking place.

SGD optimizer tracing

The simplest optimizer to implement that I know of is stochastic gradient descent (SGD) which has the update rule:

\(x = x - \alpha * \frac{\partial f(x, y)}{\partial x}\) \(y = y - \alpha * \frac{\partial f(x, y)}{\partial y}\)

Unfortunately, as you can see from the formula, there is that partial derivative part that we have to deal with in order to make this work.

We could (at least for this instance) compute the partial derivative function analytically (by hand) and define a function to implement that. That also means that we need to update the function interface class and add a custom method with this gradient function on all functions we wish to optimize (deriving the gradient on all functions we wish to display). Also, this gradient functionality is not actually generic (or used at all) by all the types of optimisations what we wish to investigate so incorporating it into the function might not be a good place to put it.

So this, gradient, brings with it a lot of changes throughout.

Fortunately, we can use auto-differentiation. JAX one such library that (given some constraints) computes the gradient function for you.

We first define, our initial toy function, and even though this isn’t required, change the type-hints to the appropriate class we are expecting to work with (jax's own ndarray objects)

import jax.numpy as jnp

def himmelblau_jnp(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    return (x**2+y-11)**2 + (x+y**2-7)**2

himmelblau_jnp(0, 1)
136

Then we call grad which compiles a callable for us that we can use to get the partial derivatives with. The argnum part is needed because we have to specify that there are two parameters the parent function uses, and we want the derivative to both of them.

When implementing Linear Regression or Neural Networks, all the parameters usually sit in a single large matrix, which is the only argument of the function so usually we only need argnums=0 which is the default. Except in this case where the parameters are passed individually.

from jax import grad

function_grad = grad(himmelblau_jnp, argnums=(0, 1))    # we want the derivative of both arguments
function_grad(2., 3.)

Now that we have all that, we define our sgd optimization routine and collect the resulting coordinates to show later on.

x, y = [0.], [0.]

def sgd_update(x, y, learning_rate=0.01):
    d_x, d_y = function_grad(x, y)
    x = x - learning_rate * d_x
    y = y - learning_rate * d_y
    return x, y

for i in range(100):
    _x, _y = sgd_update(x[-1], y[-1])        
    x.append(float(_x)), y.append(float(_y))

print(x, y)
Output
[0.0, 0.14000000059604645, 0.3364902436733246, 0.6047241687774658, 0.9560018181800842, 1.3865301609039307, 1.860503911972046, 2.3018486499786377, 2.626394271850586, 2.808932065963745, 2.8941497802734375, 2.9341514110565186, 2.955716371536255, 2.9689831733703613, 2.9778130054473877, 2.9839375019073486, 2.988283395767212, 2.99141001701355, 2.9936795234680176, 2.995337724685669, 2.996554374694824, 2.997450113296509, 2.9981112480163574, 2.9985997676849365, 2.9989614486694336, 2.9992294311523438, 2.9994280338287354, 2.99957537651062, 2.9996845722198486, 2.999765634536743, 2.999825954437256, 2.999870777130127, 2.999904155731201, 2.9999287128448486, 2.9999470710754395, 2.9999606609344482, 2.9999706745147705, 2.9999783039093018, 2.999983787536621, 2.999988079071045, 2.9999911785125732, 2.999993324279785, 2.999994993209839, 2.9999961853027344, 2.99999737739563, 2.999997854232788, 2.9999983310699463, 2.9999988079071045, 2.9999990463256836, 2.9999992847442627, 2.9999992847442627, 2.999999523162842, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421] [0.0, 0.2199999988079071, 0.49515005946159363, 0.8301041126251221, 1.2156579494476318, 1.6151021718978882, 1.9584805965423584, 2.1722240447998047, 2.2410364151000977, 2.220111846923828, 2.1723852157592773, 2.128113269805908, 2.093951940536499, 2.068840503692627, 2.050553798675537, 2.037219285964966, 2.027461528778076, 2.020296335220337, 2.0150198936462402, 2.0111260414123535, 2.0082476139068604, 2.006117105484009, 2.0045387744903564, 2.003368616104126, 2.0025007724761963, 2.001856803894043, 2.0013787746429443, 2.001024007797241, 2.000760555267334, 2.0005648136138916, 2.0004196166992188, 2.0003116130828857, 2.0002315044403076, 2.0001718997955322, 2.0001277923583984, 2.0000948905944824, 2.000070571899414, 2.0000524520874023, 2.0000388622283936, 2.0000288486480713, 2.000021457672119, 2.0000159740448, 2.000011920928955, 2.0000088214874268, 2.000006675720215, 2.000005006790161, 2.0000038146972656, 2.000002861022949, 2.000002145767212, 2.0000016689300537, 2.0000011920928955, 2.0000009536743164, 2.0000007152557373, 2.000000476837158, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579]

A bit of quick sketching and fiddling around with color maps, angles and rotations and we get this result:

angle = 45
fig, ax_3d, ax_2d = plot_function(himmelblau(), angle=angle)

ax_2d.scatter(*rotate(np.array(x), np.array(y), angle=angle), c=np.arange(len(x)), cmap='binary')

png

Now the cool things is, that we can use jax straight as a numpy replacement, inside the same function definition where before we assumed (because that’s what python does) we worked with numpy (notice bellow, we’ve just used the himmelblau class without making any change specific to jax).

If we pass in as arguments something that is jax compatible (like jnp.arrays or plaint floats) then we can call grad directly on the function and we will have the derivatives computed, for free, for us!

Now, how is jax able to do this (take a piece of arbitrary python and numpy code) and convert it into something derivative worthy, by only changing the underlying data-structure that you pass is a somewhat mystery. I assume that the jax shadow objects (like the jnp.array trace the objects that they interact with and so, create a dynamic graph of operations, something akin to what PyTorch or TensorFlow 2.0 Eager do).

x, y = [0.], [0.]

function_grad = grad(himmelblau(), argnums=(0, 1)) # argnums=(0, 1) sais that we want to see in the results the derivative of both arguments x and y

def sgd_update(x, y, learning_rate=0.01):
    d_x, d_y = function_grad(x, y)
    x = x - learning_rate * d_x
    y = y - learning_rate * d_y
    return x, y

for i in range(100):
    _x, _y = sgd_update(x[-1], y[-1])        
    x.append(float(_x)), y.append(float(_y))

print(x, y)
Output
[0.0, 0.14000000059604645, 0.3364902436733246, 0.6047241687774658, 0.9560018181800842, 1.3865301609039307, 1.860503911972046, 2.3018486499786377, 2.626394271850586, 2.808932065963745, 2.8941497802734375, 2.9341514110565186, 2.955716371536255, 2.9689831733703613, 2.9778130054473877, 2.9839375019073486, 2.988283395767212, 2.99141001701355, 2.9936795234680176, 2.995337724685669, 2.996554374694824, 2.997450113296509, 2.9981112480163574, 2.9985997676849365, 2.9989614486694336, 2.9992294311523438, 2.9994280338287354, 2.99957537651062, 2.9996845722198486, 2.999765634536743, 2.999825954437256, 2.999870777130127, 2.999904155731201, 2.9999287128448486, 2.9999470710754395, 2.9999606609344482, 2.9999706745147705, 2.9999783039093018, 2.999983787536621, 2.999988079071045, 2.9999911785125732, 2.999993324279785, 2.999994993209839, 2.9999961853027344, 2.99999737739563, 2.999997854232788, 2.9999983310699463, 2.9999988079071045, 2.9999990463256836, 2.9999992847442627, 2.9999992847442627, 2.999999523162842, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421, 2.999999761581421] [0.0, 0.2199999988079071, 0.49515005946159363, 0.8301041126251221, 1.2156579494476318, 1.6151021718978882, 1.9584805965423584, 2.1722240447998047, 2.2410364151000977, 2.220111846923828, 2.1723852157592773, 2.128113269805908, 2.093951940536499, 2.068840503692627, 2.050553798675537, 2.037219285964966, 2.027461528778076, 2.020296335220337, 2.0150198936462402, 2.0111260414123535, 2.0082476139068604, 2.006117105484009, 2.0045387744903564, 2.003368616104126, 2.0025007724761963, 2.001856803894043, 2.0013787746429443, 2.001024007797241, 2.000760555267334, 2.0005648136138916, 2.0004196166992188, 2.0003116130828857, 2.0002315044403076, 2.0001718997955322, 2.0001277923583984, 2.0000948905944824, 2.000070571899414, 2.0000524520874023, 2.0000388622283936, 2.0000288486480713, 2.000021457672119, 2.0000159740448, 2.000011920928955, 2.0000088214874268, 2.000006675720215, 2.000005006790161, 2.0000038146972656, 2.000002861022949, 2.000002145767212, 2.0000016689300537, 2.0000011920928955, 2.0000009536743164, 2.0000007152557373, 2.000000476837158, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579, 2.000000238418579]

Let’s try a different function, making sure that it also works in other instances.

x, y = [0.], [0.]

function_grad = grad(mc_cormick(), argnums=(0, 1)) # argnums=(0, 1) sais that we want to see in the results the derivative of both arguments x and y

def sgd_update(x, y, learning_rate=0.01):
    d_x, d_y = function_grad(x, y)
    x = x - learning_rate * d_x
    y = y - learning_rate * d_y
    return x, y

for i in range(100):
    _x, _y = sgd_update(x[-1], y[-1])        
    x.append(float(_x)), y.append(float(_y))

print(x, y)
Output
---------------------------------------------------------------------------

Exception                                 Traceback (most recent call last)

<ipython-input-23-aea1ca96ba1d> in <module>()
     10 
     11 for i in range(100):
---> 12     _x, _y = sgd_update(x[-1], y[-1])
     13     x.append(float(_x)), y.append(float(_y))
     14 


<ipython-input-23-aea1ca96ba1d> in sgd_update(x, y, learning_rate)
      4 
      5 def sgd_update(x, y, learning_rate=0.01):
----> 6     d_x, d_y = function_grad(x, y)
      7     x = x - learning_rate * d_x
      8     y = y - learning_rate * d_y


/usr/local/lib/python3.6/dist-packages/jax/api.py in grad_f(*args, **kwargs)
    381   @wraps(fun, docstr=docstr, argnums=argnums)
    382   def grad_f(*args, **kwargs):
--> 383     _, g = value_and_grad_f(*args, **kwargs)
    384     return g
    385 


/usr/local/lib/python3.6/dist-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
    438     f_partial, dyn_args = argnums_partial(f, argnums, args)
    439     if not has_aux:
--> 440       ans, vjp_py = _vjp(f_partial, *dyn_args)
    441     else:
    442       ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)


/usr/local/lib/python3.6/dist-packages/jax/api.py in _vjp(fun, *primals, **kwargs)
   1454   if not has_aux:
   1455     flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1456     out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
   1457     out_tree = out_tree()
   1458   else:


/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    104 def vjp(traceable, primals, has_aux=False):
    105   if not has_aux:
--> 106     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    107   else:
    108     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)


/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
     93   _, in_tree = tree_flatten(((primals, primals), {}))
     94   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
---> 95   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
     96   out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
     97   assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)


/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type)
    436   with new_master(trace_type, bottom=bottom) as master:
    437     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 438     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    439     assert not env
    440     del master


/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:


<ipython-input-8-411bd359f427> in __call__(self, x, y)
      4         Computes the given function
      5         """
----> 6         return np.sin(x+y) + (x-y)**2-1.5*x+2.5*y+1
      7 
      8     def _min(self) -> np.ndarray:


/usr/local/lib/python3.6/dist-packages/jax/core.py in __array__(self, *args, **kw)
    372 
    373   def __array__(self, *args, **kw):
--> 374     raise Exception("Tracer can't be used with raw numpy functions. "
    375                     "You might have\n"
    376                     "  import numpy as np\n"


Exception: Tracer can't be used with raw numpy functions. You might have
  import numpy as np
instead of
  import jax.numpy as jnp

Unfortunately, while I thought that all functions will work out of the box just by replacing the np.arrays with jnp.arrays there is also another constraint on, using the jax derived methods that operate on these arrays, replacing the np.<method> ones.

In the case of mc_cormick class, the __call__ function uses np.sin, a function that comes form numpy and which should be replaced with jnp.sin. This means that we need to either juse the jnp. prefix in all previously written code (and be explicit in doing this) or doing import jax.numpy as np and overriding the pure numpy code.

The first is more desirable while the latter is more pragmatic. I’m going to go with the second one, and update all my functions to use explicitly the jnp. prefix so I know what I deal with.

import jax.numpy as jnp 

class jax_mc_cormick(mc_cormick):
    def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> np.ndarray:
        """
        Computes the given function
        """
        return jnp.sin(x+y) + (x-y)**2-1.5*x+2.5*y+1

class jax_holder_table(holder_table):
    def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> np.ndarray:
        """
        Computes the given function
        """
        return -jnp.abs(jnp.sin(x)*jnp.cos(y)*jnp.exp(jnp.abs(1-jnp.sqrt(x**2+y**2)/jnp.pi)))

x, y = [0.], [0.]
function_grad = grad(jax_mc_cormick(), argnums=(0, 1)) # argnums=(0, 1) sais that we want to see in the results the derivative of both arguments x and y

def sgd_update(x, y, learning_rate=0.01):
    d_x, d_y = function_grad(x, y)
    x = x - learning_rate * d_x
    y = y - learning_rate * d_y
    return x, y

for i in range(100):
    _x, _y = sgd_update(x[-1], y[-1])        
    x.append(float(_x)), y.append(float(_y))

print(x, y)
Output
[0.0, 0.004999999888241291, 0.009204499423503876, 0.01265448797494173, 0.015389639884233475, 0.017448334023356438, 0.018867675215005875, 0.019683513790369034, 0.019930459558963776, 0.019641906023025513, 0.018850043416023254, 0.0175858773291111, 0.01587924174964428, 0.013758821412920952, 0.011252165772020817, 0.008385710418224335, 0.005184791516512632, 0.0016736702527850866, -0.0021244490053504705, -0.006187397986650467, -0.010494021698832512, -0.0150241544470191, -0.019758598878979683, -0.024679094552993774, -0.02976830117404461, -0.03500977158546448, -0.04038791358470917, -0.04588798061013222, -0.051496028900146484, -0.0571988970041275, -0.06298418343067169, -0.06884019821882248, -0.07475595921278, -0.08072115480899811, -0.086726114153862, -0.09276177734136581, -0.09881968051195145, -0.10489190369844437, -0.11097107827663422, -0.11705033481121063, -0.12312329560518265, -0.1291840374469757, -0.13522708415985107, -0.14124736189842224, -0.14724019169807434, -0.15320125222206116, -0.15912659466266632, -0.16501259803771973, -0.17085592448711395, -0.17665356397628784, -0.18240275979042053, -0.18810100853443146, -0.19374607503414154, -0.19933593273162842, -0.20486877858638763, -0.21034300327301025, -0.21575717628002167, -0.22111006081104279, -0.22640056908130646, -0.23162776231765747, -0.23679085075855255, -0.24188917875289917, -0.24692220985889435, -0.25188952684402466, -0.25679078698158264, -0.2616257965564728, -0.26639440655708313, -0.2710965871810913, -0.2757323682308197, -0.2803018391132355, -0.28480517864227295, -0.28924259543418884, -0.29361438751220703, -0.2979208827018738, -0.3021624684333801, -0.3063395619392395, -0.3104526400566101, -0.31450217962265015, -0.3184886872768402, -0.32241275906562805, -0.32627496123313904, -0.3300758898258209, -0.33381617069244385, -0.33749642968177795, -0.34111735224723816, -0.3446796238422394, -0.34818390011787415, -0.35163089632987976, -0.35502129793167114, -0.3583558201789856, -0.3616351783275604, -0.3648601174354553, -0.3680313229560852, -0.37114953994750977, -0.3742155134677887, -0.3772299587726593, -0.3801936209201813, -0.3831072151660919, -0.3859714865684509, -0.3887871503829956, -0.39155489206314087] [0.0, -0.03500000014901161, -0.06919549405574799, -0.10260950028896332, -0.13526378571987152, -0.1671789586544037, -0.19837452471256256, -0.2288689911365509, -0.25867995619773865, -0.2878240942955017, -0.3163173198699951, -0.3441748023033142, -0.3714110255241394, -0.398039847612381, -0.42407456040382385, -0.44952794909477234, -0.47441232204437256, -0.4987395703792572, -0.5225211381912231, -0.5457682013511658, -0.5684915781021118, -0.5907018184661865, -0.6124091744422913, -0.6336236596107483, -0.6543551087379456, -0.6746131181716919, -0.6944071054458618, -0.7137464284896851, -0.7326401472091675, -0.7510972619056702, -0.7691265940666199, -0.7867369055747986, -0.803936779499054, -0.8207347393035889, -0.8371391296386719, -0.8531582951545715, -0.8688003420829773, -0.8840733170509338, -0.8989852070808411, -0.9135438799858093, -0.9277570843696594, -0.9416325092315674, -0.9551776051521301, -0.9683998823165894, -0.9813066124916077, -0.9939050078392029, -1.006202220916748, -1.018205165863037, -1.0299208164215088, -1.041355848312378, -1.0525169372558594, -1.0634106397628784, -1.0740432739257812, -1.0844212770462036, -1.0945507287979126, -1.1044377088546753, -1.1140880584716797, -1.1235077381134033, -1.132702350616455, -1.1416774988174438, -1.1504385471343994, -1.1589909791946411, -1.1673399209976196, -1.1754904985427856, -1.1834477186203003, -1.1912164688110352, -1.1988015174865723, -1.2062073945999146, -1.2134387493133545, -1.2204999923706055, -1.2273954153060913, -1.2341291904449463, -1.2407054901123047, -1.2471283674240112, -1.2534016370773315, -1.2595291137695312, -1.265514612197876, -1.2713617086410522, -1.277073860168457, -1.2826545238494873, -1.2881070375442505, -1.2934346199035645, -1.298640489578247, -1.3037277460098267, -1.308699369430542, -1.3135583400726318, -1.3183075189590454, -1.3229495286941528, -1.3274872303009033, -1.3319231271743774, -1.3362598419189453, -1.340499758720398, -1.344645380973816, -1.3486990928649902, -1.3526630401611328, -1.3565396070480347, -1.3603309392929077, -1.3640390634536743, -1.3676660060882568, -1.3712139129638672, -1.3746845722198486]

Animating the complex plot, with the optimization result

Since we have an optimization that progresses over time, it’s only natural that we show it as an animation. Like the one you see bellow:

Code
import matplotlib.animation as animation
from IPython.display import HTML, display

x, y = [0.], [0.]

def single_frame(i, _ax_3d, _ax_2d):
    print(type(_ax_3d), type(_ax_2d))
    _ax_3d.clear()
    _ax_2d.clear()
    
    _x, _y = sgd_update(x[-1], y[-1])        
    x.append(float(_x)), y.append(float(_y))

    plot_function(function, angle=angle, fig=fig, ax_3d=_ax_3d, ax_2d=_ax_2d)
    _ax_2d.scatter(*rotate(np.array(x), np.array(y), angle=angle), c=np.arange(len(x)), cmap='binary')

    _ax_2d.plot()
    # _ax_3d.plot()


function = himmelblau()

fig = plt.figure(figsize=(13,5))
fig, ax_3d, ax_2d = plot_function(function, fig=fig)
fig.tight_layout()

frames = 6
animator = animation.FuncAnimation(fig, single_frame, fargs=(ax_3d, ax_2d), frames=frames, interval=800, blit=False)
display(HTML(animator.to_html5_video()))
plt.close()

Also, some abstraction patterns seem to become obvious:

  • the initial plot (the two contour figures) are actually static
  • we may need to draw multiple traces on the same plot from multiple optimization strategies
  • the two points above may mean we need a unifying object for all the artifacts of the static plots (fig, ax_2d and ax_3d)

Just as quick useful visualization, look at what happens when you have a too high learning rate.

Code
import matplotlib.animation as animation
from IPython.display import HTML, display

x, y = [1.], [1.]

def single_frame(i, _ax_3d, _ax_2d):
    _ax_3d.clear()
    _ax_2d.clear()
    print(x, y)
    _x, _y = sgd_update(x[-1], y[-1], learning_rate=0.5)        
    x.append(float(_x)), y.append(float(_y))

    plot_function(function, angle=angle, fig=fig, ax_3d=_ax_3d, ax_2d=_ax_2d)
    _ax_2d.scatter(*rotate(np.array(x), np.array(y), angle=angle), c=np.arange(len(x)), cmap='binary')

    _ax_2d.plot()


function = jax_mc_cormick()
function_grad = grad(function, argnums=(0, 1))

fig = plt.figure(figsize=(13,5))
fig, ax_3d, ax_2d = plot_function(function, fig=fig)
fig.tight_layout()

frames = 6
animator = animation.FuncAnimation(fig, single_frame, fargs=(ax_3d, ax_2d), frames=frames, interval=800, blit=False)
display(HTML(animator.to_html5_video()))
plt.close()
[1.0] [1.0]
[1.0, 1.958073377609253] [1.0, -0.04192662239074707]
[1.0, 1.958073377609253, 0.8773366212844849] [1.0, -0.04192662239074707, 0.8773366212844849]
[1.0, 1.958073377609253, 0.8773366212844849, 1.7187578678131104] [1.0, -0.04192662239074707, 0.8773366212844849, -0.28124213218688965]
[1.0, 1.958073377609253, 0.8773366212844849, 1.7187578678131104, 0.4023146629333496] [1.0, -0.04192662239074707, 0.8773366212844849, -0.28124213218688965, 0.4023146629333496]
[1.0, 1.958073377609253, 0.8773366212844849, 1.7187578678131104, 0.4023146629333496, 0.8056254386901855] [1.0, -0.04192662239074707, 0.8773366212844849, -0.28124213218688965, 0.4023146629333496, -1.1943745613098145]
[1.0, 1.958073377609253, 0.8773366212844849, 1.7187578678131104, 0.4023146629333496, 0.8056254386901855, -0.9070665836334229] [1.0, -0.04192662239074707, 0.8773366212844849, -0.28124213218688965, 0.4023146629333496, -1.1943745613098145, -0.9070665240287781]

[Bonus] Using the jax optimizers

What would be even cooler is to use the optimizers implemented inside jax.

Here, I’ve reimplemented the SGD optimizer, which is easily enough to compute, but some others like adam which would be nice to also see in action would not be as easy to implement.

JAX is kind of rough on this front, and the optimizers (for now) sit inside the experimental submodule which means that their API might change in the future.

An optimizer is a function that has some initialization parameters, and which returns 3 functions:

  • init - is a function to which you pass all the initial values of your hidden parameters and you get back a state object, which is a pytree structure (some internal representation). This is a bit confusing and I’m guessing this intermediate pytree thing might disappear from the API in the near future.
  • update - is the function that does a single update pass over the whole parameters. It receives as inputs:
    • i - the count of the current iteration. This useful because, depending on the optimizer implementation, you can have different learning properties at each iteration (like some annealing strategy for the learning rate, etc..)
    • g - the gradient values (you get these by extracting the params from the state function, using the get_params function bellow (these are the variables that will get updated by the optimizer). Then pass these onto your gradient function and its results as input to this function.
    • state - that pytre structure that you’ve got after calling init (and which you’ll constantly replace with the result of this update function call)
  • get_params - a utils function that extracts the param object from a known state object (which is a pytree).

So the full flow of the above, in code is shown bellow:

 grad(himmelblau(), argnums=(0, 1))(*get_params(state))
(DeviceArray(-36., dtype=float32), DeviceArray(-32., dtype=float32))
from jax.experimental.optimizers import sgd


init, update, get_params = sgd(step_size=0.001)


state = init((1., 2.)) # initialize the optimizer with some initial weights and get a state back
print(state)
print(get_params(state))    # you use this function to extract the weight values from the state object
grad_function = grad(himmelblau(), argnums=(0, 1))  # you build the function that will compute your gradients
state = update(0, grad_function(*get_params(state)), state)    # you call update with a iteration number, the gradient of the params, and the previous state and you get back a new state 
print(state)
OptimizerState(packed_state=((1.0,), (2.0,)), tree_def=PyTreeDef(tuple, [*,*]), subtree_defs=(*, *))
(1.0, 2.0)
OptimizerState(packed_state=((DeviceArray(1.036, dtype=float32),), (DeviceArray(2.032, dtype=float32),)), tree_def=PyTreeDef(tuple, [*,*]), subtree_defs=(*, *))

And you can see the result of running 10 iterations of the above, in a loop. It moves to some direction, and I’m sure you’re eager to see where, on the graph…

grad_function = grad(himmelblau(), argnums=(0, 1))
def run():
    state = init((1., 2.))
    for i in range(10):
        params = get_params(state)
        yield params
        state = update(i, grad_function(*params), state)
    
[(float(x), float(y)) for x, y in run()]
[(1.0, 2.0),
 (1.0360000133514404, 2.0320000648498535),
 (1.0723856687545776, 2.062704086303711),
 (1.1091352701187134, 2.092081069946289),
 (1.1462260484695435, 2.1201066970825195),
 (1.18363356590271, 2.1467630863189697),
 (1.22133207321167, 2.1720387935638428),
 (1.2592941522598267, 2.1959288120269775),
 (1.2974909543991089, 2.2184340953826904),
 (1.335891842842102, 2.239561080932617)]

Before actually showing these numbers on the chart, we might spend some time on encapsulating this way of using the optimizer (in general) into a class that has a nice (depends on taste) API. Later on, we can just substitute the optimizer and see what happens!

class optimize:
    def __init__(self, function):
        self.function = function
        self.grad_function = grad(function, argnums=(0, 1))
        self.x, self.y = list(), list()

    def using(self, optimizer):
        self._init, self._update, self._get_params = optimizer
        return self

    def start_from(self, params):
        self.state = self._init(tuple(params))
        return self

    def update(self, nr_iterations=1):
        for i in range(nr_iterations):
            params = self._get_params(self.state)
            self.__add_point(*params)
            self.state = self._update(i, self.grad_function(*params), self.state)
        return self.x, self.y            

    def __add_point(self, _x, _y):
        """
        Adds the x and y coordinates for these point to the trace lists
        """
        self.x.append(float(_x))
        self.y.append(float(_y))


optimize(himmelblau())\
    .using(sgd(step_size=0.001))\
    .start_from([1., 1.])\
    .update(10)
([1.0,
  1.0460000038146973,
  1.0928562879562378,
  1.1405162811279297,
  1.1889206171035767,
  1.2380036115646362,
  1.2876930236816406,
  1.3379102945327759,
  1.388570785522461,
  1.4395838975906372],
 [1.0,
  1.0379999876022339,
  1.0759831666946411,
  1.1138836145401,
  1.1516332626342773,
  1.1891623735427856,
  1.2264001369476318,
  1.2632750272750854,
  1.299715518951416,
  1.3356506824493408])
Code
import matplotlib.animation as animation
from IPython.display import HTML, display


def single_frame(i, optimisation, _ax_3d, _ax_2d, angle):
    _ax_3d.clear()
    _ax_2d.clear()
    
    x, y = optimisation.update()        
    x, y = np.array(x), np.array(y)

    plot_function(optimisation.function, angle=angle, fig=fig, ax_3d=_ax_3d, ax_2d=_ax_2d)
    _ax_2d.scatter(*rotate(x, y, angle=angle), c=np.arange(len(x)), cmap='cool')
    _ax_3d.scatter3D(x, y, optimisation.function(x, y), c=np.arange(len(x)), cmap='cool')

    _ax_2d.plot()


angle=225
optimisation = optimize(himmelblau())\
    .using(sgd(step_size=0.01))\
    .start_from([-1., -1.])\

fig = plt.figure(figsize=(13,5))
fig, ax_3d, ax_2d = plot_function(optimisation.function, fig=fig, angle=angle)
fig.tight_layout()

frames = 20
animator = animation.FuncAnimation(fig, single_frame, fargs=(optimisation, ax_3d, ax_2d, angle), frames=frames, interval=250, blit=False)
display(HTML(animator.to_html5_video()))
plt.close()

You can hardly see it, in the 3D plot, because of the transparency issue, but on the 2D it is really informative!

Conclusions

This has been a really long and tiring post to write (maybe two weeks of full-time work?!). Overall, I can say:

  • 3D support in matplotlib is quite limited
  • the visualization really help in understand how a certain optimizer is behaving
  • jax is cool but the optimizers API is a bit rough
  • the 3D view doesn’t add much value overall

[Bonus, Bonus] How does SGD handle a hard function?

Code
import matplotlib.animation as animation
from IPython.display import HTML, display


def single_frame(i, optimisation, _ax_3d, _ax_2d, angle):
    _ax_3d.clear()
    _ax_2d.clear()
    
    x, y = optimisation.update()        
    x, y = np.array(x), np.array(y)
    print(x, y)

    plot_function(optimisation.function, angle=angle, fig=fig, ax_3d=_ax_3d, ax_2d=_ax_2d)
    _ax_2d.scatter(*rotate(x, y, angle=angle), c=np.arange(len(x)), cmap='cool')
    _ax_3d.scatter3D(x, y, optimisation.function(x, y), c=np.arange(len(x)), cmap='cool')

    _ax_2d.plot()


angle=225
optimisation = optimize(jax_holder_table())\
    .using(sgd(step_size=0.3))\
    .start_from([-1., -1.])\

fig = plt.figure(figsize=(13,5))
fig, ax_3d, ax_2d = plot_function(optimisation.function, fig=fig, angle=angle)
fig.tight_layout()

frames = 20
animator = animation.FuncAnimation(fig, single_frame, fargs=(optimisation, ax_3d, ax_2d, angle), frames=frames, interval=250, blit=False)
display(HTML(animator.to_html5_video()))
plt.close()
[-1.] [-1.]
[-1.         -1.09856904] [-1.         -0.57867539]
[-1.         -1.09856904 -1.1924032 ] [-1.         -0.57867539 -0.25041041]
[-1.         -1.09856904 -1.1924032  -1.2352618 ] [-1.         -0.57867539 -0.25041041 -0.09040605]
[-1.         -1.09856904 -1.1924032  -1.2352618  -1.25142336] [-1.         -0.57867539 -0.25041041 -0.09040605 -0.03152514]
[-1.         -1.09856904 -1.1924032  -1.2352618  -1.25142336 -1.25790954] [-1.         -0.57867539 -0.25041041 -0.09040605 -0.03152514 -0.01097676]
[-1.         -1.09856904 -1.1924032  -1.2352618  -1.25142336 -1.25790954
 -1.26062083] [-1.         -0.57867539 -0.25041041 -0.09040605 -0.03152514 -0.01097676
 -0.00382642]
[-1.         -1.09856904 -1.1924032  -1.2352618  -1.25142336 -1.25790954
 -1.26062083 -1.26177108] [-1.         -0.57867539 -0.25041041 -0.09040605 -0.03152514 -0.01097676
 -0.00382642 -0.00133481]
[-1.         -1.09856904 -1.1924032  -1.2352618  -1.25142336 -1.25790954
 -1.26062083 -1.26177108 -1.26226151] [-1.00000000e+00 -5.78675389e-01 -2.50410408e-01 -9.04060453e-02
 -3.15251388e-02 -1.09767634e-02 -3.82641982e-03 -1.33480760e-03
 -4.65787482e-04]
[-1.         -1.09856904 -1.1924032  -1.2352618  -1.25142336 -1.25790954
 -1.26062083 -1.26177108 -1.26226151 -1.26247096] [-1.00000000e+00 -5.78675389e-01 -2.50410408e-01 -9.04060453e-02
 -3.15251388e-02 -1.09767634e-02 -3.82641982e-03 -1.33480760e-03
 -4.65787482e-04 -1.62562239e-04]
[-1.         -1.09856904 -1.1924032  -1.2352618  -1.25142336 -1.25790954
 -1.26062083 -1.26177108 -1.26226151 -1.26247096 -1.26256049] [-1.00000000e+00 -5.78675389e-01 -2.50410408e-01 -9.04060453e-02
 -3.15251388e-02 -1.09767634e-02 -3.82641982e-03 -1.33480760e-03
 -4.65787482e-04 -1.62562239e-04 -5.67385796e-05]
[-1.         -1.09856904 -1.1924032  -1.2352618  -1.25142336 -1.25790954
 -1.26062083 -1.26177108 -1.26226151 -1.26247096 -1.26256049 -1.26259875] [-1.00000000e+00 -5.78675389e-01 -2.50410408e-01 -9.04060453e-02
 -3.15251388e-02 -1.09767634e-02 -3.82641982e-03 -1.33480760e-03
 -4.65787482e-04 -1.62562239e-04 -5.67385796e-05 -1.98038142e-05]
[-1.         -1.09856904 -1.1924032  -1.2352618  -1.25142336 -1.25790954
 -1.26062083 -1.26177108 -1.26226151 -1.26247096 -1.26256049 -1.26259875
 -1.26261508] [-1.00000000e+00 -5.78675389e-01 -2.50410408e-01 -9.04060453e-02
 -3.15251388e-02 -1.09767634e-02 -3.82641982e-03 -1.33480760e-03
 -4.65787482e-04 -1.62562239e-04 -5.67385796e-05 -1.98038142e-05
 -6.91232253e-06]
[-1.         -1.09856904 -1.1924032  -1.2352618  -1.25142336 -1.25790954
 -1.26062083 -1.26177108 -1.26226151 -1.26247096 -1.26256049 -1.26259875
 -1.26261508 -1.26262212] [-1.00000000e+00 -5.78675389e-01 -2.50410408e-01 -9.04060453e-02
 -3.15251388e-02 -1.09767634e-02 -3.82641982e-03 -1.33480760e-03
 -4.65787482e-04 -1.62562239e-04 -5.67385796e-05 -1.98038142e-05
 -6.91232253e-06 -2.41268890e-06]
[-1.         -1.09856904 -1.1924032  -1.2352618  -1.25142336 -1.25790954
 -1.26062083 -1.26177108 -1.26226151 -1.26247096 -1.26256049 -1.26259875
 -1.26261508 -1.26262212 -1.2626251 ] [-1.00000000e+00 -5.78675389e-01 -2.50410408e-01 -9.04060453e-02
 -3.15251388e-02 -1.09767634e-02 -3.82641982e-03 -1.33480760e-03
 -4.65787482e-04 -1.62562239e-04 -5.67385796e-05 -1.98038142e-05
 -6.91232253e-06 -2.41268890e-06 -8.42130703e-07]
[-1.         -1.09856904 -1.1924032  -1.2352618  -1.25142336 -1.25790954
 -1.26062083 -1.26177108 -1.26226151 -1.26247096 -1.26256049 -1.26259875
 -1.26261508 -1.26262212 -1.2626251  -1.26262641] [-1.00000000e+00 -5.78675389e-01 -2.50410408e-01 -9.04060453e-02
 -3.15251388e-02 -1.09767634e-02 -3.82641982e-03 -1.33480760e-03
 -4.65787482e-04 -1.62562239e-04 -5.67385796e-05 -1.98038142e-05
 -6.91232253e-06 -2.41268890e-06 -8.42130703e-07 -2.93939536e-07]
[-1.         -1.09856904 -1.1924032  -1.2352618  -1.25142336 -1.25790954
 -1.26062083 -1.26177108 -1.26226151 -1.26247096 -1.26256049 -1.26259875
 -1.26261508 -1.26262212 -1.2626251  -1.26262641 -1.26262689] [-1.00000000e+00 -5.78675389e-01 -2.50410408e-01 -9.04060453e-02
 -3.15251388e-02 -1.09767634e-02 -3.82641982e-03 -1.33480760e-03
 -4.65787482e-04 -1.62562239e-04 -5.67385796e-05 -1.98038142e-05
 -6.91232253e-06 -2.41268890e-06 -8.42130703e-07 -2.93939536e-07
 -1.02597497e-07]
[-1.         -1.09856904 -1.1924032  -1.2352618  -1.25142336 -1.25790954
 -1.26062083 -1.26177108 -1.26226151 -1.26247096 -1.26256049 -1.26259875
 -1.26261508 -1.26262212 -1.2626251  -1.26262641 -1.26262689 -1.26262712] [-1.00000000e+00 -5.78675389e-01 -2.50410408e-01 -9.04060453e-02
 -3.15251388e-02 -1.09767634e-02 -3.82641982e-03 -1.33480760e-03
 -4.65787482e-04 -1.62562239e-04 -5.67385796e-05 -1.98038142e-05
 -6.91232253e-06 -2.41268890e-06 -8.42130703e-07 -2.93939536e-07
 -1.02597497e-07 -3.58109276e-08]
[-1.         -1.09856904 -1.1924032  -1.2352618  -1.25142336 -1.25790954
 -1.26062083 -1.26177108 -1.26226151 -1.26247096 -1.26256049 -1.26259875
 -1.26261508 -1.26262212 -1.2626251  -1.26262641 -1.26262689 -1.26262712
 -1.26262724] [-1.00000000e+00 -5.78675389e-01 -2.50410408e-01 -9.04060453e-02
 -3.15251388e-02 -1.09767634e-02 -3.82641982e-03 -1.33480760e-03
 -4.65787482e-04 -1.62562239e-04 -5.67385796e-05 -1.98038142e-05
 -6.91232253e-06 -2.41268890e-06 -8.42130703e-07 -2.93939536e-07
 -1.02597497e-07 -3.58109276e-08 -1.24995534e-08]
[-1.         -1.09856904 -1.1924032  -1.2352618  -1.25142336 -1.25790954
 -1.26062083 -1.26177108 -1.26226151 -1.26247096 -1.26256049 -1.26259875
 -1.26261508 -1.26262212 -1.2626251  -1.26262641 -1.26262689 -1.26262712
 -1.26262724 -1.26262724] [-1.00000000e+00 -5.78675389e-01 -2.50410408e-01 -9.04060453e-02
 -3.15251388e-02 -1.09767634e-02 -3.82641982e-03 -1.33480760e-03
 -4.65787482e-04 -1.62562239e-04 -5.67385796e-05 -1.98038142e-05
 -6.91232253e-06 -2.41268890e-06 -8.42130703e-07 -2.93939536e-07
 -1.02597497e-07 -3.58109276e-08 -1.24995534e-08 -4.36288161e-09]
[-1.         -1.09856904 -1.1924032  -1.2352618  -1.25142336 -1.25790954
 -1.26062083 -1.26177108 -1.26226151 -1.26247096 -1.26256049 -1.26259875
 -1.26261508 -1.26262212 -1.2626251  -1.26262641 -1.26262689 -1.26262712
 -1.26262724 -1.26262724 -1.26262724] [-1.00000000e+00 -5.78675389e-01 -2.50410408e-01 -9.04060453e-02
 -3.15251388e-02 -1.09767634e-02 -3.82641982e-03 -1.33480760e-03
 -4.65787482e-04 -1.62562239e-04 -5.67385796e-05 -1.98038142e-05
 -6.91232253e-06 -2.41268890e-06 -8.42130703e-07 -2.93939536e-07
 -1.02597497e-07 -3.58109276e-08 -1.24995534e-08 -4.36288161e-09
 -1.52283297e-09]

Not that good…

Later update

Well, if I knew before about this post it would have spared me a lot of trouble..

Comments