|
- import tkinter as tk
-
- import numpy as np
- from PIL import ImageGrab, ImageTk
- from PIL.Image import Resampling
-
- class DigitDrawer(tk.Frame):
- def __init__(self, parent, canvas_width, canvas_height):
- super().__init__(parent)
- self.canvas_width = canvas_width
- self.canvas_height = canvas_height
- self.brush_size = 3
- self.update_ui()
-
- def clear_ui(self):
- for widget in self.winfo_children():
- widget.destroy()
-
- def update_ui(self):
- self.clear_ui()
- # Create a Canvas to draw on
- self.canvas = tk.Canvas(self, width=self.canvas_width, height=self.canvas_height, bg='white')
- self.canvas.pack(padx=10, pady=10)
- self.canvas_demo = tk.Canvas(self, width=28, height=28, bg='white')
- self.canvas_demo.pack(padx=10, pady=10)
-
- # Clear Button
- self.clear_button = tk.Button(self, text="Clear", command=self.clear_canvas)
- self.clear_button.pack(expand=True, fill='both')
-
- # Bind mouse events to draw on the canvas
- self.canvas.bind("<B1-Motion>", self.paint)
-
- def paint(self, event):
- """Draw on the canvas by creating ovals (circles) at mouse position."""
- x1, y1 = (event.x - self.brush_size), (event.y - self.brush_size)
- x2, y2 = (event.x + self.brush_size), (event.y + self.brush_size)
- self.canvas.create_oval(x1, y1, x2, y2, fill='black', outline='black')
-
- def clear_canvas(self):
- """Clear the canvas to allow the user to draw a new digit."""
- self.canvas.delete("all")
-
- def convert_to_array(self):
- """Convert the canvas drawing to a 28x28 grayscale array."""
- # Get the canvas's pixel data and save it temporarily
- x = self.winfo_rootx() + self.canvas.winfo_x()
- y = self.winfo_rooty() + self.canvas.winfo_y()
- x1 = x + self.canvas.winfo_width()
- y1 = y + self.canvas.winfo_height()
-
- # Capture the canvas area and convert it into a grayscale image using PIL
- image = ImageGrab.grab((x, y, x1, y1)).convert("L").resize((28, 28), resample=Resampling.HAMMING)
- self.demo_image = ImageTk.PhotoImage(image)
- self.canvas_demo.create_image(0, 0, anchor=tk.NW, image=self.demo_image)
-
- image_array = np.asarray(image) / 255.0
- print(np.array(image_array).reshape((28, 28)))
-
- flat_array = image_array.flatten()
- return flat_array
|