Перейти к основному содержимому

Практикум — распознавание цифр на PyTorch

Разработчику Начальный уровень

О практикуме

TestPyTorch — учебное приложение из трёх модулей. Свёрточная нейросеть (CNN) обучается на MNIST (рукописные цифры 28×28), веса сохраняются на диск, а окно Tkinter принимает рисунок мышью и выводит предсказание с уверенностью.

Зачем такой проект после теории в PyTorch для разработчика?

  • сквозной цикл ML — данные → модель → loss.backward() → сохранение → инференс;
  • компьютерное зрение в миниатюре — свёртки, пулинг, softmax, препроцессинг изображения;
  • GUI + фоновое обучение — длинная задача уходит в поток, окно остаётся отзывчивым;
  • связка train и inference — одни и те же Normalize и архитектура в train.py и app.py.

Образец для сверки — F:\Projects\Python\TestPyTorch. Теория CNN и backprop — нейрон и слои, перцептрон на NumPy. Аналог по формату "десктоп + данные" — практикум Pandas Data Viewer. Табличный ML до нейросетей — scikit-learn и маршрут машинного обучения.

Для кого материал

Нужны Python 3.10+, базовые классы, установленные torch, torchvision, Pillow. Желательно прочитать обзор PyTorch, один раз пройти первую программу на Tkinter и по возможности — NumPy в Lab. GPU не обязателен: на CPU обучение займёт 1–2 минуты.
Маршрут раздела "ИИ" — машинное обучение, затем нейросети.

Ментальная модель

В 333 вы учитесь говорить на языке PyTorch: тензор, autograd, один батч. Здесь те же приёмы собраны в продукт — как в Melbourne для tabular ML, только задача — классификация изображений, а интерфейс — не ноутбук, а окно с canvas.

Оценка времени — 2–4 часа при прохождении всех этапов подряд с самопроверкой после каждого шага.

Чему учит практикум

НавыкЧто именно тренируемГде в коде
Архитектура CNNConv2d, MaxPool2d, Linear, Dropoutmodel.py
Цикл обученияDataLoader, loss, backward, метрикиtrain.py
Работа с датасетомMNIST, transforms, Subsettrain.py
Сохранение моделиstate_dict, torch.save / loadtrain.py, _load_model
Препроцессинг CVinvert, crop, resize 28×28, normalize_preprocess, _predict
GUIcanvas, события мыши, StringVar, progressbarapp.py
Многопоточностьобучение без блокировки UI_start_training

Ключевые определения

  • MNIST — набор 70 000 изображений рукописных цифр 0–9 в градациях серого 28×28; стандартный "Hello, World" для компьютерного зрения. Контекст датасетов — Data Science.
  • CNN (Convolutional Neural Network, свёрточная нейросеть) — архитектура, где фильтры скользят по картинке и выделяют локальные признаки (штрихи, углы, дуги). Пулинг уменьшает разрешение карты признаков и даёт устойчивость к сдвигу.
  • Logits — сырые выходы последнего слоя до softmax; модель возвращает именно их, а вероятности считают отдельно через F.softmax.
  • Эпоха (epoch) — один полный проход по обучающей выборке. Две эпохи в учебном проекте — компромисс между временем и качеством.
  • Батч (batch) — пачка примеров за один шаг градиента; BATCH_SIZE = 128 ускоряет обучение на GPU и стабилизирует оценку loss.
  • CrossEntropyLoss — функция потерь для многоклассовой классификации; внутри объединяет softmax и negative log-likelihood. Подробнее — 333 — loss и оптимизатор.
  • Adam — адаптивный оптимизатор; для прототипа удобнее ручной подбор learning rate, чем классический SGD.
  • Инференс — предсказание на новых данных; режим model.eval() и контекст torch.no_grad() отключают Dropout и autograd.
  • state_dict — словарь обученных весов и смещений; сохраняется через torch.save, загружается через load_state_dict в ту же архитектуру.
  • Train / evalmodel.train() включает Dropout и обучение; model.eval() — только предсказание. Смешивать режимы нельзя.

Архитектура проекта

Проект сознательно разделён на модель, обучение и интерфейс — тот же приём, что в Kivy-практикуме (логика отдельно от GUI):

