|
- from ui.components.plot_figure import PlotFrame
- from abc import ABC
-
- from matplotlib.figure import Figure
-
- from ui.plotters.plotter import Plotter
-
- class PredictionsPlot(PlotFrame):
- def __init__(self, parent):
- super().__init__(parent, height=32)
- self.plotter = PredictionsPlotter(self.figure)
-
- class PredictionsPlotter(Plotter, ABC):
- def __init__(self, figure: Figure):
- super().__init__(figure)
- self.axes = figure.add_subplot()
- self.clean_axes()
-
- def plot(self, data):
- self.axes.imshow(data, cmap='coolwarm', aspect='auto')
- for idx in range(10):
- self.axes.annotate(f"{idx}", xy=(idx - 0.2, 0.2))
- self.clean_axes()
-
- def clean_axes(self):
- # Remove axis ticks, labels, and spines
- self.axes.set_xticks([]) # Remove x-ticks
- self.axes.set_yticks([]) # Remove y-ticks
- self.axes.spines['top'].set_visible(False)
- self.axes.spines['bottom'].set_visible(False)
- self.axes.spines['left'].set_visible(False)
- self.axes.spines['right'].set_visible(False)
- self.axes.set_facecolor((0, 0, 0))
-
- def reset_plot(self):
- self.axes.clear()
|