import inspect
import numpy as np
import pandas as pd
from xarray import DataArray
import xarray.plot as xplt
from xarray.plot.plot import _infer_interval_breaks
from xarray.plot.utils import (_determine_cmap_params,
_build_discrete_cmap,
_color_palette)
from . import TestCase, requires_matplotlib, incompatible_2_6
try:
import matplotlib as mpl
# Using a different backend makes Travis CI work.
mpl.use('Agg')
# Order of imports is important here.
import matplotlib.pyplot as plt
except ImportError:
pass
def text_in_fig():
'''
Return the set of all text in the figure
'''
alltxt = [t.get_text() for t in plt.gcf().findobj(mpl.text.Text)]
# Set comprehension not compatible with Python 2.6
return set(alltxt)
def find_possible_colorbars():
# nb. this function also matches meshes from pcolormesh
return plt.gcf().findobj(mpl.collections.QuadMesh)
def substring_in_axes(substring, ax):
'''
Return True if a substring is found anywhere in an axes
'''
alltxt = set([t.get_text() for t in ax.findobj(mpl.text.Text)])
for txt in alltxt:
if substring in txt:
return True
return False
def easy_array(shape, start=0, stop=1):
'''
Make an array with desired shape using np.linspace
shape is a tuple like (2, 3)
'''
a = np.linspace(start, stop, num=np.prod(shape))
return a.reshape(shape)
@requires_matplotlib
class PlotTestCase(TestCase):
def tearDown(self):
# Remove all matplotlib figures
plt.close('all')
def pass_in_axis(self, plotmethod):
fig, axes = plt.subplots(ncols=2)
plotmethod(ax=axes[0])
self.assertTrue(axes[0].has_data())
def imshow_called(self, plotmethod):
plotmethod()
images = plt.gca().findobj(mpl.image.AxesImage)
return len(images) > 0
def contourf_called(self, plotmethod):
plotmethod()
paths = plt.gca().findobj(mpl.collections.PathCollection)
return len(paths) > 0
class TestPlot(PlotTestCase):
def setUp(self):
self.darray = DataArray(easy_array((2, 3, 4)))
def test1d(self):
self.darray[:, 0, 0].plot()
def test_2d_before_squeeze(self):
a = DataArray(easy_array((1, 5)))
a.plot()
def test2d_uniform_calls_imshow(self):
self.assertTrue(self.imshow_called(self.darray[:, :, 0].plot.imshow))
def test2d_nonuniform_calls_contourf(self):
a = self.darray[:, :, 0]
a.coords['dim_1'] = [2, 1, 89]
self.assertTrue(self.contourf_called(a.plot.contourf))
def test3d(self):
self.darray.plot()
def test_can_pass_in_axis(self):
self.pass_in_axis(self.darray.plot)
def test__infer_interval_breaks(self):
self.assertArrayEqual([-0.5, 0.5, 1.5], _infer_interval_breaks([0, 1]))
self.assertArrayEqual([-0.5, 0.5, 5.0, 9.5, 10.5],
_infer_interval_breaks([0, 1, 9, 10]))
self.assertArrayEqual(pd.date_range('20000101', periods=4) - np.timedelta64(12, 'h'),
_infer_interval_breaks(pd.date_range('20000101', periods=3)))
@incompatible_2_6
def test_datetime_dimension(self):
nrow = 3
ncol = 4
time = pd.date_range('2000-01-01', periods=nrow)
a = DataArray(easy_array((nrow, ncol)),
coords=[('time', time), ('y', range(ncol))])
a.plot()
ax = plt.gca()
self.assertTrue(ax.has_data())
def test_convenient_facetgrid(self):
a = easy_array((10, 15, 4))
d = DataArray(a, dims=['y', 'x', 'z'])
d.coords['z'] = list('abcd')
g = d.plot(x='x', y='y', col='z', col_wrap=2, cmap='cool')
self.assertArrayEqual(g.axes.shape, [2, 2])
for ax in g.axes.flat:
self.assertTrue(ax.has_data())
with self.assertRaisesRegexp(ValueError, '[Ff]acet'):
d.plot(x='x', y='y', col='z', ax=plt.gca())
with self.assertRaisesRegexp(ValueError, '[Ff]acet'):
d[0].plot(x='x', y='y', col='z', ax=plt.gca())
def test_subplot_kws(self):
a = easy_array((10, 15, 4))
d = DataArray(a, dims=['y', 'x', 'z'])
d.coords['z'] = list('abcd')
g = d.plot(x='x', y='y', col='z', col_wrap=2, cmap='cool',
subplot_kws=dict(axisbg='r'))
for ax in g.axes.flat:
self.assertEqual(ax.get_axis_bgcolor(), 'r')
def test_convenient_facetgrid_4d(self):
a = easy_array((10, 15, 2, 3))
d = DataArray(a, dims=['y', 'x', 'columns', 'rows'])
g = d.plot(x='x', y='y', col='columns', row='rows')
self.assertArrayEqual(g.axes.shape, [3, 2])
for ax in g.axes.flat:
self.assertTrue(ax.has_data())
with self.assertRaisesRegexp(ValueError, '[Ff]acet'):
d.plot(x='x', y='y', col='columns', ax=plt.gca())
class TestPlot1D(PlotTestCase):
def setUp(self):
d = [0, 1.1, 0, 2]
self.darray = DataArray(d, coords={'period': range(len(d))})
def test_xlabel_is_index_name(self):
self.darray.plot()
self.assertEqual('period', plt.gca().get_xlabel())
def test_no_label_name_on_y_axis(self):
self.darray.plot()
self.assertEqual('', plt.gca().get_ylabel())
def test_ylabel_is_data_name(self):
self.darray.name = 'temperature'
self.darray.plot()
self.assertEqual(self.darray.name, plt.gca().get_ylabel())
def test_wrong_dims_raises_valueerror(self):
twodims = DataArray(easy_array((2, 5)))
with self.assertRaises(ValueError):
twodims.plot.line()
def test_format_string(self):
self.darray.plot.line('ro')
def test_can_pass_in_axis(self):
self.pass_in_axis(self.darray.plot.line)
def test_nonnumeric_index_raises_typeerror(self):
a = DataArray([1, 2, 3], {'letter': ['a', 'b', 'c']})
with self.assertRaisesRegexp(TypeError, r'[Pp]lot'):
a.plot.line()
def test_primitive_returned(self):
p = self.darray.plot.line()
self.assertTrue(isinstance(p[0], mpl.lines.Line2D))
def test_plot_nans(self):
self.darray[1] = np.nan
self.darray.plot.line()
def test_x_ticks_are_rotated_for_time(self):
time = pd.date_range('2000-01-01', '2000-01-10')
a = DataArray(np.arange(len(time)), {'t': time})
a.plot.line()
rotation = plt.gca().get_xticklabels()[0].get_rotation()
self.assertFalse(rotation == 0)
def test_slice_in_title(self):
self.darray.coords['d'] = 10
self.darray.plot.line()
title = plt.gca().get_title()
self.assertEqual('d = 10', title)
class TestPlotHistogram(PlotTestCase):
def setUp(self):
self.darray = DataArray(easy_array((2, 3, 4)))
def test_3d_array(self):
self.darray.plot.hist()
def test_title_no_name(self):
self.darray.plot.hist()
self.assertEqual('', plt.gca().get_title())
def test_title_uses_name(self):
self.darray.name = 'testpoints'
self.darray.plot.hist()
self.assertIn(self.darray.name, plt.gca().get_title())
def test_ylabel_is_count(self):
self.darray.plot.hist()
self.assertEqual('Count', plt.gca().get_ylabel())
def test_can_pass_in_kwargs(self):
nbins = 5
self.darray.plot.hist(bins=nbins)
self.assertEqual(nbins, len(plt.gca().patches))
def test_can_pass_in_axis(self):
self.pass_in_axis(self.darray.plot.hist)
def test_primitive_returned(self):
h = self.darray.plot.hist()
self.assertTrue(isinstance(h[-1][0], mpl.patches.Rectangle))
def test_plot_nans(self):
self.darray[0, 0, 0] = np.nan
self.darray.plot.hist()
@requires_matplotlib
class TestDetermineCmapParams(TestCase):
def setUp(self):
self.data = np.linspace(0, 1, num=100)
def test_robust(self):
cmap_params = _determine_cmap_params(self.data, robust=True)
self.assertEqual(cmap_params['vmin'], np.percentile(self.data, 2))
self.assertEqual(cmap_params['vmax'], np.percentile(self.data, 98))
self.assertEqual(cmap_params['cmap'].name, 'viridis')
self.assertEqual(cmap_params['extend'], 'both')
self.assertIsNone(cmap_params['levels'])
self.assertIsNone(cmap_params['norm'])
def test_center(self):
cmap_params = _determine_cmap_params(self.data, center=0.5)
self.assertEqual(cmap_params['vmax'] - 0.5, 0.5 - cmap_params['vmin'])
self.assertEqual(cmap_params['cmap'], 'RdBu_r')
self.assertEqual(cmap_params['extend'], 'neither')
self.assertIsNone(cmap_params['levels'])
self.assertIsNone(cmap_params['norm'])
def test_integer_levels(self):
data = self.data + 1
cmap_params = _determine_cmap_params(data, levels=5, vmin=0, vmax=5,
cmap='Blues')
self.assertEqual(cmap_params['vmin'], cmap_params['levels'][0])
self.assertEqual(cmap_params['vmax'], cmap_params['levels'][-1])
self.assertEqual(cmap_params['cmap'].name, 'Blues')
self.assertEqual(cmap_params['extend'], 'neither')
self.assertEqual(cmap_params['cmap'].N, 5)
self.assertEqual(cmap_params['norm'].N, 6)
cmap_params = _determine_cmap_params(data, levels=5,
vmin=0.5, vmax=1.5)
self.assertEqual(cmap_params['cmap'].name, 'viridis')
self.assertEqual(cmap_params['extend'], 'max')
def test_list_levels(self):
data = self.data + 1
orig_levels = [0, 1, 2, 3, 4, 5]
# vmin and vmax should be ignored if levels are explicitly provided
cmap_params = _determine_cmap_params(data, levels=orig_levels,
vmin=0, vmax=3)
self.assertEqual(cmap_params['vmin'], 0)
self.assertEqual(cmap_params['vmax'], 5)
self.assertEqual(cmap_params['cmap'].N, 5)
self.assertEqual(cmap_params['norm'].N, 6)
for wrap_levels in [list, np.array, pd.Index, DataArray]:
cmap_params = _determine_cmap_params(
data, levels=wrap_levels(orig_levels))
self.assertArrayEqual(cmap_params['levels'], orig_levels)
def test_divergentcontrol(self):
neg = self.data - 0.1
pos = self.data
# Default with positive data will be a normal cmap
cmap_params = _determine_cmap_params(pos)
self.assertEqual(cmap_params['vmin'], 0)
self.assertEqual(cmap_params['vmax'], 1)
self.assertEqual(cmap_params['cmap'].name, "viridis")
# Default with negative data will be a divergent cmap
cmap_params = _determine_cmap_params(neg)
self.assertEqual(cmap_params['vmin'], -0.9)
self.assertEqual(cmap_params['vmax'], 0.9)
self.assertEqual(cmap_params['cmap'], "RdBu_r")
# Setting vmin or vmax should prevent this only if center is false
cmap_params = _determine_cmap_params(neg, vmin=-0.1, center=False)
self.assertEqual(cmap_params['vmin'], -0.1)
self.assertEqual(cmap_params['vmax'], 0.9)
self.assertEqual(cmap_params['cmap'].name, "viridis")
cmap_params = _determine_cmap_params(neg, vmax=0.5, center=False)
self.assertEqual(cmap_params['vmin'], -0.1)
self.assertEqual(cmap_params['vmax'], 0.5)
self.assertEqual(cmap_params['cmap'].name, "viridis")
# Setting center=False too
cmap_params = _determine_cmap_params(neg, center=False)
self.assertEqual(cmap_params['vmin'], -0.1)
self.assertEqual(cmap_params['vmax'], 0.9)
self.assertEqual(cmap_params['cmap'].name, "viridis")
# However, I should still be able to set center and have a div cmap
cmap_params = _determine_cmap_params(neg, center=0)
self.assertEqual(cmap_params['vmin'], -0.9)
self.assertEqual(cmap_params['vmax'], 0.9)
self.assertEqual(cmap_params['cmap'], "RdBu_r")
# Setting vmin or vmax alone will force symetric bounds around center
cmap_params = _determine_cmap_params(neg, vmin=-0.1)
self.assertEqual(cmap_params['vmin'], -0.1)
self.assertEqual(cmap_params['vmax'], 0.1)
self.assertEqual(cmap_params['cmap'], "RdBu_r")
cmap_params = _determine_cmap_params(neg, vmax=0.5)
self.assertEqual(cmap_params['vmin'], -0.5)
self.assertEqual(cmap_params['vmax'], 0.5)
self.assertEqual(cmap_params['cmap'], "RdBu_r")
cmap_params = _determine_cmap_params(neg, vmax=0.6, center=0.1)
self.assertEqual(cmap_params['vmin'], -0.4)
self.assertEqual(cmap_params['vmax'], 0.6)
self.assertEqual(cmap_params['cmap'], "RdBu_r")
# But this is only true if vmin or vmax are negative
cmap_params = _determine_cmap_params(pos, vmin=-0.1)
self.assertEqual(cmap_params['vmin'], -0.1)
self.assertEqual(cmap_params['vmax'], 0.1)
self.assertEqual(cmap_params['cmap'], "RdBu_r")
cmap_params = _determine_cmap_params(pos, vmin=0.1)
self.assertEqual(cmap_params['vmin'], 0.1)
self.assertEqual(cmap_params['vmax'], 1)
self.assertEqual(cmap_params['cmap'].name, "viridis")
cmap_params = _determine_cmap_params(pos, vmax=0.5)
self.assertEqual(cmap_params['vmin'], 0)
self.assertEqual(cmap_params['vmax'], 0.5)
self.assertEqual(cmap_params['cmap'].name, "viridis")
# If both vmin and vmax are provided, output is non-divergent
cmap_params = _determine_cmap_params(neg, vmin=-0.2, vmax=0.6)
self.assertEqual(cmap_params['vmin'], -0.2)
self.assertEqual(cmap_params['vmax'], 0.6)
self.assertEqual(cmap_params['cmap'].name, "viridis")
@requires_matplotlib
class TestDiscreteColorMap(TestCase):
def setUp(self):
x = np.arange(start=0, stop=10, step=2)
y = np.arange(start=9, stop=-7, step=-3)
xy = np.dstack(np.meshgrid(x, y))
distance = np.linalg.norm(xy, axis=2)
self.darray = DataArray(distance, list(zip(('y', 'x'), (y, x))))
self.data_min = distance.min()
self.data_max = distance.max()
def test_recover_from_seaborn_jet_exception(self):
pal = _color_palette('jet', 4)
self.assertTrue(type(pal) == np.ndarray)
self.assertEqual(len(pal), 4)
def test_build_discrete_cmap(self):
for (cmap, levels, extend, filled) in [('jet', [0, 1], 'both', False),
('hot', [-4, 4], 'max', True)]:
ncmap, cnorm = _build_discrete_cmap(cmap, levels, extend, filled)
self.assertEqual(ncmap.N, len(levels) - 1)
self.assertEqual(len(ncmap.colors), len(levels) - 1)
self.assertEqual(cnorm.N, len(levels))
self.assertArrayEqual(cnorm.boundaries, levels)
self.assertEqual(max(levels), cnorm.vmax)
self.assertEqual(min(levels), cnorm.vmin)
if filled:
self.assertEqual(ncmap.colorbar_extend, extend)
else:
self.assertEqual(ncmap.colorbar_extend, 'neither')
def test_discrete_colormap_list_of_levels(self):
for extend, levels in [('max', [-1, 2, 4, 8, 10]),
('both', [2, 5, 10, 11]),
('neither', [0, 5, 10, 15]),
('min', [2, 5, 10, 15])]:
for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']:
primitive = getattr(self.darray.plot, kind)(levels=levels)
self.assertArrayEqual(levels, primitive.norm.boundaries)
self.assertEqual(max(levels), primitive.norm.vmax)
self.assertEqual(min(levels), primitive.norm.vmin)
if kind != 'contour':
self.assertEqual(extend, primitive.cmap.colorbar_extend)
else:
self.assertEqual('neither', primitive.cmap.colorbar_extend)
self.assertEqual(len(levels) - 1, len(primitive.cmap.colors))
def test_discrete_colormap_int_levels(self):
for extend, levels, vmin, vmax in [('neither', 7, None, None),
('neither', 7, None, 20),
('both', 7, 4, 8),
('min', 10, 4, 15)]:
for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']:
primitive = getattr(self.darray.plot, kind)(levels=levels,
vmin=vmin,
vmax=vmax)
self.assertGreaterEqual(levels,
len(primitive.norm.boundaries) - 1)
if vmax is None:
self.assertGreaterEqual(primitive.norm.vmax, self.data_max)
else:
self.assertGreaterEqual(primitive.norm.vmax, vmax)
if vmin is None:
self.assertLessEqual(primitive.norm.vmin, self.data_min)
else:
self.assertLessEqual(primitive.norm.vmin, vmin)
if kind != 'contour':
self.assertEqual(extend, primitive.cmap.colorbar_extend)
else:
self.assertEqual('neither', primitive.cmap.colorbar_extend)
self.assertGreaterEqual(levels, len(primitive.cmap.colors))
def test_discrete_colormap_list_levels_and_vmin_or_vmax(self):
levels = [0, 5, 10, 15]
primitive = self.darray.plot(levels=levels, vmin=-3, vmax=20)
self.assertEqual(primitive.norm.vmax, max(levels))
self.assertEqual(primitive.norm.vmin, min(levels))
class Common2dMixin:
"""
Common tests for 2d plotting go here.
These tests assume that a staticmethod for `self.plotfunc` exists.
Should have the same name as the method.
"""
def setUp(self):
da = DataArray(easy_array(
(10, 15), start=-1), dims=['y', 'x'])
# add 2d coords
ds = da.to_dataset(name='testvar')
x, y = np.meshgrid(da.x.values, da.y.values)
ds['x2d'] = DataArray(x, dims=['y', 'x'])
ds['y2d'] = DataArray(y, dims=['y', 'x'])
ds.set_coords(['x2d', 'y2d'], inplace=True)
# set darray and plot method
self.darray = ds.testvar
self.plotmethod = getattr(self.darray.plot, self.plotfunc.__name__)
def test_label_names(self):
self.plotmethod()
self.assertEqual('x', plt.gca().get_xlabel())
self.assertEqual('y', plt.gca().get_ylabel())
def test_1d_raises_valueerror(self):
with self.assertRaisesRegexp(ValueError, r'DataArray must be 2d'):
self.plotfunc(self.darray[0, :])
def test_3d_raises_valueerror(self):
a = DataArray(easy_array((2, 3, 4)))
with self.assertRaisesRegexp(ValueError, r'DataArray must be 2d'):
self.plotfunc(a)
def test_nonnumeric_index_raises_typeerror(self):
a = DataArray(easy_array((3, 2)),
coords=[['a', 'b', 'c'], ['d', 'e']])
with self.assertRaisesRegexp(TypeError, r'[Pp]lot'):
self.plotfunc(a)
def test_can_pass_in_axis(self):
self.pass_in_axis(self.plotmethod)
def test_xyincrease_false_changes_axes(self):
self.plotmethod(xincrease=False, yincrease=False)
xlim = plt.gca().get_xlim()
ylim = plt.gca().get_ylim()
diffs = xlim[0] - 14, xlim[1] - 0, ylim[0] - 9, ylim[1] - 0
self.assertTrue(all(abs(x) < 1 for x in diffs))
def test_xyincrease_true_changes_axes(self):
self.plotmethod(xincrease=True, yincrease=True)
xlim = plt.gca().get_xlim()
ylim = plt.gca().get_ylim()
diffs = xlim[0] - 0, xlim[1] - 14, ylim[0] - 0, ylim[1] - 9
self.assertTrue(all(abs(x) < 1 for x in diffs))
def test_plot_nans(self):
x1 = self.darray[:5]
x2 = self.darray.copy()
x2[5:] = np.nan
clim1 = self.plotfunc(x1).get_clim()
clim2 = self.plotfunc(x2).get_clim()
self.assertEqual(clim1, clim2)
def test_viridis_cmap(self):
cmap_name = self.plotmethod(cmap='viridis').get_cmap().name
self.assertEqual('viridis', cmap_name)
def test_default_cmap(self):
cmap_name = self.plotmethod().get_cmap().name
self.assertEqual('RdBu_r', cmap_name)
cmap_name = self.plotfunc(abs(self.darray)).get_cmap().name
self.assertEqual('viridis', cmap_name)
def test_seaborn_palette_as_cmap(self):
try:
import seaborn
cmap_name = self.plotmethod(
levels=2, cmap='husl').get_cmap().name
self.assertEqual('husl', cmap_name)
except ImportError:
pass
def test_can_change_default_cmap(self):
cmap_name = self.plotmethod(cmap='Blues').get_cmap().name
self.assertEqual('Blues', cmap_name)
def test_diverging_color_limits(self):
artist = self.plotmethod()
vmin, vmax = artist.get_clim()
self.assertAlmostEqual(-vmin, vmax)
def test_xy_strings(self):
self.plotmethod('y', 'x')
ax = plt.gca()
self.assertEqual('y', ax.get_xlabel())
self.assertEqual('x', ax.get_ylabel())
def test_positional_coord_string(self):
with self.assertRaisesRegexp(ValueError, 'cannot supply only one'):
self.plotmethod('y')
with self.assertRaisesRegexp(ValueError, 'cannot supply only one'):
self.plotmethod(y='x')
def test_bad_x_string_exception(self):
with self.assertRaisesRegexp(ValueError, 'x and y must be coordinate'):
self.plotmethod('not_a_real_dim', 'y')
self.darray.coords['z'] = 100
with self.assertRaisesRegexp(ValueError, 'cannot supply only one'):
self.plotmethod('z')
def test_coord_strings(self):
# 1d coords (same as dims)
self.assertIn('x', self.darray.coords)
self.assertIn('y', self.darray.coords)
self.plotmethod(y='y', x='x')
def test_default_title(self):
a = DataArray(easy_array((4, 3, 2)), dims=['a', 'b', 'c'])
a.coords['d'] = u'foo'
self.plotfunc(a.isel(c=1))
title = plt.gca().get_title()
self.assertTrue('c = 1, d = foo' == title or 'd = foo, c = 1' == title)
def test_colorbar_label(self):
self.darray.name = 'testvar'
self.plotmethod()
self.assertIn(self.darray.name, text_in_fig())
def test_no_labels(self):
self.darray.name = 'testvar'
self.plotmethod(add_labels=False)
alltxt = text_in_fig()
for string in ['x', 'y', 'testvar']:
self.assertNotIn(string, alltxt)
def test_verbose_facetgrid(self):
a = easy_array((10, 15, 3))
d = DataArray(a, dims=['y', 'x', 'z'])
g = xplt.FacetGrid(d, col='z')
g.map_dataarray(self.plotfunc, 'x', 'y')
for ax in g.axes.flat:
self.assertTrue(ax.has_data())
@incompatible_2_6
def test_2d_function_and_method_signature_same(self):
func_sig = inspect.getcallargs(self.plotfunc, self.darray)
method_sig = inspect.getcallargs(self.plotmethod)
del method_sig['_PlotMethods_obj']
del func_sig['darray']
self.assertEqual(func_sig, method_sig)
def test_convenient_facetgrid(self):
a = easy_array((10, 15, 4))
d = DataArray(a, dims=['y', 'x', 'z'])
g = self.plotfunc(d, x='x', y='y', col='z', col_wrap=2)
self.assertArrayEqual(g.axes.shape, [2, 2])
for (y, x), ax in np.ndenumerate(g.axes):
self.assertTrue(ax.has_data())
if x == 0:
self.assertEqual('y', ax.get_ylabel())
else:
self.assertEqual('', ax.get_ylabel())
if y == 1:
self.assertEqual('x', ax.get_xlabel())
else:
self.assertEqual('', ax.get_xlabel())
# Infering labels
g = self.plotfunc(d, col='z', col_wrap=2)
self.assertArrayEqual(g.axes.shape, [2, 2])
for (y, x), ax in np.ndenumerate(g.axes):
self.assertTrue(ax.has_data())
if x == 0:
self.assertEqual('y', ax.get_ylabel())
else:
self.assertEqual('', ax.get_ylabel())
if y == 1:
self.assertEqual('x', ax.get_xlabel())
else:
self.assertEqual('', ax.get_xlabel())
def test_convenient_facetgrid_4d(self):
a = easy_array((10, 15, 2, 3))
d = DataArray(a, dims=['y', 'x', 'columns', 'rows'])
g = self.plotfunc(d, x='x', y='y', col='columns', row='rows')
self.assertArrayEqual(g.axes.shape, [3, 2])
for ax in g.axes.flat:
self.assertTrue(ax.has_data())
def test_facetgrid_cmap(self):
# Regression test for GH592
data = (np.random.random(size=(20, 25, 12)) + np.linspace(-3, 3, 12))
d = DataArray(data, dims=['x', 'y', 'time'])
fg = d.plot.pcolormesh(col='time')
# check that all color limits are the same
self.assertTrue(len(set(m.get_clim() for m in fg._mappables)) == 1)
# check that all colormaps are the same
self.assertTrue(len(set(m.get_cmap().name for m in fg._mappables)) == 1)
class TestContourf(Common2dMixin, PlotTestCase):
plotfunc = staticmethod(xplt.contourf)
def test_contourf_called(self):
# Having both statements ensures the test works properly
self.assertFalse(self.contourf_called(self.darray.plot.imshow))
self.assertTrue(self.contourf_called(self.darray.plot.contourf))
def test_primitive_artist_returned(self):
artist = self.plotmethod()
self.assertTrue(isinstance(artist, mpl.contour.QuadContourSet))
def test_extend(self):
artist = self.plotmethod()
self.assertEqual(artist.extend, 'neither')
self.darray[0, 0] = -100
self.darray[-1, -1] = 100
artist = self.plotmethod(robust=True)
self.assertEqual(artist.extend, 'both')
self.darray[0, 0] = 0
self.darray[-1, -1] = 0
artist = self.plotmethod(vmin=-0, vmax=10)
self.assertEqual(artist.extend, 'min')
artist = self.plotmethod(vmin=-10, vmax=0)
self.assertEqual(artist.extend, 'max')
def test_2d_coord_names(self):
self.plotmethod(x='x2d', y='y2d')
# make sure labels came out ok
ax = plt.gca()
self.assertEqual('x2d', ax.get_xlabel())
self.assertEqual('y2d', ax.get_ylabel())
def test_levels(self):
artist = self.plotmethod(levels=[-0.5, -0.4, 0.1])
self.assertEqual(artist.extend, 'both')
artist = self.plotmethod(levels=3)
self.assertEqual(artist.extend, 'neither')
class TestContour(Common2dMixin, PlotTestCase):
plotfunc = staticmethod(xplt.contour)
def test_colors(self):
# matplotlib cmap.colors gives an rgbA ndarray
# when seaborn is used, instead we get an rgb tuble
def _color_as_tuple(c):
return tuple(c[:3])
artist = self.plotmethod(colors='k')
self.assertEqual(
_color_as_tuple(artist.cmap.colors[0]),
(0.0, 0.0, 0.0))
artist = self.plotmethod(colors=['k', 'b'])
self.assertEqual(
_color_as_tuple(artist.cmap.colors[1]),
(0.0, 0.0, 1.0))
def test_cmap_and_color_both(self):
with self.assertRaises(ValueError):
self.plotmethod(colors='k', cmap='RdBu')
def list_of_colors_in_cmap_deprecated(self):
with self.assertRaises(Exception):
self.plotmethod(cmap=['k', 'b'])
def test_2d_coord_names(self):
self.plotmethod(x='x2d', y='y2d')
# make sure labels came out ok
ax = plt.gca()
self.assertEqual('x2d', ax.get_xlabel())
self.assertEqual('y2d', ax.get_ylabel())
class TestPcolormesh(Common2dMixin, PlotTestCase):
plotfunc = staticmethod(xplt.pcolormesh)
def test_primitive_artist_returned(self):
artist = self.plotmethod()
self.assertTrue(isinstance(artist, mpl.collections.QuadMesh))
def test_everything_plotted(self):
artist = self.plotmethod()
self.assertEqual(artist.get_array().size, self.darray.size)
def test_2d_coord_names(self):
self.plotmethod(x='x2d', y='y2d')
# make sure labels came out ok
ax = plt.gca()
self.assertEqual('x2d', ax.get_xlabel())
self.assertEqual('y2d', ax.get_ylabel())
class TestImshow(Common2dMixin, PlotTestCase):
plotfunc = staticmethod(xplt.imshow)
def test_imshow_called(self):
# Having both statements ensures the test works properly
self.assertFalse(self.imshow_called(self.darray.plot.contourf))
self.assertTrue(self.imshow_called(self.darray.plot.imshow))
def test_xy_pixel_centered(self):
self.darray.plot.imshow(yincrease=False)
self.assertTrue(np.allclose([-0.5, 14.5], plt.gca().get_xlim()))
self.assertTrue(np.allclose([9.5, -0.5], plt.gca().get_ylim()))
def test_default_aspect_is_auto(self):
self.darray.plot.imshow()
self.assertEqual('auto', plt.gca().get_aspect())
def test_can_change_aspect(self):
self.darray.plot.imshow(aspect='equal')
self.assertEqual('equal', plt.gca().get_aspect())
def test_primitive_artist_returned(self):
artist = self.plotmethod()
self.assertTrue(isinstance(artist, mpl.image.AxesImage))
def test_seaborn_palette_needs_levels(self):
try:
import seaborn
with self.assertRaises(ValueError):
self.plotmethod(cmap='husl')
except ImportError:
pass
def test_2d_coord_names(self):
with self.assertRaisesRegexp(ValueError, 'requires 1D coordinates'):
self.plotmethod(x='x2d', y='y2d')
class TestFacetGrid(PlotTestCase):
def setUp(self):
d = easy_array((10, 15, 3))
self.darray = DataArray(d, dims=['y', 'x', 'z'],
coords={'z': ['a', 'b', 'c']})
self.g = xplt.FacetGrid(self.darray, col='z')
def test_no_args(self):
self.g.map_dataarray(xplt.contourf, 'x', 'y')
# Don't want colorbar labeled with 'None'
alltxt = text_in_fig()
self.assertNotIn('None', alltxt)
for ax in self.g.axes.flat:
self.assertTrue(ax.has_data())
# default font size should be small
fontsize = ax.title.get_size()
self.assertLessEqual(fontsize, 12)
def test_names_appear_somewhere(self):
self.darray.name = 'testvar'
self.g.map_dataarray(xplt.contourf, 'x', 'y')
for k, ax in zip('abc', self.g.axes.flat):
self.assertEqual('z = {0}'.format(k), ax.get_title())
alltxt = text_in_fig()
self.assertIn(self.darray.name, alltxt)
for label in ['x', 'y']:
self.assertIn(label, alltxt)
def test_text_not_super_long(self):
self.darray.coords['z'] = [100 * letter for letter in 'abc']
g = xplt.FacetGrid(self.darray, col='z')
g.map_dataarray(xplt.contour, 'x', 'y')
alltxt = text_in_fig()
maxlen = max(len(txt) for txt in alltxt)
self.assertLess(maxlen, 50)
t0 = g.axes[0, 0].get_title()
self.assertTrue(t0.endswith('...'))
def test_colorbar(self):
vmin = self.darray.values.min()
vmax = self.darray.values.max()
expected = np.array((vmin, vmax))
self.g.map_dataarray(xplt.imshow, 'x', 'y')
for image in plt.gcf().findobj(mpl.image.AxesImage):
clim = np.array(image.get_clim())
self.assertTrue(np.allclose(expected, clim))
self.assertEqual(1, len(find_possible_colorbars()))
def test_empty_cell(self):
g = xplt.FacetGrid(self.darray, col='z', col_wrap=2)
g.map_dataarray(xplt.imshow, 'x', 'y')
bottomright = g.axes[-1, -1]
self.assertFalse(bottomright.has_data())
self.assertFalse(bottomright.get_visible())
def test_norow_nocol_error(self):
with self.assertRaisesRegexp(ValueError, r'[Rr]ow'):
xplt.FacetGrid(self.darray)
def test_groups(self):
self.g.map_dataarray(xplt.imshow, 'x', 'y')
upperleft_dict = self.g.name_dicts[0, 0]
upperleft_array = self.darray.loc[upperleft_dict]
z0 = self.darray.isel(z=0)
self.assertDataArrayEqual(upperleft_array, z0)
def test_float_index(self):
self.darray.coords['z'] = [0.1, 0.2, 0.4]
g = xplt.FacetGrid(self.darray, col='z')
g.map_dataarray(xplt.imshow, 'x', 'y')
def test_nonunique_index_error(self):
self.darray.coords['z'] = [0.1, 0.2, 0.2]
with self.assertRaisesRegexp(ValueError, r'[Uu]nique'):
xplt.FacetGrid(self.darray, col='z')
def test_robust(self):
z = np.zeros((20, 20, 2))
darray = DataArray(z, dims=['y', 'x', 'z'])
darray[:, :, 1] = 1
darray[2, 0, 0] = -1000
darray[3, 0, 0] = 1000
g = xplt.FacetGrid(darray, col='z')
g.map_dataarray(xplt.imshow, 'x', 'y', robust=True)
# Color limits should be 0, 1
# The largest number displayed in the figure should be less than 21
numbers = set()
alltxt = text_in_fig()
for txt in alltxt:
try:
numbers.add(float(txt))
except ValueError:
pass
largest = max(abs(x) for x in numbers)
self.assertLess(largest, 21)
def test_can_set_vmin_vmax(self):
vmin, vmax = 50.0, 1000.0
expected = np.array((vmin, vmax))
self.g.map_dataarray(xplt.imshow, 'x', 'y', vmin=vmin, vmax=vmax)
for image in plt.gcf().findobj(mpl.image.AxesImage):
clim = np.array(image.get_clim())
self.assertTrue(np.allclose(expected, clim))
def test_figure_size(self):
self.assertArrayEqual(self.g.fig.get_size_inches(), (10, 3))
g = xplt.FacetGrid(self.darray, col='z', size=6)
self.assertArrayEqual(g.fig.get_size_inches(), (19, 6))
g = self.darray.plot.imshow(col='z', size=6)
self.assertArrayEqual(g.fig.get_size_inches(), (19, 6))
g = xplt.FacetGrid(self.darray, col='z', size=4, aspect=0.5)
self.assertArrayEqual(g.fig.get_size_inches(), (7, 4))
def test_num_ticks(self):
nticks = 100
maxticks = nticks + 1
self.g.map_dataarray(xplt.imshow, 'x', 'y')
self.g.set_ticks(max_xticks=nticks, max_yticks=nticks)
for ax in self.g.axes.flat:
xticks = len(ax.get_xticks())
yticks = len(ax.get_yticks())
self.assertLessEqual(xticks, maxticks)
self.assertLessEqual(yticks, maxticks)
self.assertGreaterEqual(xticks, nticks / 2.0)
self.assertGreaterEqual(yticks, nticks / 2.0)
def test_map(self):
self.g.map(plt.contourf, 'x', 'y', Ellipsis)
self.g.map(lambda: None)
def test_map_dataset(self):
g = xplt.FacetGrid(self.darray.to_dataset(name='foo'), col='z')
g.map(plt.contourf, 'x', 'y', 'foo')
alltxt = text_in_fig()
for label in ['x', 'y']:
self.assertIn(label, alltxt)
# everything has a label
self.assertNotIn('None', alltxt)
# colorbar can't be inferred automatically
self.assertNotIn('foo', alltxt)
self.assertEqual(0, len(find_possible_colorbars()))
g.add_colorbar(label='colors!')
self.assertIn('colors!', text_in_fig())
self.assertEqual(1, len(find_possible_colorbars()))
def test_set_axis_labels(self):
g = self.g.map_dataarray(xplt.contourf, 'x', 'y')
g.set_axis_labels('longitude', 'latitude')
alltxt = text_in_fig()
for label in ['longitude', 'latitude']:
self.assertIn(label, alltxt)
def test_facetgrid_colorbar(self):
a = easy_array((10, 15, 4))
d = DataArray(a, dims=['y', 'x', 'z'], name='foo')
d.plot.imshow(x='x', y='y', col='z')
self.assertEqual(1, len(find_possible_colorbars()))
d.plot.imshow(x='x', y='y', col='z', add_colorbar=True)
self.assertEqual(1, len(find_possible_colorbars()))
d.plot.imshow(x='x', y='y', col='z', add_colorbar=False)
self.assertEqual(0, len(find_possible_colorbars()))
class TestFacetGrid4d(PlotTestCase):
def setUp(self):
a = easy_array((10, 15, 3, 2))
darray = DataArray(a, dims=['y', 'x', 'col', 'row'])
darray.coords['col'] = np.array(['col' + str(x) for x in
darray.coords['col'].values])
darray.coords['row'] = np.array(['row' + str(x) for x in
darray.coords['row'].values])
self.darray = darray
def test_default_labels(self):
g = xplt.FacetGrid(self.darray, col='col', row='row')
self.assertEqual((2, 3), g.axes.shape)
g.map_dataarray(xplt.imshow, 'x', 'y')
# Rightmost column should be labeled
for label, ax in zip(self.darray.coords['row'].values, g.axes[:, -1]):
self.assertTrue(substring_in_axes(label, ax))
# Top row should be labeled
for label, ax in zip(self.darray.coords['col'].values, g.axes[0, :]):
self.assertTrue(substring_in_axes(label, ax))
|