test-pytorch/
├── model.py # класс DigitCNN (nn.Module)
├── train.py # обучение на MNIST, сохранение весов
├── app.py # GUI — рисование, предсказание
├── requirements.txt
├── data/ # MNIST (скачивается автоматически)
└── models/
└── digit_cnn.pt # веса после train()

Карта этапов

ЭтапФокусРезультат
0Окружениеvenv, зависимости, структура папок
1model.pyКласс DigitCNN
2train.pyЦикл обучения, файл весов
3app.py, каркасОкно, кнопки, загрузка модели
4РисованиеCanvas + PIL, препроцессинг под MNIST
5ПредсказаниеSoftmax, топ-3, фоновое обучение
6РевизияИтоговая самопроверка

Как проходить практикум

  1. Создайте папку test-pytorch/, venv и установите зависимости — этап 0.
  2. Идите этапы 1–5 по порядку — после этапа 2 запускайте python train.py, после 3–5 — python app.py.
  3. Копируйте код целиком из блока этапа; не пропускайте фрагменты вида # ....
  4. Отмечайте Самопроверку; читайте Разбор — там связь строк с теорией CNN и Tkinter.
  5. Финал — этап 6 и сверка с образцом F:\Projects\Python\TestPyTorch.

Маршрут чтения

  1. Ключевые определения и архитектура — до первой строки кода.
  2. Этапы 0–2 — модель и обучение без GUI.
  3. Этапы 3–5 — интерфейс и инференс.
  4. Дальше — 333, нейросети или применение CV.

Правило прохождения — не переходите к следующему этапу, пока не отмечены пункты самопроверки. Тот же подход — в отладке и разработке.


Этап 0 — окружение

Цель — отдельная папка проекта, виртуальное окружение, зафиксированные зависимости и каталоги под данные и веса.

Зачем venv и requirements.txt

Виртуальное окружение изолирует пакеты проекта от системного Python: torch тянет тяжёлые зависимости, и их лучше не смешивать с другими задачами на том же компьютере. Файл requirements.txt фиксирует состав окружения — через месяц или на другой машине воспроизведёте установку одной командой. Подробнее — Зависимости Python.

mkdir test-pytorch && cd test-pytorch
python -m venv .venv
.venv\Scripts\activate
pip install "torch>=2.0.0" "torchvision>=0.15.0" "Pillow>=10.0.0"

requirements.txt:

torch>=2.0.0
torchvision>=0.15.0
Pillow>=10.0.0

Разбор команд

  • python -m venv .venv — каталог .venv с копией интерпретатора и pip.
  • .venv\Scripts\activate (Windows) подключает окружение в текущей сессии терминала; в Linux/macOS — source .venv/bin/activate.
  • torch — тензоры, nn.Module, autograd, оптимизаторы; см. 333 — установка и device.
  • torchvision — готовый MNIST и transforms для CV.
  • Pillow — off-screen изображение для преобразования рисунка с canvas в массив 28×28.

С GPU-сборкой сверяйтесь с официальным индексом PyTorch; для этого практикума хватит CPU.

Структура каталогов

Создайте пустые папки:

  • models/ — сюда попадёт digit_cnn.pt после обучения;
  • data/ — опционально; MNIST скачается в data/MNIST при первом train().

Проверка окружения:

python -c "import torch; print(torch.__version__, torch.cuda.is_available())"

Вторая часть вывода — True, если доступна CUDA; для практикума обе ситуации нормальны.

Самопроверка

  • python -c "import torch; print(torch.__version__)" выводит версию без ошибок.
  • pip list содержит torch, torchvision, Pillow.
  • Папка models/ существует.

Этап 1 — модель DigitCNN

Цель — описать архитектуру CNN для десяти классов (цифры 0–9) и проверить форму выхода.

Теория в двух абзацах

Полносвязная сеть на "сырых" 784 пикселях MNIST работает, но не использует пространственную структуру: соседние пиксели связаны смыслом (контур цифры), а не случайным порядком в векторе. Свёртка смотрит на локальное окно 3×3 и строит карты признаков — на ранних слоях это края, на поздних — фрагменты цифр. Нейрон и слои объясняют идею; здесь — практическая сборка в PyTorch.

Создайте model.py:

import torch
import torch.nn as nn


class DigitCNN(nn.Module):
"""Простая CNN для распознавания рукописных цифр MNIST."""

