Перейти к основному содержимому

Scikit-learn — регрессия и классификация

Разработчику

Scikit-learn — стандартная библиотека для табличного машинного обучения на Python: регрессия, классификация, кластеризация, предобработка и оценка моделей в едином API. Для изображений, длинных текстов и больших нейросетей обычно переходят к TensorFlow или PyTorch — см. Keras и TensorFlow. Обзор всех типов обучения — в Машинное обучение; практический end-to-end пример — проект Melbourne.


Когда достаточно scikit-learn

ЗадачаТипичные алгоритмы sklearn
Прогноз числа (цена, спрос)LinearRegression, GradientBoostingRegressor, RandomForestRegressor
Категория (спам, отток)LogisticRegression, RandomForestClassifier, SVC
Группы без метокKMeans, DBSCAN
Сжатие признаковPCA, TruncatedSVD

Данные — таблица: строки — объекты, столбцы — признаки. Перед обучением нужны кодирование категорий и честное разбиение train/test.


Единый контракт API

Почти все модели sklearn поддерживают три метода:

  • fit(X, y) — обучение;
  • predict(X) — прогноз;
  • score(X, y) — быстрая оценка (для классификаторов — accuracy по умолчанию).
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.25, random_state=42, stratify=y
)

clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)
print(clf.score(X_test, y_test))

Регрессия — прогноз числа

Регрессия предсказывает непрерывную величину y (цена, температура, время доставки).

from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_absolute_error, r2_score
import numpy as np

X = np.array([[1], [2], [3], [4], [5]])
y = np.array([2.0, 4.1, 5.0, 4.2, 5.1])

model = LinearRegression()
model.fit(X, y)
pred = model.predict([[6]])

print(pred[0], model.coef_, model.intercept_)
print("MAE:", mean_absolute_error(y, model.predict(X)))
print("R²:", r2_score(y, model.predict(X)))
МетрикаСмысл
MAEСредняя абсолютная ошибка в единицах целевой переменной
RMSEШтрафует крупные промахи сильнее MAE
Доля дисперсии, объяснённая моделью (1.0 — идеально на train, на test обычно ниже)

Сквозной пример с GradientBoostingRegressor и GridSearchCVМельбурн.


Классификация — прогноз категории

Классификация предсказывает метку класса: 0/1, «кошка»/«собака», один из нескольких видов.

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix

clf = LogisticRegression(max_iter=500)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

print(classification_report(y_test, y_pred))
print(confusion_matrix(y_test, y_pred))

При дисбалансе классов (99% «не спам», 1% «спам») accuracy вводит в заблуждение. Смотрите precision, recall, F1 и матрицу ошибок — подробнее в разделе метрик Машинное обучение.

predict_proba(X) возвращает вероятности по классам — удобно для порога «срабатывания» в проде.


Pipeline — без утечки данных

Ошибка новичка: сначала StandardScaler на всей таблице, потом train_test_split. Статистики scaler «видят» test — метрики завышаются.

Pipeline объединяет предобработку и модель; fit на train применяет scaler только к train.

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC

pipe = Pipeline([
("scaler", StandardScaler()),
("clf", SVC(kernel="rbf", probability=True)),
])
pipe.fit(X_train, y_train)
print(pipe.score(X_test, y_test))

Тот же pipeline можно передать в GridSearchCV — перебор гиперпараметров внутри кросс-валидации на train.


Подбор гиперпараметров

from sklearn.model_selection import GridSearchCV

param_grid = {
"clf__C": [0.1, 1, 10],
"clf__gamma": ["scale", "auto"],
}

search = GridSearchCV(
pipe,
param_grid,
cv=5,
scoring="f1_macro",
n_jobs=-1,
)
search.fit(X_train, y_train)
print(search.best_params_, search.best_score_)
print("Test:", search.score(X_test, y_test))

cv=5 — пять фолдов на train; test остаётся для финального отчёта один раз.


Сохранение модели

import joblib

joblib.dump(pipe, "model.joblib")
loaded = joblib.load("model.joblib")
loaded.predict(X_test[:1])

В проде версионируйте и модель, и схему признаков (имена и типы столбцов), иначе после переобучения API сломается.


Связь с другими материалами

ТемаСтатья
Алгоритмы (деревья, SVM, бустинг)Алгоритмы ИИ
Pandas и EDAPython для анализа
Нейросети, изображения, текстKeras и TensorFlow, распознавание
Облачные API без своего обученияCognitive Services

См. также

Другие статьи этого же раздела в боковом меню (как на странице "О разделе").