Subplots in matplotlib

In a previous post we learned how to use matplotlib’s gridspec
to make subplots of unequal size. gridspec
is quite powerful, but can be a little complicated to use. This is especially true if you are coming to Python from Matlab.
In the current post we learn how to make figures and subplots in a manner that will be more familiar to those who know Matlab.
Importing modules and creating data to plot
The first thing we need to do is import matplotlib.pyplot
for plotting and numpy
to calculate the cosine of some data. To save on typing we will import these libraries using aliases:
import matplotlib.pyplot as plt
import numpy as np
# Create data
X = np.linspace(-np.pi, np.pi, 256, endpoint=True)
Y = np.cos(X)
A basic plot
The simplest plot requires a single line of code:
plt.plot(X, Y)
The resulting figure is show below. plt.plot()
is similar to Matlab’s plot()
function. In this case it is plotting the Y
values against the X
values.
Working with subplots
A slightly more complicated figure can be achieved by splitting things into various subplots. The following example creates an 8 inch x 8 inch figure with 2 rows and 2 columns of subplots, resulting in a total of 4 subplots.
plt.figure(figsize=(8,8))
plt.subplot(2,2,1)
plt.plot(X, Y, color="blue")
plt.title('subplot(2,2,1)')
plt.subplot(2,2,2)
plt.plot(X, Y*-1, color="red")
plt.title('subplot(2,2,2)')
plt.subplot(2,2,3)
plt.plot(X, Y*-1, color="green")
plt.title('subplot(2,2,3)')
plt.subplot(2,2,4)
plt.plot(X, Y, color="black")
plt.title('subplot(2,2,4)')
As can be seen in the code above, subplots are specified using plt.subplot()
, similar to Matlab’s subplot()
. The three values passed to this command (rows, columns, subplot_id)
. In the above example we wanted 2 columns and 2 rows. The third number specified the current subplot; any plotting after the plt.subplot()
command appears on the subplot specified by subplot_id
. In the figure below, the input to plt.subplot()
is included in the title of each subplot. Subplots start at 1
and go from left to right in the first row, and then left to right in all subsequent rows.
Subplots of unequal size
Similar to Matlab, it is possible to pass more than one value as the subplot_id
. This results in a subplot that occupies the space of the specified subplots.
plt.figure(figsize=(8,8))
plt.subplot(2,2,1)
plt.plot(X, Y, color="blue")
plt.title('subplot(2,2,1)')
plt.subplot(2,2,2)
plt.plot(X, Y*-1, color="red")
plt.title('subplot(2,2,2)')
plt.subplot(2,2,(3,4))
plt.plot(X, Y*-1, color="green")
plt.plot(X, Y, color="black")
plt.title('subplot(2,2,(3,4))')
In the code above, the third call to plt.subplot()
specified two values for the subplot_id
. Specifically, we provide (3,4)
as the subplot_id
, which means this subplot will occupy the space of the third and fourth subplots in a 2 row by 2 column grid. This corresponds to the entire bottom row, and is illustrated in the figure below.
To make sure this concept is clear, a second example is provided where the subplot_id
is (2,4)
. In a 2 row by 2 column grid, this corresponds to a subplot that occupies the entire right column. This is illustrated in the figure below.
Summary
Including subplots is simple in matplotlib and the similarity between plt.subplot()
and Matlab’s subplot()
commands should help make the transition to Python easier.