def __init__(self) -> None:
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 7 * 7, 128),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(128, 10),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
return self.classifier(x)

Разбор архитектуры

Базовый контракт nn.Module.

  • super().__init__() — регистрация подмодулей в PyTorch.
  • forward — что происходит при вызове model(x); обучение и autograd идут через этот путь.
  • Разделение на features (свёртки) и classifier (полносвязные слои) — привычный паттерн torchvision; удобно менять "голову" под другую задачу.

Вход — тензор формы (batch, 1, 28, 28):

  • batch — сколько картинок в пачке;
  • 1 — один канал (оттенки серого, не RGB);
  • 28 × 28 — размер MNIST.

Блок features — размеры по шагам

СлойВыходная формаКомментарий
Вход(N, 1, 28, 28)N — размер батча
Conv2d(1→32, 3, pad=1) + ReLU(N, 32, 28, 28)padding сохраняет сторону
MaxPool2d(2)(N, 32, 14, 14)окно 2×2, шаг 2
Conv2d(32→64, 3, pad=1) + ReLU(N, 64, 14, 14)больше фильтров — богаче признаки
MaxPool2d(2)(N, 64, 7, 7)финальная карта перед flatten

Блок classifier.

  • Flatten() — вектор длины 64 × 7 × 7 = 3136.
  • Linear(3136, 128) + ReLU — скрытый полносвязный слой.
  • Dropout(0.25) — случайное "выключение" 25% нейронов только в model.train(); борется с переобучением — см. смещение и дисперсия.
  • Linear(128, 10) — по одному logit на каждый класс 0…9; softmax не внутри модели — его добавят при инференсе.

Проверка формы (добавьте в конец model.py):

if __name__ == "__main__":
model = DigitCNN()
dummy = torch.randn(4, 1, 28, 28)
out = model(dummy)
assert out.shape == (4, 10)
print("OK:", out.shape)

Разбор проверки.

  • torch.randn(4, 1, 28, 28) — четыре случайных "картинки" для smoke-теста.
  • assert out.shape == (4, 10) — на каждый образец десять logits.
  • Такой тест не проверяет качество, только согласованность размерностей — первый шаг перед обучением.

Самопроверка

  • python model.py печатает OK: torch.Size([4, 10]).
  • Понимаете, откуда берётся число 64 * 7 * 7.
  • Можете словами объяснить разницу между features и classifier.

Этап 2 — обучение train.py

Цель — скачать MNIST, прогнать цикл обучения, сохранить веса в models/digit_cnn.pt.

Гиперпараметры учебного запуска

КонстантаЗначениеЗачем
EPOCHS2быстрый прототип; для продакшена — десятки эпох
BATCH_SIZE128баланс скорости и памяти
LEARNING_RATE1e-3типичный старт для Adam
TRAIN_SAMPLES20_000часть train MNIST; полный train — 60 000

Создайте train.py:

"""Обучение модели и сохранение весов в models/digit_cnn.pt"""

from pathlib import Path
from typing import Callable

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

from model import DigitCNN

MODEL_PATH = Path(__file__).parent / "models" / "digit_cnn.pt"
EPOCHS = 2
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
TRAIN_SAMPLES = 20_000


def train(on_progress: Callable[[str], None] | None = None) -> Path:
def report(message: str) -> None:
if on_progress:
on_progress(message)
else:
print(message)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
report(f"Устройство: {device}")

transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)

dataset = datasets.MNIST(root="data", train=True, download=True, transform=transform)
subset = Subset(dataset, range(min(TRAIN_SAMPLES, len(dataset))))
train_loader = DataLoader(subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

model = DigitCNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

model.train()
for epoch in range(EPOCHS):
total_loss = 0.0
correct = 0
total = 0

for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)

optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

total_loss += loss.item() * images.size(0)
correct += (outputs.argmax(dim=1) == labels).sum().item()
total += labels.size(0)

accuracy = 100.0 * correct / total
avg_loss = total_loss / total
report(f"Эпоха {epoch + 1}/{EPOCHS} — loss: {avg_loss:.4f}, accuracy: {accuracy:.2f}%")

MODEL_PATH.parent.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), MODEL_PATH)
report(f"Модель сохранена: {MODEL_PATH}")
return MODEL_PATH


if __name__ == "__main__":
train()

Разбор цикла обучения

