Практикум — распознавание цифр на PyTorch
О практикуме
TestPyTorch — учебное приложение из трёх модулей. Свёрточная нейросеть (CNN) обучается на MNIST (рукописные цифры 28×28), веса сохраняются на диск, а окно Tkinter принимает рисунок мышью и выводит предсказание с уверенностью.
Зачем такой проект после теории в PyTorch для разработчика?
- сквозной цикл ML — данные → модель →
loss.backward()→ сохранение → инференс; - компьютерное зрение в миниатюре — свёртки, пулинг, softmax, препроцессинг изображения;
- GUI + фоновое обучение — длинная задача уходит в поток, окно остаётся отзывчивым;
- связка train и inference — одни и те же
Normalizeи архитектура вtrain.pyиapp.py.
Образец для сверки — F:\Projects\Python\TestPyTorch. Теория CNN и backprop — нейрон и слои, перцептрон на NumPy. Аналог по формату "десктоп + данные" — практикум Pandas Data Viewer. Табличный ML до нейросетей — scikit-learn и маршрут машинного обучения.
Нужны Python 3.10+, базовые классы, установленные torch, torchvision, Pillow. Желательно прочитать обзор PyTorch, один раз пройти первую программу на Tkinter и по возможности — NumPy в Lab. GPU не обязателен: на CPU обучение займёт 1–2 минуты.
Маршрут раздела "ИИ" — машинное обучение, затем нейросети.
Оценка времени — 2–4 часа при прохождении всех этапов подряд с самопроверкой после каждого шага.
Чему учит практикум
| Навык | Что именно тренируем | Где в коде |
|---|---|---|
| Архитектура CNN | Conv2d, MaxPool2d, Linear, Dropout | model.py |
| Цикл обучения | DataLoader, loss, backward, метрики | train.py |
| Работа с датасетом | MNIST, transforms, Subset | train.py |
| Сохранение модели | state_dict, torch.save / load | train.py, _load_model |
| Препроцессинг CV | invert, crop, resize 28×28, normalize | _preprocess, _predict |
| GUI | canvas, события мыши, StringVar, progressbar | app.py |
| Многопоточность | обучение без блокировки UI | _start_training |
Ключевые определения
- MNIST — набор 70 000 изображений рукописных цифр 0–9 в градациях серого 28×28; стандартный "Hello, World" для компьютерного зрения. Контекст датасетов — Data Science.
- CNN (Convolutional Neural Network, свёрточная нейросеть) — архитектура, где фильтры скользят по картинке и выделяют локальные признаки (штрихи, углы, дуги). Пулинг уменьшает разрешение карты признаков и даёт устойчивость к сдвигу.
- Logits — сырые выходы последнего слоя до softmax; модель возвращает именно их, а вероятности считают отдельно через
F.softmax. - Эпоха (epoch) — один полный проход по обучающей выборке. Две эпохи в учебном проекте — компромисс между временем и качеством.
- Батч (batch) — пачка примеров за один шаг градиента;
BATCH_SIZE = 128ускоряет обучение на GPU и стабилизирует оценку loss. - CrossEntropyLoss — функция потерь для многоклассовой классификации; внутри объединяет softmax и negative log-likelihood. Подробнее — 333 — loss и оптимизатор.
- Adam — адаптивный оптимизатор; для прототипа удобнее ручной подбор learning rate, чем классический SGD.
- Инференс — предсказание на новых данных; режим
model.eval()и контекстtorch.no_grad()отключают Dropout и autograd. - state_dict — словарь обученных весов и смещений; сохраняется через
torch.save, загружается черезload_state_dictв ту же архитектуру. - Train / eval —
model.train()включает Dropout и обучение;model.eval()— только предсказание. Смешивать режимы нельзя.
Архитектура проекта
Проект сознательно разделён на модель, обучение и интерфейс — тот же приём, что в Kivy-практикуме (логика отдельно от GUI):
test-pytorch/
├── model.py # класс DigitCNN (nn.Module)
├── train.py # обучение на MNIST, сохранение весов
├── app.py # GUI — рисование, предсказание
├── requirements.txt
├── data/ # MNIST (скачивается автоматически)
└── models/
└── digit_cnn.pt # веса после train()
Карта этапов
| Этап | Фокус | Результат |
|---|---|---|
| 0 | Окружение | venv, зависимости, структура папок |
| 1 | model.py | Класс DigitCNN |
| 2 | train.py | Цикл обучения, файл весов |
| 3 | app.py, каркас | Окно, кнопки, загрузка модели |
| 4 | Рисование | Canvas + PIL, препроцессинг под MNIST |
| 5 | Предсказание | Softmax, топ-3, фоновое обучение |
| 6 | Ревизия | Итоговая самопроверка |
Как проходить практикум
- Создайте папку
test-pytorch/, venv и установите зависимости — этап 0. - Идите этапы 1–5 по порядку — после этапа 2 запускайте
python train.py, после 3–5 —python app.py. - Копируйте код целиком из блока этапа; не пропускайте фрагменты вида
# .... - Отмечайте Самопроверку; читайте Разбор — там связь строк с теорией CNN и Tkinter.
- Финал — этап 6 и сверка с образцом
F:\Projects\Python\TestPyTorch.
Маршрут чтения
- Ключевые определения и архитектура — до первой строки кода.
- Этапы 0–2 — модель и обучение без GUI.
- Этапы 3–5 — интерфейс и инференс.
- Дальше — 333, нейросети или применение CV.
Правило прохождения — не переходите к следующему этапу, пока не отмечены пункты самопроверки. Тот же подход — в отладке и разработке.
Этап 0 — окружение
Цель — отдельная папка проекта, виртуальное окружение, зафиксированные зависимости и каталоги под данные и веса.
Зачем venv и requirements.txt
Виртуальное окружение изолирует пакеты проекта от системного Python: torch тянет тяжёлые зависимости, и их лучше не смешивать с другими задачами на том же компьютере. Файл requirements.txt фиксирует состав окружения — через месяц или на другой машине воспроизведёте установку одной командой. Подробнее — Зависимости Python.
mkdir test-pytorch && cd test-pytorch
python -m venv .venv
.venv\Scripts\activate
pip install "torch>=2.0.0" "torchvision>=0.15.0" "Pillow>=10.0.0"
requirements.txt:
torch>=2.0.0
torchvision>=0.15.0
Pillow>=10.0.0
Разбор команд
python -m venv .venv— каталог.venvс копией интерпретатора иpip..venv\Scripts\activate(Windows) подключает окружение в текущей сессии терминала; в Linux/macOS —source .venv/bin/activate.torch— тензоры,nn.Module, autograd, оптимизаторы; см. 333 — установка и device.torchvision— готовый MNIST иtransformsдля CV.Pillow— off-screen изображение для преобразования рисунка с canvas в массив 28×28.
С GPU-сборкой сверяйтесь с официальным индексом PyTorch; для этого практикума хватит CPU.
Структура каталогов
Создайте пустые папки:
models/— сюда попадётdigit_cnn.ptпосле обучения;data/— опционально; MNIST скачается вdata/MNISTпри первомtrain().
Проверка окружения:
python -c "import torch; print(torch.__version__, torch.cuda.is_available())"
Вторая часть вывода — True, если доступна CUDA; для практикума обе ситуации нормальны.
Самопроверка
-
python -c "import torch; print(torch.__version__)"выводит версию без ошибок. -
pip listсодержитtorch,torchvision,Pillow. - Папка
models/существует.
Этап 1 — модель DigitCNN
Цель — описать архитектуру CNN для десяти классов (цифры 0–9) и проверить форму выхода.
Теория в двух абзацах
Полносвязная сеть на "сырых" 784 пикселях MNIST работает, но не использует пространственную структуру: соседние пиксели связаны смыслом (контур цифры), а не случайным порядком в векторе. Свёртка смотрит на локальное окно 3×3 и строит карты признаков — на ранних слоях это края, на поздних — фрагменты цифр. Нейрон и слои объясняют идею; здесь — практическая сборка в PyTorch.
Создайте model.py:
import torch
import torch.nn as nn
class DigitCNN(nn.Module):
"""Простая CNN для распознавания рукописных цифр MNIST."""
def __init__(self) -> None:
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 7 * 7, 128),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(128, 10),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
return self.classifier(x)
Разбор архитектуры
Базовый контракт nn.Module.
super().__init__()— регистрация подмодулей в PyTorch.forward— что происходит при вызовеmodel(x); обучение и autograd идут через этот путь.- Разделение на
features(свёртки) иclassifier(полносвязные слои) — привычный паттерн torchvision; удобно менять "голову" под другую задачу.
Вход — тензор формы (batch, 1, 28, 28):
batch— сколько картинок в пачке;1— один канал (оттенки серого, не RGB);28 × 28— размер MNIST.
Блок features — размеры по шагам
| Слой | Выходная форма | Комментарий |
|---|---|---|
| Вход | (N, 1, 28, 28) | N — размер батча |
Conv2d(1→32, 3, pad=1) + ReLU | (N, 32, 28, 28) | padding сохраняет сторону |
MaxPool2d(2) | (N, 32, 14, 14) | окно 2×2, шаг 2 |
Conv2d(32→64, 3, pad=1) + ReLU | (N, 64, 14, 14) | больше фильтров — богаче признаки |
MaxPool2d(2) | (N, 64, 7, 7) | финальная карта перед flatten |
Блок classifier.
Flatten()— вектор длины64 × 7 × 7 = 3136.Linear(3136, 128)+ReLU— скрытый полносвязный слой.Dropout(0.25)— случайное "выключение" 25% нейронов только вmodel.train(); борется с переобучением — см. смещение и дисперсия.Linear(128, 10)— по одному logit на каждый класс 0…9; softmax не внутри модели — его добавят при инференсе.
Проверка формы (добавьте в конец model.py):
if __name__ == "__main__":
model = DigitCNN()
dummy = torch.randn(4, 1, 28, 28)
out = model(dummy)
assert out.shape == (4, 10)
print("OK:", out.shape)
Разбор проверки.
torch.randn(4, 1, 28, 28)— четыре случайных "картинки" для smoke-теста.assert out.shape == (4, 10)— на каждый образец десять logits.- Такой тест не проверяет качество, только согласованность размерностей — первый шаг перед обучением.
Самопроверка
-
python model.pyпечатаетOK: torch.Size([4, 10]). - Понимаете, откуда берётся число
64 * 7 * 7. - Можете словами объяснить разницу между
featuresиclassifier.
Этап 2 — обучение train.py
Цель — скачать MNIST, прогнать цикл обучения, сохранить веса в models/digit_cnn.pt.
Гиперпараметры учебного запуска
| Константа | Значение | Зачем |
|---|---|---|
EPOCHS | 2 | быстрый прототип; для продакшена — десятки эпох |
BATCH_SIZE | 128 | баланс скорости и памяти |
LEARNING_RATE | 1e-3 | типичный старт для Adam |
TRAIN_SAMPLES | 20_000 | часть train MNIST; полный train — 60 000 |
Создайте train.py:
"""Обучение модели и сохранение весов в models/digit_cnn.pt"""
from pathlib import Path
from typing import Callable
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from model import DigitCNN
MODEL_PATH = Path(__file__).parent / "models" / "digit_cnn.pt"
EPOCHS = 2
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
TRAIN_SAMPLES = 20_000
def train(on_progress: Callable[[str], None] | None = None) -> Path:
def report(message: str) -> None:
if on_progress:
on_progress(message)
else:
print(message)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
report(f"Устройство: {device}")
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)
dataset = datasets.MNIST(root="data", train=True, download=True, transform=transform)
subset = Subset(dataset, range(min(TRAIN_SAMPLES, len(dataset))))
train_loader = DataLoader(subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
model = DigitCNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(EPOCHS):
total_loss = 0.0
correct = 0
total = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * images.size(0)
correct += (outputs.argmax(dim=1) == labels).sum().item()
total += labels.size(0)
accuracy = 100.0 * correct / total
avg_loss = total_loss / total
report(f"Эпоха {epoch + 1}/{EPOCHS} — loss: {avg_loss:.4f}, accuracy: {accuracy:.2f}%")
MODEL_PATH.parent.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), MODEL_PATH)
report(f"Модель сохранена: {MODEL_PATH}")
return MODEL_PATH
if __name__ == "__main__":
train()
Разбор цикла обучения
Устройство и перенос данных.
torch.device("cuda" if ... else "cpu")— модель и батчи должны жить на одном device..to(device)дляmodel,images,labels— иначе будет ошибка смешения CPU/CUDA.
Трансформации MNIST.
ToTensor()— PIL или ndarray → float-тензор[0, 1].Normalize((0.1307,), (0.3081,))— вычитание среднего и деление на std одного канала; константы посчитаны для MNIST.- Критично: те же числа должны быть в
app.pyпри инференсе — иначе модель "видит" другой масштаб пикселей и accuracy падает.
DataLoader и Subset.
datasets.MNIST(..., download=True)— первый запуск скачает архив (~10 МБ).Subset(..., range(20_000))— учебное ускорение; для соревнования качества уберите Subset.shuffle=True— каждую эпоху батчи в новом порядке;num_workers=0— проще на Windows (без fork-воркеров).
Один шаг батча — мини-алгоритм обучения.
optimizer.zero_grad()— градиенты не накапливаются между шагами.outputs = model(images)— прямой проход (forward).loss = criterion(outputs, labels)— scalar loss.loss.backward()— backprop; цепное правило — перцептрон на NumPy.optimizer.step()— обновление весов Adam.
Метрики эпохи.
loss.item()— Python-float из scalar-тензора;.item()нужен, чтобы не держать граф autograd в истории.outputs.argmax(dim=1)— предсказанный класс по максимальному logit.accuracyздесь считается на train — оптимистичная оценка; честнее держать отложенный test MNIST (train=False) — см. разбиение данных.
Сохранение.
model.state_dict()— только веса, без архитектуры.torch.save(..., MODEL_PATH)— файлdigit_cnn.pt; загрузка возможна только в тот же классDigitCNN.
Запуск:
python train.py
Ожидаемый вывод (числа могут чуть отличаться):
Устройство: cpu
Эпоха 1/2 — loss: 0.23xx, accuracy: 9x.xx%
Эпоха 2/2 — loss: 0.07xx, accuracy: 9x.xx%
Модель сохранена: ...\models\digit_cnn.pt
Самопроверка
- Две строки "Эпоха …" с падающим loss.
- Файл
models/digit_cnn.ptсоздан. - Папка
data/MNISTпоявилась после первого запуска. - Accuracy после 2-й эпохи обычно >95% на подвыборке.
Этап 3 — каркас GUI app.py
Цель — окно Tkinter, кнопки, загрузка сохранённых весов; рисование и предсказание добавим на следующих этапах.
Слои интерфейса
Десктопное приложение удобно мыслить сверху вниз:
- Заголовок — подсказка пользователю.
- Canvas — область рисования 280×280 (масштаб для удобства; модель всё равно получит 28×28).
- Панель кнопок — "Распознать" и "Очистить".
- Строка результата —
StringVar+ Label. - Служебная строка — device PyTorch.
Компоновка через grid — см. pack и grid в 3111. Контекст жанра — десктопные приложения.
Начните app.py (этапы 4–5 дополнят этот файл):
"""GUI-приложение: рисуйте цифру мышью, PyTorch распознаёт её."""
import tkinter as tk
from tkinter import ttk
import torch
from model import DigitCNN
from train import MODEL_PATH
CANVAS_SIZE = 280
class DigitRecognizerApp:
def __init__(self, root: tk.Tk) -> None:
self.root = root
self.root.title("Распознавание цифр — PyTorch")
self.root.resizable(False, False)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model: DigitCNN | None = None
self._build_ui()
if MODEL_PATH.exists():
self._load_model()
else:
self.result_var.set("Сначала выполните: python train.py")
def _build_ui(self) -> None:
main = ttk.Frame(self.root, padding=16)
main.grid(row=0, column=0)
ttk.Label(main, text="Нарисуйте цифру от 0 до 9", font=("Segoe UI", 14)).grid(
row=0, column=0, columnspan=2, pady=(0, 12)
)
self.canvas = tk.Canvas(
main,
width=CANVAS_SIZE,
height=CANVAS_SIZE,
bg="black",
highlightthickness=2,
highlightbackground="#cccccc",
)
self.canvas.grid(row=1, column=0, columnspan=2)
btn_frame = ttk.Frame(main)
btn_frame.grid(row=2, column=0, columnspan=2, pady=12)
self.predict_btn = ttk.Button(btn_frame, text="Распознать", state=tk.DISABLED)
self.predict_btn.pack(side=tk.LEFT, padx=4)
ttk.Button(btn_frame, text="Очистить", command=self._clear).pack(side=tk.LEFT, padx=4)
self.result_var = tk.StringVar(value="Загрузка...")
ttk.Label(main, textvariable=self.result_var, font=("Segoe UI", 12)).grid(
row=3, column=0, columnspan=2
)
ttk.Label(
main,
text=f"PyTorch · {self.device}",
font=("Segoe UI", 9),
foreground="#666666",
).grid(row=4, column=0, columnspan=2, pady=(8, 0))
def _load_model(self) -> None:
model = DigitCNN().to(self.device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=self.device, weights_only=True))
model.eval()
self.model = model
self.result_var.set("Результат появится здесь")
self.predict_btn.config(state=tk.NORMAL)
def _clear(self) -> None:
self.canvas.delete("all")
def main() -> None:
root = tk.Tk()
style = ttk.Style()
if "vista" in style.theme_names():
style.theme_use("vista")
DigitRecognizerApp(root)
root.mainloop()
if __name__ == "__main__":
main()
Разбор каркаса
Импорты и точка входа.
ttk— themed widgets; на Windows темаvistaближе к нативному виду.MODEL_PATHимпортируется изtrain.py, чтобы путь к весам был один в проекте.if __name__ == "__main__"— см. точку входа Python.
Состояние приложения.
self.model—None, пока веса не загружены; кнопка "Распознать" disabled.self.device— тот же выбор CPU/CUDA, что вtrain.py.StringVar— связка текста результата с Label; обновление через.set()без пересоздания виджета — 3111 — переменные Tkinter.
Загрузка весов _load_model.
- Сначала создаётся пустая архитектура
DigitCNN(), затем в неё заливаются веса. map_location=self.device— файл, сохранённый на GPU, откроется на CPU и наоборот.weights_only=True(PyTorch 2.x) — не выполнять произвольный pickle, только тензоры весов.model.eval()— Dropout выключен; иначе предсказания будут случайными.
Самопроверка
- После
python train.pyзапускpython app.pyпоказывает чёрный canvas и активную кнопку "Распознать". - Без файла весов — подсказка выполнить
train.py. - Внизу окна видно
PyTorch · cpuилиcuda:0.
Этап 4 — рисование и препроцессинг
Цель — рисовать белым по чёрному, дублировать штрих в PIL и готовить изображение 28×28 в стиле MNIST.
Конвенция пикселей MNIST vs canvas
| Фон | Цифра | Размер | |
|---|---|---|---|
| Наш canvas | чёрный (0) | белый (255) | 280×280 |
| MNIST | светлый (~белый) | тёмная (~чёрная) | 28×28 |
Поэтому перед подачей в сеть нужны invert, crop по контуру и resize — не "сырой" canvas.
Добавьте в начало app.py импорты:
from PIL import Image, ImageDraw, ImageOps
В __init__ после создания canvas:
self.image = Image.new("L", (CANVAS_SIZE, CANVAS_SIZE), 0)
self.draw = ImageDraw.Draw(self.image)
self.last_x: int | None = None
self.last_y: int | None = None
self.canvas.bind("<Button-1>", self._on_press)
self.canvas.bind("<B1-Motion>", self._on_drag)
self.canvas.bind("<ButtonRelease-1>", self._on_release)
Константы (рядом с CANVAS_SIZE):
BRUSH_WIDTH = 18
DRAW_COLOR = 255
BACKGROUND_COLOR = 0
Методы рисования:
def _on_press(self, event: tk.Event) -> None:
self.last_x, self.last_y = event.x, event.y
self._draw_point(event.x, event.y)
def _on_drag(self, event: tk.Event) -> None:
if self.last_x is None or self.last_y is None:
return
self.canvas.create_line(
self.last_x,
self.last_y,
event.x,
event.y,
fill="white",
width=BRUSH_WIDTH,
capstyle=tk.ROUND,
smooth=True,
)
self.draw.line(
[self.last_x, self.last_y, event.x, event.y],
fill=DRAW_COLOR,
width=BRUSH_WIDTH,
)
self.last_x, self.last_y = event.x, event.y
def _on_release(self, _event: tk.Event) -> None:
self.last_x, self.last_y = None, None
def _draw_point(self, x: int, y: int) -> None:
r = BRUSH_WIDTH // 2
self.canvas.create_oval(x - r, y - r, x + r, y + r, fill="white", outline="white")
self.draw.ellipse([x - r, y - r, x + r, y + r], fill=DRAW_COLOR)
def _clear(self) -> None:
self.canvas.delete("all")
self.image = Image.new("L", (CANVAS_SIZE, CANVAS_SIZE), BACKGROUND_COLOR)
self.draw = ImageDraw.Draw(self.image)
if self.model is not None:
self.result_var.set("Результат появится здесь")
Зачем два слоя — Canvas и PIL
- Canvas — то, что видит пользователь; Tkinter хранит примитивы (линии, овалы), а не готовую bitmap для numpy/torch.
- PIL (
self.image) — off-screen растр того же рисунка; из него удобно делатьcrop,resize,ToTensor.
События мыши:
<Button-1>— нажатие;<B1-Motion>— drag с зажатой левой кнопкой;<ButtonRelease-1>— отпускание.last_x,last_y— предыдущая точка для непрерывной линии.
Препроцессинг _preprocess
Добавьте метод в класс:
def _preprocess(self) -> Image.Image | None:
if self.image.getbbox() is None:
return None
processed = ImageOps.invert(self.image)
bbox = processed.getbbox()
if bbox is None:
return None
cropped = processed.crop(bbox)
return ImageOps.fit(cropped, (28, 28), method=Image.Resampling.LANCZOS)
Разбор по строкам.
getbbox()— прямоугольник непустых пикселей;None— canvas пуст.ImageOps.invert— белая цифра на чёрном → тёмная на светлом, как в MNIST.crop(bbox)— убираем лишние поля вокруг цифры.ImageOps.fit(..., (28, 28), LANCZOS)— вписать в квадрат MNIST с сохранением пропорций; качественнее, чем грубыйresize.
Для отладки можно временно сохранять результат:
resized.save("debug_28.png")
Самопроверка
- Линия мышью рисуется плавно на чёрном фоне.
- "Очистить" стирает и canvas, и PIL-слой.
- После рисования "7" файл
debug_28.png(если добавили save) похож на MNIST-цифру.
Этап 5 — предсказание и автообучение
Цель — softmax и топ-3 классов; при отсутствии весов — обучение в фоновом потоке без заморозки окна.
Train vs inference в одном приложении
На первом запуске файла digit_cnn.pt может не быть. Тогда GUI сам вызывает train() из train.py — пользователь не обязан помнить про отдельный скрипт. Долгая операция уходит в daemon-поток; UI обновляется через root.after(0, ...).
Добавьте импорты:
import threading
from tkinter import messagebox
import torch.nn.functional as F
from torchvision import transforms
from train import MODEL_PATH, train
Подключите кнопку в _build_ui:
self.predict_btn = ttk.Button(btn_frame, text="Распознать", command=self._predict, state=tk.DISABLED)
Добавьте progressbar после result_var:
self.progress = ttk.Progressbar(main, mode="indeterminate", length=280)
self.progress.grid(row=4, column=0, columnspan=2, pady=(8, 0))
self.progress.grid_remove()
Сдвиньте метку устройства на row=5.
Замените ветку "нет весов" в __init__:
if MODEL_PATH.exists():
self._load_model()
else:
self._start_training()
Методы предсказания и фонового обучения:
def _start_training(self) -> None:
self.result_var.set("Первый запуск: обучение модели на MNIST (~1–2 мин)...")
self.progress.grid()
self.progress.start(12)
def worker() -> None:
try:
train(on_progress=lambda msg: self.root.after(0, self.result_var.set, msg))
self.root.after(0, self._on_training_done)
except Exception as exc:
self.root.after(0, self._on_training_failed, str(exc))
threading.Thread(target=worker, daemon=True).start()
def _on_training_done(self) -> None:
self.progress.stop()
self.progress.grid_remove()
self._load_model()
def _on_training_failed(self, error: str) -> None:
self.progress.stop()
self.progress.grid_remove()
messagebox.showerror("Ошибка обучения", error)
self.root.destroy()
def _predict(self) -> None:
if self.model is None:
return
resized = self._preprocess()
if resized is None:
self.result_var.set("Сначала нарисуйте цифру")
return
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)
tensor = transform(resized).unsqueeze(0).to(self.device)
with torch.no_grad():
logits = self.model(tensor)
probs = F.softmax(logits, dim=1)[0]
digit = int(probs.argmax().item())
confidence = float(probs[digit].item()) * 100
top3 = probs.topk(3)
details = ", ".join(
f"{int(idx.item())}: {prob.item() * 100:.1f}%"
for prob, idx in zip(top3.values, top3.indices)
)
self.result_var.set(f"Цифра: {digit} ({confidence:.1f}%) · топ-3: {details}")
Разбор инференса
Тензор на вход модели.
transform(resized)— PIL → tensor[1, 28, 28]после ToTensor + Normalize.unsqueeze(0)— добавляет размерность batch →[1, 1, 28, 28], как вtrain.py.
Softmax и уверенность.
- Logits — любые вещественные числа; softmax переводит их в вероятности, сумма = 1.
confidence— доля max-класса; 85% не значит "абсолютная истина", но помогает увидеть сомнение модели.topk(3)— полезно для пар вроде 3/8 или 4/9 — типичная путаница на MNIST.
Режим без градиентов.
torch.no_grad()— не строить граф autograd; быстрее и меньше RAM.- Вместе с
model.eval()— каноничный инференс; подробнее — 333 — сохранение и загрузка.
Поток threading.
- Обучение в главном потоке заморозило бы
mainloop. root.after(0, callback)ставит обновление UI в очередь Tkinter — безопасно из worker-потока.daemon=True— поток не держит процесс после закрытия окна.- Тот же паттерн — в 311 — Tkinter и GUI; теория потоков — многопоточность Python.
Самопроверка
- Удалите
models/digit_cnn.pt, запуститеapp.py— идёт обучение с progressbar, затем можно рисовать. - Нарисованная "7" распознаётся с confidence обычно >70%.
- Пустой canvas — "Сначала нарисуйте цифру".
- В топ-3 видны альтернативные цифры с процентами.
Этап 6 — ревизия и дальнейшие шаги
Цель — убедиться, что три модуля работают вместе; понять типичные ошибки и куда развивать проект.
Итоговая архитектура
Полный код совпадает с образцом в F:\Projects\Python\TestPyTorch. Запуск:
python app.py
Итоговая самопроверка
-
model.py,train.py,app.pyв одной папке; импорты без ошибок. - Обучение создаёт
models/digit_cnn.pt; повторный запуск GUI грузит веса без переобучения. - Рисование, "Очистить", "Распознать" и топ-3 работают.
- На CPU обучение укладывается в несколько минут.
- Можете объяснить, зачем invert и те же константы Normalize в train и predict.
Частые ошибки
| Симптом | Вероятная причина | Что проверить |
|---|---|---|
| Всегда одна цифра / случайный ответ | model.train() при predict | _load_model → model.eval() |
| Низкая confidence на нормальной цифре | другой Normalize или без invert | _preprocess и transforms |
RuntimeError device | tensor на CPU, model на CUDA | .to(self.device) везде |
| Окно "висит" при первом запуске | train() в main thread | _start_training + threading |
size mismatch при load | изменили архитектуру | удалите старый .pt, переобучите |
Что добавить самостоятельно
| Улучшение | Зачем | Подсказка |
|---|---|---|
| Больше эпох / полный MNIST | выше accuracy | убрать Subset, EPOCHS = 5 |
| Валидация на test MNIST | честная метрика | datasets.MNIST(..., train=False) |
| Сохранение рисунка | отладка препроцессинга | resized.save("debug.png") |
| Экспорт ONNX | деплой без Python | torch.onnx.export — применение ИИ |
Unit-тесты _preprocess | регрессии | pytest на синтетическом PIL |
| Data augmentation | устойчивость к сдвигу | RandomAffine в train.py |
Связанные материалы
PyTorch и нейросети
- PyTorch для разработчика — тензоры, autograd, Dataset, checkpoint.
- Нейрон и слои · перцептрон на NumPy.
- Keras и TensorFlow — альтернативный фреймворк.
- Маршрут ML — от scikit-learn к глубокому обучению.
- Распознавание лиц и объектов — куда растёт CV после MNIST.
Данные и GUI
- NumPy — массивы и матрицы — основа под torch.
- практикум Pandas Data Viewer — другой сквозной Tkinter-проект.
- Tkinter и GUI · 3111 — первая программа · 3112 — справочник виджетов.
- Data Science · анализ данных — о разделе.
В подборках
Аналитика данных — Анализ данных — о разделе, Data Science, Python — о разделе.
Нейросети и ИИ — Машинное обучение — о разделе, Нейросети — о разделе, Введение в ИИ — о разделе.