Matplotlib returning a plot object
Question:
I have a function that wraps pyplot.plt
so I can quickly create graphs with oft-used defaults:
def plot_signal(time, signal, title='', xlab='', ylab='',
line_width=1, alpha=1, color='k',
subplots=False, show_grid=True, fig_size=(10, 5)):
# Skipping a lot of other complexity here
f, axarr = plt.subplots(figsize=fig_size)
axarr.plot(time, signal, linewidth=line_width,
alpha=alpha, color=color)
axarr.set_xlim(min(time), max(time))
axarr.set_xlabel(xlab)
axarr.set_ylabel(ylab)
axarr.grid(show_grid)
plt.suptitle(title, size=16)
plt.show()
However, there are times where I’d want to be able to return the plot so I can manually add/edit things for a specific graph. For example, I want to be able to change the axis labels, or add a second line to the plot after calling the function:
import numpy as np
x = np.random.rand(100)
y = np.random.rand(100)
plot = plot_signal(np.arange(len(x)), x)
plot.plt(y, 'r')
plot.show()
I’ve seen a few questions on this (How to return a matplotlib.figure.Figure object from Pandas plot function? and AttributeError: 'Figure' object has no attribute 'plot') and as a result I’ve tried adding the following to the end of the function:
-
return axarr
-
return axarr.get_figure()
-
return plt.axes()
However, they all return a similar error: AttributeError: 'AxesSubplot' object has no attribute 'plt'
Whats the correct way to return a plot object so it can be edited later?
Answers:
I think the error is pretty self-explanatory. There is no such thing as pyplot.plt
, or similar. plt
is the quasi-standard abbreviated form of pyplot
when being imported, i.e., import matplotlib.pyplot as plt
.
Concerning the problem, the first approach, return axarr
is the most versatile one. You get an axis, or an array of axes, and can plot to it.
The code may look like:
def plot_signal(x,y, ..., **kwargs):
# Skipping a lot of other complexity here
f, ax = plt.subplots(figsize=fig_size)
ax.plot(x,y, ...)
# further stuff
return ax
ax = plot_signal(x,y, ...)
ax.plot(x2, y2, ...)
plt.show()
This is actually a great question that took me YEARS to figure out. A great way to do this is to pass a figure object to your code and have your function add an axis then return the updated figure. Here is an example:
fig_size = (10, 5)
f = plt.figure(figsize=fig_size)
def plot_signal(time, signal, title='', xlab='', ylab='',
line_width=1, alpha=1, color='k',
subplots=False, show_grid=True, fig=f):
# Skipping a lot of other complexity here
axarr = f.add_subplot(1,1,1) # here is where you add the subplot to f
plt.plot(time, signal, linewidth=line_width,
alpha=alpha, color=color)
plt.set_xlim(min(time), max(time))
plt.set_xlabel(xlab)
plt.set_ylabel(ylab)
plt.grid(show_grid)
plt.title(title, size=16)
return(f)
f = plot_signal(time, signal, fig=f)
f
From the matplotlib docs, the recommended signature to use is:
def my_plotter(ax, data1, data2, param_dict):
"""
A helper function to make a graph
Parameters
----------
ax : Axes
The axes to draw to
data1 : array
The x data
data2 : array
The y data
param_dict : dict
Dictionary of keyword arguments to pass to ax.plot
Returns
-------
out : list
list of artists added
"""
out = ax.plot(data1, data2, **param_dict)
return out
This can be used as:
data1, data2, data3, data4 = np.random.randn(4, 100)
fig, ax = plt.subplots(1, 1)
my_plotter(ax, data1, data2, {'marker': 'x'})
You should pass the axes rather than the figure. A Figure
contains one or more Axes
. Axes
is an area where points can be specified in x-y format, 3d plot etc. Figure is something in which we graph the data – it can be jupyter notebook or the windows GUI etc.
I have a function that wraps pyplot.plt
so I can quickly create graphs with oft-used defaults:
def plot_signal(time, signal, title='', xlab='', ylab='',
line_width=1, alpha=1, color='k',
subplots=False, show_grid=True, fig_size=(10, 5)):
# Skipping a lot of other complexity here
f, axarr = plt.subplots(figsize=fig_size)
axarr.plot(time, signal, linewidth=line_width,
alpha=alpha, color=color)
axarr.set_xlim(min(time), max(time))
axarr.set_xlabel(xlab)
axarr.set_ylabel(ylab)
axarr.grid(show_grid)
plt.suptitle(title, size=16)
plt.show()
However, there are times where I’d want to be able to return the plot so I can manually add/edit things for a specific graph. For example, I want to be able to change the axis labels, or add a second line to the plot after calling the function:
import numpy as np
x = np.random.rand(100)
y = np.random.rand(100)
plot = plot_signal(np.arange(len(x)), x)
plot.plt(y, 'r')
plot.show()
I’ve seen a few questions on this (How to return a matplotlib.figure.Figure object from Pandas plot function? and AttributeError: 'Figure' object has no attribute 'plot') and as a result I’ve tried adding the following to the end of the function:
-
return axarr
-
return axarr.get_figure()
-
return plt.axes()
However, they all return a similar error: AttributeError: 'AxesSubplot' object has no attribute 'plt'
Whats the correct way to return a plot object so it can be edited later?
I think the error is pretty self-explanatory. There is no such thing as pyplot.plt
, or similar. plt
is the quasi-standard abbreviated form of pyplot
when being imported, i.e., import matplotlib.pyplot as plt
.
Concerning the problem, the first approach, return axarr
is the most versatile one. You get an axis, or an array of axes, and can plot to it.
The code may look like:
def plot_signal(x,y, ..., **kwargs):
# Skipping a lot of other complexity here
f, ax = plt.subplots(figsize=fig_size)
ax.plot(x,y, ...)
# further stuff
return ax
ax = plot_signal(x,y, ...)
ax.plot(x2, y2, ...)
plt.show()
This is actually a great question that took me YEARS to figure out. A great way to do this is to pass a figure object to your code and have your function add an axis then return the updated figure. Here is an example:
fig_size = (10, 5)
f = plt.figure(figsize=fig_size)
def plot_signal(time, signal, title='', xlab='', ylab='',
line_width=1, alpha=1, color='k',
subplots=False, show_grid=True, fig=f):
# Skipping a lot of other complexity here
axarr = f.add_subplot(1,1,1) # here is where you add the subplot to f
plt.plot(time, signal, linewidth=line_width,
alpha=alpha, color=color)
plt.set_xlim(min(time), max(time))
plt.set_xlabel(xlab)
plt.set_ylabel(ylab)
plt.grid(show_grid)
plt.title(title, size=16)
return(f)
f = plot_signal(time, signal, fig=f)
f
From the matplotlib docs, the recommended signature to use is:
def my_plotter(ax, data1, data2, param_dict):
"""
A helper function to make a graph
Parameters
----------
ax : Axes
The axes to draw to
data1 : array
The x data
data2 : array
The y data
param_dict : dict
Dictionary of keyword arguments to pass to ax.plot
Returns
-------
out : list
list of artists added
"""
out = ax.plot(data1, data2, **param_dict)
return out
This can be used as:
data1, data2, data3, data4 = np.random.randn(4, 100)
fig, ax = plt.subplots(1, 1)
my_plotter(ax, data1, data2, {'marker': 'x'})
You should pass the axes rather than the figure. A Figure
contains one or more Axes
. Axes
is an area where points can be specified in x-y format, 3d plot etc. Figure is something in which we graph the data – it can be jupyter notebook or the windows GUI etc.