Устройство и перенос данных.

  • torch.device("cuda" if ... else "cpu") — модель и батчи должны жить на одном device.
  • .to(device) для model, images, labels — иначе будет ошибка смешения CPU/CUDA.

Трансформации MNIST.

  • ToTensor() — PIL или ndarray → float-тензор [0, 1].
  • Normalize((0.1307,), (0.3081,)) — вычитание среднего и деление на std одного канала; константы посчитаны для MNIST.
  • Критично: те же числа должны быть в app.py при инференсе — иначе модель "видит" другой масштаб пикселей и accuracy падает.

DataLoader и Subset.

  • datasets.MNIST(..., download=True) — первый запуск скачает архив (~10 МБ).
  • Subset(..., range(20_000)) — учебное ускорение; для соревнования качества уберите Subset.
  • shuffle=True — каждую эпоху батчи в новом порядке; num_workers=0 — проще на Windows (без fork-воркеров).

Один шаг батча — мини-алгоритм обучения.

  1. optimizer.zero_grad() — градиенты не накапливаются между шагами.
  2. outputs = model(images) — прямой проход (forward).
  3. loss = criterion(outputs, labels) — scalar loss.
  4. loss.backward() — backprop; цепное правило — перцептрон на NumPy.
  5. optimizer.step() — обновление весов Adam.

Метрики эпохи.

  • loss.item() — Python-float из scalar-тензора; .item() нужен, чтобы не держать граф autograd в истории.
  • outputs.argmax(dim=1) — предсказанный класс по максимальному logit.
  • accuracy здесь считается на train — оптимистичная оценка; честнее держать отложенный test MNIST (train=False) — см. разбиение данных.

Сохранение.

  • model.state_dict() — только веса, без архитектуры.
  • torch.save(..., MODEL_PATH) — файл digit_cnn.pt; загрузка возможна только в тот же класс DigitCNN.

Запуск:

python train.py

Ожидаемый вывод (числа могут чуть отличаться):

Устройство: cpu
Эпоха 1/2 — loss: 0.23xx, accuracy: 9x.xx%
Эпоха 2/2 — loss: 0.07xx, accuracy: 9x.xx%
Модель сохранена: ...\models\digit_cnn.pt

Самопроверка

  • Две строки "Эпоха …" с падающим loss.
  • Файл models/digit_cnn.pt создан.
  • Папка data/MNIST появилась после первого запуска.
  • Accuracy после 2-й эпохи обычно >95% на подвыборке.

Этап 3 — каркас GUI app.py

Цель — окно Tkinter, кнопки, загрузка сохранённых весов; рисование и предсказание добавим на следующих этапах.

Слои интерфейса

Десктопное приложение удобно мыслить сверху вниз:

  1. Заголовок — подсказка пользователю.
  2. Canvas — область рисования 280×280 (масштаб для удобства; модель всё равно получит 28×28).
  3. Панель кнопок — "Распознать" и "Очистить".
  4. Строка результатаStringVar + Label.
  5. Служебная строка — device PyTorch.

Компоновка через grid — см. pack и grid в 3111. Контекст жанра — десктопные приложения.

Начните app.py (этапы 4–5 дополнят этот файл):

"""GUI-приложение: рисуйте цифру мышью, PyTorch распознаёт её."""

import tkinter as tk
from tkinter import ttk

import torch

from model import DigitCNN
from train import MODEL_PATH

CANVAS_SIZE = 280


class DigitRecognizerApp:
def __init__(self, root: tk.Tk) -> None:
self.root = root
self.root.title("Распознавание цифр — PyTorch")
self.root.resizable(False, False)

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model: DigitCNN | None = None

self._build_ui()

if MODEL_PATH.exists():
self._load_model()
else:
self.result_var.set("Сначала выполните: python train.py")

def _build_ui(self) -> None:
main = ttk.Frame(self.root, padding=16)
main.grid(row=0, column=0)

ttk.Label(main, text="Нарисуйте цифру от 0 до 9", font=("Segoe UI", 14)).grid(
row=0, column=0, columnspan=2, pady=(0, 12)
)

self.canvas = tk.Canvas(
main,
width=CANVAS_SIZE,
height=CANVAS_SIZE,
bg="black",
highlightthickness=2,
highlightbackground="#cccccc",
)
self.canvas.grid(row=1, column=0, columnspan=2)

