Plot Animation with Matplotlib
Hello AI enthusiast, this is a short and spot-on post on a very unique top. If you are a Data Scientist Or an aspiring one then you must have used Matplotlib to analyse your data using different plots.
In this post, we will go a step further and utilize the inbuilt capability of Matplotlib to animate the plots. You may download the animation as a video file if required or you can simply use it to demonstrate any iterative learning.
This post is not about the basics of Matplotlib, so you should have a basic idea of Matplotlib.
FuncAnimation is the class which facilitate all the required capability to convert a simple plot into an animation. So let's understand its initialization parameters.
class matplotlib.animation.FuncAnimation(fig, func, frames=None, init_func=None, fargs=None, save_count=None, *, cache_frame_data=True, **kwargs)[source]
- fig - The figure object used to get needed events, such as draw or resize.
- func - This is a function which the CLass will call in every iteration to draw a plot and collate all the plots to build the animation. The first argument will be the next value in frames. Any additional positional arguments can be supplied via the fargs parameter.
- frames - This is basically an Iterator whose individual values will b passed to the function i.e. the previous parameter. We have to use it to draw the plot in each iteration
- init_func - A function used to draw a clear frame.
- fargs - Additional arguments to pass to each call to func.
- interval - Delay between frames in milliseconds.
- repeat_delay - The delay in milliseconds between consecutive animation runs, if repeat is True
- repeat - Whether the animation repeats when the sequence of frames is completed.
Let's program a simple example,
import numpy as np, matplotlib.pyplot as plt from mpl_toolkits import mplot3d from matplotlib import animation from IPython.display import HTML, Image rc('animation', html='html5') fig = plt.figure(figsize=(8,6)) ax = plt.axes() def func_plot(w): #1 ax.clear() x=np.random.normal(loc=0.0,scale=50.0,size=50) y=np.random.normal(loc=0.0,scale=50.0,size=50) plot = ax.scatter(x, y, c=w) return plot anim = animation.FuncAnimation(fig, func_plot, frames=['r','b','g','k'], repeat=True) #2 anim
#1 - This is our function that will be passed to the Animation Class. It plots a Scatterplot and returns that
#2 - In the frames argument, we have passed a list of 4 string which will be passed to the parameter of
func_plotwhich uses it as the colour for the Scatterplot. It also means that the function will be called 4 times
Vizualize a Linear Regression convergence
So, we are done with a basic example. Let's now move to a better example where we will plot a simple linear regression's convergence via. GradientDescent.
We will use a dummy data X, Y and fit a Linear Regression on it. We will save the values of Y for each iteration and then plot it in an animation.
# Linear Regression from sklearn.datasets import make_regression from sklearn.linear_model import SGDRegressor from IPython.display import HTML, Image rc('animation', html='html5') x, y = make_regression(n_samples=1000, n_features=1, n_informative=1, noise=150,) model = SGDRegressor(warm_start=True, eta0=2.5, random_state=0) parms =  for i in range(25): model.partial_fit(x,y) m,c = model.coef_, model.intercept_ parms.append((m,c)) fig = plt.figure(figsize=(8,6)) ax = plt.axes() def func_plot(w): ax.clear() x1, y1, x2, y2 = x.min(), w*x.min()+w, x.max(), w*x.max()+w plot = ax.scatter(x,y) plot = ax.plot([x1, x2], [y1, y2], color='b') return plot anim = animation.FuncAnimation(fig, func_plot, frames=parms, repeat=True, interval=500) anim
Code is quite trivial and self-explanatory. We just fit a LinearRegression sequentially and use the learnt parameters to plot the line.
Extension to 3-D plot
In a similar manner, you can extend this code for a 3-D plot. The only thing that will change is the addition of matplotlib 3-D plots. Code related to animation will remain as it is.
Let's plot a plane by changing the equation and some random scatterplot.
# Main plane from mpl_toolkits import mplot3d # All other imports from the previous code fig = plt.figure(figsize=(8,6)) ax = plt.axes(projection='3d') N=1 x=500*np.random.random(N) y=500*np.random.random(N) z=500*np.random.random(N) def func_plot(w): ax.clear() global x,y,z x=x+np.random.normal(loc=0.0,scale=50.0,size=10) y=y+np.random.normal(loc=0.0,scale=50.0,size=10) z=z+np.random.normal(loc=0.0,scale=50.0,size=10) X = np.linspace(-w*10,w*10,10000) Y = np.linspace(-w*10,w*10,10000) X,Y = np.meshgrid(X,Y) Z = 2*X + 5*Y - 7.5 ax.set_xlim3d(-500.0,500.0) ax.set_ylim3d(-500.0,500.0) ax.set_zlim3d(-500.0,500.0) plot = ax.scatter3D(x, y, z, c='r') plot = ax.plot_surface(X, Y, Z, color='r') return plot anim = animation.FuncAnimation(fig, func_plot, frames=range(1,10,1)) anim
Code is quite trivial and self-explanatory.
This was all for this post. You may extend the idea to,
- KMeans Clustering
- Neural Network Learning of Decision Boundary