Here's a method I wrote to create a plot for a function defined on a 2D mesh. Note you can always extract a surface from a 3D mesh by creating a BoundaryMesh and a SubMesh.
import numpy as np
import os
from matplotlib import colors
from pylab import plt
def plot_variable(u, name, direc, cmap='gist_yarg', scale='lin', numLvls=12,
umin=None, umax=None, tp=False, tpAlpha=0.5, show=True,
hide_ax_tick_labels=False, label_axes=True, title='',
use_colorbar=True, hide_axis=False, colorbar_loc='right'):
"""
"""
mesh = u.function_space().mesh()
v = u.compute_vertex_values(mesh)
x = mesh.coordinates()[:,0]
y = mesh.coordinates()[:,1]
t = mesh.cells()
d = os.path.dirname(direc)
if not os.path.exists(d):
os.makedirs(d)
if umin != None:
vmin = umin
else:
vmin = v.min()
if umax != None:
vmax = umax
else:
vmax = v.max()
# countour levels :
if scale == 'log':
v[v < vmin] = vmin + 1e-12
v[v > vmax] = vmax - 1e-12
from matplotlib.ticker import LogFormatter
levels = np.logspace(np.log10(vmin), np.log10(vmax), numLvls)
formatter = LogFormatter(10, labelOnlyBase=False)
norm = colors.LogNorm()
elif scale == 'lin':
v[v < vmin] = vmin + 1e-12
v[v > vmax] = vmax - 1e-12
from matplotlib.ticker import ScalarFormatter
levels = np.linspace(vmin, vmax, numLvls)
formatter = ScalarFormatter()
norm = None
elif scale == 'bool':
from matplotlib.ticker import ScalarFormatter
levels = [0, 1, 2]
formatter = ScalarFormatter()
norm = None
fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(111)
c = ax.tricontourf(x, y, t, v, levels=levels, norm=norm,
cmap=pl.get_cmap(cmap))
plt.axis('equal')
if tp == True:
p = ax.triplot(x, y, t, '-', lw=0.2, alpha=tpAlpha)
ax.set_xlim([x.min(), x.max()])
ax.set_ylim([y.min(), y.max()])
if label_axes:
ax.set_xlabel(r'$x$')
ax.set_ylabel(r'$y$')
if hide_ax_tick_labels:
ax.set_xticklabels([])
ax.set_yticklabels([])
if hide_axis:
plt.axis('off')
# include colorbar :
if scale != 'bool' and use_colorbar:
divider = make_axes_locatable(plt.gca())
cax = divider.append_axes(colorbar_loc, "5%", pad="3%")
cbar = plt.colorbar(c, cax=cax, format=formatter,
ticks=levels)
tit = plt.title(title)
if use_colorbar:
plt.tight_layout(rect=[.03,.03,0.97,0.97])
else:
plt.tight_layout()
plt.savefig(direc + name + '.png', dpi=300)
if show:
plt.show()
plt.close(fig)