btn_frame = ttk.Frame(main)
btn_frame.grid(row=2, column=0, columnspan=2, pady=12)

self.predict_btn = ttk.Button(btn_frame, text="Распознать", state=tk.DISABLED)
self.predict_btn.pack(side=tk.LEFT, padx=4)
ttk.Button(btn_frame, text="Очистить", command=self._clear).pack(side=tk.LEFT, padx=4)

self.result_var = tk.StringVar(value="Загрузка...")
ttk.Label(main, textvariable=self.result_var, font=("Segoe UI", 12)).grid(
row=3, column=0, columnspan=2
)

ttk.Label(
main,
text=f"PyTorch · {self.device}",
font=("Segoe UI", 9),
foreground="#666666",
).grid(row=4, column=0, columnspan=2, pady=(8, 0))

def _load_model(self) -> None:
model = DigitCNN().to(self.device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=self.device, weights_only=True))
model.eval()
self.model = model
self.result_var.set("Результат появится здесь")
self.predict_btn.config(state=tk.NORMAL)

def _clear(self) -> None:
self.canvas.delete("all")


def main() -> None:
root = tk.Tk()
style = ttk.Style()
if "vista" in style.theme_names():
style.theme_use("vista")
DigitRecognizerApp(root)
root.mainloop()


if __name__ == "__main__":
main()

Разбор каркаса

Импорты и точка входа.

  • ttk — themed widgets; на Windows тема vista ближе к нативному виду.
  • MODEL_PATH импортируется из train.py, чтобы путь к весам был один в проекте.
  • if __name__ == "__main__" — см. точку входа Python.

Состояние приложения.

  • self.modelNone, пока веса не загружены; кнопка "Распознать" disabled.
  • self.device — тот же выбор CPU/CUDA, что в train.py.
  • StringVar — связка текста результата с Label; обновление через .set() без пересоздания виджета — 3111 — переменные Tkinter.

Загрузка весов _load_model.

  • Сначала создаётся пустая архитектура DigitCNN(), затем в неё заливаются веса.
  • map_location=self.device — файл, сохранённый на GPU, откроется на CPU и наоборот.
  • weights_only=True (PyTorch 2.x) — не выполнять произвольный pickle, только тензоры весов.
  • model.eval() — Dropout выключен; иначе предсказания будут случайными.

Самопроверка

  • После python train.py запуск python app.py показывает чёрный canvas и активную кнопку "Распознать".
  • Без файла весов — подсказка выполнить train.py.
  • Внизу окна видно PyTorch · cpu или cuda:0.

Этап 4 — рисование и препроцессинг

Цель — рисовать белым по чёрному, дублировать штрих в PIL и готовить изображение 28×28 в стиле MNIST.

Конвенция пикселей MNIST vs canvas

ФонЦифраРазмер
Наш canvasчёрный (0)белый (255)280×280
MNISTсветлый (~белый)тёмная (~чёрная)28×28

Поэтому перед подачей в сеть нужны invert, crop по контуру и resize — не "сырой" canvas.

Добавьте в начало app.py импорты:

from PIL import Image, ImageDraw, ImageOps

В __init__ после создания canvas:

self.image = Image.new("L", (CANVAS_SIZE, CANVAS_SIZE), 0)
self.draw = ImageDraw.Draw(self.image)
self.last_x: int | None = None
self.last_y: int | None = None

self.canvas.bind("<Button-1>", self._on_press)
self.canvas.bind("<B1-Motion>", self._on_drag)
self.canvas.bind("<ButtonRelease-1>", self._on_release)

Константы (рядом с CANVAS_SIZE):

BRUSH_WIDTH = 18
DRAW_COLOR = 255
BACKGROUND_COLOR = 0

Методы рисования:

def _on_press(self, event: tk.Event) -> None:
self.last_x, self.last_y = event.x, event.y
self._draw_point(event.x, event.y)

def _on_drag(self, event: tk.Event) -> None:
if self.last_x is None or self.last_y is None:
return
self.canvas.create_line(
self.last_x,
self.last_y,
event.x,
event.y,
fill="white",
width=BRUSH_WIDTH,
capstyle=tk.ROUND,
smooth=True,
)
self.draw.line(
[self.last_x, self.last_y, event.x, event.y],
fill=DRAW_COLOR,
width=BRUSH_WIDTH,
)
self.last_x, self.last_y = event.x, event.y

