from ui.components.plot_figure import PlotFrame from abc import ABC from matplotlib.figure import Figure from ui.plotters.plotter import Plotter class PredictionsPlot(PlotFrame): def __init__(self, parent): super().__init__(parent, height=32) self.plotter = PredictionsPlotter(self.figure) class PredictionsPlotter(Plotter, ABC): def __init__(self, figure: Figure): super().__init__(figure) self.axes = figure.add_subplot() self.clean_axes() def plot(self, data): self.axes.imshow(data, cmap='coolwarm', aspect='auto') for idx in range(10): self.axes.annotate(f"{idx}", xy=(idx - 0.2, 0.2)) self.clean_axes() def clean_axes(self): # Remove axis ticks, labels, and spines self.axes.set_xticks([]) # Remove x-ticks self.axes.set_yticks([]) # Remove y-ticks self.axes.spines['top'].set_visible(False) self.axes.spines['bottom'].set_visible(False) self.axes.spines['left'].set_visible(False) self.axes.spines['right'].set_visible(False) self.axes.set_facecolor((0, 0, 0)) def reset_plot(self): self.axes.clear()