"""
trainer.py - 품질 예측 모델 학습
- 물성 회귀: XGBoost (데이터 적으면 RandomForest 자동 fallback)
- 적합/부적합 분류: XGBoost Classifier
- 학습 결과는 quality_model.pkl 로 저장
"""
import os
import joblib
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.metrics import mean_absolute_error, f1_score

try:
    import xgboost as xgb
    HAS_XGB = True
except ImportError:
    HAS_XGB = False

from config import MODEL_PATH, MODEL_VER, REG_TARGETS, get_conn
from features import build_training_matrix

# 데이터가 이 수치 미만이면 RandomForest 사용 (과적합 방지)
XGB_MIN_SAMPLES = 80


def _make_regressor(n):
    if HAS_XGB and n >= XGB_MIN_SAMPLES:
        return xgb.XGBRegressor(
            n_estimators=300, max_depth=5, learning_rate=0.05,
            subsample=0.8, colsample_bytree=0.8, random_state=42)
    return RandomForestRegressor(
        n_estimators=200, max_depth=8, min_samples_leaf=2,
        random_state=42, n_jobs=-1)


def _make_classifier(n, pos_ratio):
    spw = max(1.0, (1 - pos_ratio) / max(pos_ratio, 1e-6))
    if HAS_XGB and n >= XGB_MIN_SAMPLES:
        return xgb.XGBClassifier(
            n_estimators=300, max_depth=4, learning_rate=0.05,
            subsample=0.8, colsample_bytree=0.8,
            scale_pos_weight=spw, eval_metric="logloss", random_state=42)
    return RandomForestClassifier(
        n_estimators=200, max_depth=6, class_weight="balanced",
        random_state=42, n_jobs=-1)


def train(verbose=True):
    conn = get_conn()
    try:
        X, y_reg, y_class = build_training_matrix(conn)
    finally:
        conn.close()

    if X is None or len(X) < 10:
        raise RuntimeError(
            "학습 데이터가 부족합니다 (최소 10배치 필요). "
            "품질검사 결과가 입력된 배치를 더 확보하세요.")

    n = len(X)
    feature_names = list(X.columns)
    algo = "XGBoost" if (HAS_XGB and n >= XGB_MIN_SAMPLES) else "RandomForest"
    report = {"algo": algo, "n_samples": n, "reg_mae": {}, "clf_f1": None}

    # ---- 물성 회귀 모델 ----
    reg_models = {}
    for col in REG_TARGETS:
        y = y_reg[col].astype(float).fillna(y_reg[col].median())
        if y.nunique() < 2:
            continue
        Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.2, random_state=42)
        m = _make_regressor(n)
        m.fit(Xtr, ytr)
        mae = mean_absolute_error(yte, m.predict(Xte))
        report["reg_mae"][col] = round(float(mae), 4)
        # 전체 데이터로 재학습 (배포용)
        m.fit(X, y)
        reg_models[col] = m
        if verbose:
            print(f"[REG] {col:14s} MAE={mae:.4f}")

    # ---- 적합/부적합 분류 모델 ----
    clf = None
    if len(np.unique(y_class)) >= 2:
        pos_ratio = float(np.mean(y_class))
        Xtr, Xte, ytr, yte = train_test_split(
            X, y_class, test_size=0.2, random_state=42, stratify=y_class)
        clf = _make_classifier(n, pos_ratio)
        clf.fit(Xtr, ytr)
        f1 = f1_score(yte, clf.predict(Xte), zero_division=0)
        report["clf_f1"] = round(float(f1), 4)
        clf.fit(X, y_class)
        if verbose:
            print(f"[CLF] F1={f1:.4f}  (불량비율={pos_ratio:.2%})")

    os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
    joblib.dump({
        "reg": reg_models, "clf": clf,
        "features": feature_names, "ver": MODEL_VER,
        "algo": algo,
    }, MODEL_PATH)
    if verbose:
        print(f"\n모델 저장 완료: {MODEL_PATH} ({algo}, {n} samples)")
    return report


if __name__ == "__main__":
    print(train())