def _on_release(self, _event: tk.Event) -> None:
self.last_x, self.last_y = None, None

def _draw_point(self, x: int, y: int) -> None:
r = BRUSH_WIDTH // 2
self.canvas.create_oval(x - r, y - r, x + r, y + r, fill="white", outline="white")
self.draw.ellipse([x - r, y - r, x + r, y + r], fill=DRAW_COLOR)

def _clear(self) -> None:
self.canvas.delete("all")
self.image = Image.new("L", (CANVAS_SIZE, CANVAS_SIZE), BACKGROUND_COLOR)
self.draw = ImageDraw.Draw(self.image)
if self.model is not None:
self.result_var.set("Результат появится здесь")

Зачем два слоя — Canvas и PIL

  • Canvas — то, что видит пользователь; Tkinter хранит примитивы (линии, овалы), а не готовую bitmap для numpy/torch.
  • PIL (self.image) — off-screen растр того же рисунка; из него удобно делать crop, resize, ToTensor.

События мыши:

  • <Button-1> — нажатие; <B1-Motion> — drag с зажатой левой кнопкой; <ButtonRelease-1> — отпускание.
  • last_x, last_y — предыдущая точка для непрерывной линии.

Препроцессинг _preprocess

Добавьте метод в класс:

def _preprocess(self) -> Image.Image | None:
if self.image.getbbox() is None:
return None
processed = ImageOps.invert(self.image)
bbox = processed.getbbox()
if bbox is None:
return None
cropped = processed.crop(bbox)
return ImageOps.fit(cropped, (28, 28), method=Image.Resampling.LANCZOS)

Разбор по строкам.

  • getbbox() — прямоугольник непустых пикселей; None — canvas пуст.
  • ImageOps.invert — белая цифра на чёрном → тёмная на светлом, как в MNIST.
  • crop(bbox) — убираем лишние поля вокруг цифры.
  • ImageOps.fit(..., (28, 28), LANCZOS) — вписать в квадрат MNIST с сохранением пропорций; качественнее, чем грубый resize.

Для отладки можно временно сохранять результат:

resized.save("debug_28.png")

Самопроверка

  • Линия мышью рисуется плавно на чёрном фоне.
  • "Очистить" стирает и canvas, и PIL-слой.
  • После рисования "7" файл debug_28.png (если добавили save) похож на MNIST-цифру.

Этап 5 — предсказание и автообучение

Цель — softmax и топ-3 классов; при отсутствии весов — обучение в фоновом потоке без заморозки окна.

Train vs inference в одном приложении

На первом запуске файла digit_cnn.pt может не быть. Тогда GUI сам вызывает train() из train.py — пользователь не обязан помнить про отдельный скрипт. Долгая операция уходит в daemon-поток; UI обновляется через root.after(0, ...).

Добавьте импорты:

import threading
from tkinter import messagebox

import torch.nn.functional as F
from torchvision import transforms

from train import MODEL_PATH, train

Подключите кнопку в _build_ui:

self.predict_btn = ttk.Button(btn_frame, text="Распознать", command=self._predict, state=tk.DISABLED)

Добавьте progressbar после result_var:

self.progress = ttk.Progressbar(main, mode="indeterminate", length=280)
self.progress.grid(row=4, column=0, columnspan=2, pady=(8, 0))
self.progress.grid_remove()

Сдвиньте метку устройства на row=5.

Замените ветку "нет весов" в __init__:

if MODEL_PATH.exists():
self._load_model()
else:
self._start_training()

Методы предсказания и фонового обучения:

def _start_training(self) -> None:
self.result_var.set("Первый запуск: обучение модели на MNIST (~1–2 мин)...")
self.progress.grid()
self.progress.start(12)

def worker() -> None:
try:
train(on_progress=lambda msg: self.root.after(0, self.result_var.set, msg))
self.root.after(0, self._on_training_done)
except Exception as exc:
self.root.after(0, self._on_training_failed, str(exc))

threading.Thread(target=worker, daemon=True).start()

def _on_training_done(self) -> None:
self.progress.stop()
self.progress.grid_remove()
self._load_model()

def _on_training_failed(self, error: str) -> None:
self.progress.stop()
self.progress.grid_remove()
messagebox.showerror("Ошибка обучения", error)
self.root.destroy()

