Plot Animation with Matplotlib

Plot Animation with Matplotlib

Apr 4, 2021·

4 min read


Hello AI enthusiast, this is a short and spot-on post on a very unique top. If you are a Data Scientist Or an aspiring one then you must have used Matplotlib to analyse your data using different plots.
In this post, we will go a step further and utilize the inbuilt capability of Matplotlib to animate the plots. You may download the animation as a video file if required or you can simply use it to demonstrate any iterative learning.
This post is not about the basics of Matplotlib, so you should have a basic idea of Matplotlib.

FuncAnimation function

FuncAnimation is the class which facilitate all the required capability to convert a simple plot into an animation. So let's understand its initialization parameters.

class matplotlib.animation.FuncAnimation(fig, func, frames=None, init_func=None, fargs=None, save_count=None, *, cache_frame_data=True, **kwargs)[source]


  • fig - The figure object used to get needed events, such as draw or resize.
  • func - This is a function which the CLass will call in every iteration to draw a plot and collate all the plots to build the animation. The first argument will be the next value in frames. Any additional positional arguments can be supplied via the fargs parameter.
  • frames - This is basically an Iterator whose individual values will b passed to the function i.e. the previous parameter. We have to use it to draw the plot in each iteration
  • init_func - A function used to draw a clear frame.
  • fargs - Additional arguments to pass to each call to func.
  • interval - Delay between frames in milliseconds.
  • repeat_delay - The delay in milliseconds between consecutive animation runs, if repeat is True
  • repeat - Whether the animation repeats when the sequence of frames is completed.

Let's program a simple example,

import numpy as np, matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
from matplotlib import animation
from IPython.display import HTML, Image 
rc('animation', html='html5')

fig = plt.figure(figsize=(8,6))
ax = plt.axes()

def func_plot(w): #1


    plot = ax.scatter(x, y, c=w)

    return plot

anim = animation.FuncAnimation(fig, func_plot, frames=['r','b','g','k'], repeat=True) #2

#1 - This is our function that will be passed to the Animation Class. It plots a Scatterplot and returns that
#2 - In the frames argument, we have passed a list of 4 string which will be passed to the parameter of func_plot which uses it as the colour for the Scatterplot. It also means that the function will be called 4 times


Vizualize a Linear Regression convergence

So, we are done with a basic example. Let's now move to a better example where we will plot a simple linear regression's convergence via. GradientDescent.
We will use a dummy data X, Y and fit a Linear Regression on it. We will save the values of Y for each iteration and then plot it in an animation.

# Linear Regression
from sklearn.datasets import make_regression
from sklearn.linear_model import SGDRegressor
from IPython.display import HTML, Image 
rc('animation', html='html5')

x, y = make_regression(n_samples=1000, n_features=1, n_informative=1, noise=150,)
model = SGDRegressor(warm_start=True, eta0=2.5, random_state=0)

parms = []
for i in range(25):
    m,c = model.coef_[0], model.intercept_[0]

fig = plt.figure(figsize=(8,6))
ax = plt.axes()

def func_plot(w):

    x1, y1, x2, y2 = x.min(), w[0]*x.min()+w[1], x.max(), w[0]*x.max()+w[1] 
    plot = ax.scatter(x,y)
    plot = ax.plot([x1, x2], [y1, y2], color='b')
    return plot

anim = animation.FuncAnimation(fig, func_plot, frames=parms, repeat=True, interval=500)

Code is quite trivial and self-explanatory. We just fit a LinearRegression sequentially and use the learnt parameters to plot the line.


Extension to 3-D plot

In a similar manner, you can extend this code for a 3-D plot. The only thing that will change is the addition of matplotlib 3-D plots. Code related to animation will remain as it is.
Let's plot a plane by changing the equation and some random scatterplot.

# Main plane
from mpl_toolkits import mplot3d
# All other imports from the previous code

fig = plt.figure(figsize=(8,6))
ax = plt.axes(projection='3d')


def func_plot(w):
    global x,y,z 

    X = np.linspace(-w*10,w*10,10000)
    Y = np.linspace(-w*10,w*10,10000)
    X,Y = np.meshgrid(X,Y)
    Z = 2*X + 5*Y - 7.5

    plot = ax.scatter3D(x, y, z, c='r')
    plot = ax.plot_surface(X, Y, Z, color='r')
    return plot

anim = animation.FuncAnimation(fig, func_plot, frames=range(1,10,1))

Code is quite trivial and self-explanatory.



This was all for this post. You may extend the idea to,

  • KMeans Clustering
  • Neural Network Learning of Decision Boundary