def scatter_plot(data, x, y, by=None, ax=None, figsize=None, grid=False,
**kwargs):
"""
Make a scatter plot from two DataFrame columns
Parameters
----------
data : DataFrame
x : Column name for the x-axis values
y : Column name for the y-axis values
ax : Matplotlib axis object
figsize : A tuple (width, height) in inches
grid : Setting this to True will show the grid
kwargs : other plotting keyword arguments
To be passed to scatter function
Returns
-------
fig : matplotlib.Figure
"""
import matplotlib.pyplot as plt
# workaround because `c='b'` is hardcoded in matplotlibs scatter method
kwargs.setdefault('c', plt.rcParams['patch.facecolor'])
def plot_group(group, ax):
xvals = group[x].values
yvals = group[y].values
ax.scatter(xvals, yvals, **kwargs)
ax.grid(grid)
if by is not None:
fig = _grouped_plot(plot_group, data, by=by, figsize=figsize, ax=ax)
else:
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111)
else:
fig = ax.get_figure()
plot_group(data, ax)
ax.set_ylabel(com.pprint_thing(y))
ax.set_xlabel(com.pprint_thing(x))
ax.grid(grid)
return fig
plotting.py 文件源码
python
阅读 43
收藏 0
点赞 0
评论 0
评论列表
文章目录