def _predict(self) -> None:
if self.model is None:
return

resized = self._preprocess()
if resized is None:
self.result_var.set("Сначала нарисуйте цифру")
return

transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)
tensor = transform(resized).unsqueeze(0).to(self.device)

with torch.no_grad():
logits = self.model(tensor)
probs = F.softmax(logits, dim=1)[0]
digit = int(probs.argmax().item())
confidence = float(probs[digit].item()) * 100

top3 = probs.topk(3)
details = ", ".join(
f"{int(idx.item())}: {prob.item() * 100:.1f}%"
for prob, idx in zip(top3.values, top3.indices)
)
self.result_var.set(f"Цифра: {digit} ({confidence:.1f}%) · топ-3: {details}")

Разбор инференса

Тензор на вход модели.

  • transform(resized) — PIL → tensor [1, 28, 28] после ToTensor + Normalize.
  • unsqueeze(0) — добавляет размерность batch → [1, 1, 28, 28], как в train.py.

Softmax и уверенность.

  • Logits — любые вещественные числа; softmax переводит их в вероятности, сумма = 1.
  • confidence — доля max-класса; 85% не значит "абсолютная истина", но помогает увидеть сомнение модели.
  • topk(3) — полезно для пар вроде 3/8 или 4/9 — типичная путаница на MNIST.

Режим без градиентов.

  • torch.no_grad() — не строить граф autograd; быстрее и меньше RAM.
  • Вместе с model.eval() — каноничный инференс; подробнее — 333 — сохранение и загрузка.

Поток threading.

  • Обучение в главном потоке заморозило бы mainloop.
  • root.after(0, callback) ставит обновление UI в очередь Tkinter — безопасно из worker-потока.
  • daemon=True — поток не держит процесс после закрытия окна.
  • Тот же паттерн — в 311 — Tkinter и GUI; теория потоков — многопоточность Python.

Самопроверка

  • Удалите models/digit_cnn.pt, запустите app.py — идёт обучение с progressbar, затем можно рисовать.
  • Нарисованная "7" распознаётся с confidence обычно >70%.
  • Пустой canvas — "Сначала нарисуйте цифру".
  • В топ-3 видны альтернативные цифры с процентами.

Этап 6 — ревизия и дальнейшие шаги

Цель — убедиться, что три модуля работают вместе; понять типичные ошибки и куда развивать проект.

Итоговая архитектура

Полный код совпадает с образцом в F:\Projects\Python\TestPyTorch. Запуск:

python app.py

Итоговая самопроверка

  • model.py, train.py, app.py в одной папке; импорты без ошибок.
  • Обучение создаёт models/digit_cnn.pt; повторный запуск GUI грузит веса без переобучения.
  • Рисование, "Очистить", "Распознать" и топ-3 работают.
  • На CPU обучение укладывается в несколько минут.
  • Можете объяснить, зачем invert и те же константы Normalize в train и predict.

Частые ошибки

СимптомВероятная причинаЧто проверить
Всегда одна цифра / случайный ответmodel.train() при predict_load_modelmodel.eval()
Низкая confidence на нормальной цифредругой Normalize или без invert_preprocess и transforms
RuntimeError devicetensor на CPU, model на CUDA.to(self.device) везде
Окно "висит" при первом запускеtrain() в main thread_start_training + threading
size mismatch при loadизменили архитектуруудалите старый .pt, переобучите

Что добавить самостоятельно

УлучшениеЗачемПодсказка
Больше эпох / полный MNISTвыше accuracyубрать Subset, EPOCHS = 5
Валидация на test MNISTчестная метрикаdatasets.MNIST(..., train=False)
Сохранение рисункаотладка препроцессингаresized.save("debug.png")
Экспорт ONNXдеплой без Pythontorch.onnx.exportприменение ИИ
Unit-тесты _preprocessрегрессииpytest на синтетическом PIL
Data augmentationустойчивость к сдвигуRandomAffine в train.py

Связанные материалы

PyTorch и нейросети

Данные и GUI


В подборках

Аналитика данныхАнализ данных — о разделе, Data Science, Python — о разделе.

Нейросети и ИИМашинное обучение — о разделе, Нейросети — о разделе, Введение в ИИ — о разделе.