def force_aspect(ax, aspect=1): im = ax.get_images() extent = im[0].get_extent() ax.set_aspect(abs((extent[1]-extent[0])/(extent[3]-extent[2]))/aspect) def mpl_matshow(ax, data, shape): ax.matshow(data.reshape(shape, shape), cmap='hot', aspect='auto') ax.set_xticks([]) ax.set_yticks([]) force_aspect(ax)