"""
predictor.py - 추론 / 경고 판정 / 원료 영향도(SHAP) 분석
"""
import joblib
import numpy as np
import pandas as pd

try:
    import shap
    HAS_SHAP = True
except ImportError:
    HAS_SHAP = False

from config import MODEL_PATH, REG_TARGETS, get_conn
from features import build_batch_features, align_features

_MODEL_CACHE = None


def load_model():
    global _MODEL_CACHE
    if _MODEL_CACHE is None:
        _MODEL_CACHE = joblib.load(MODEL_PATH)
    return _MODEL_CACHE


def reload_model():
    global _MODEL_CACHE
    _MODEL_CACHE = None
    return load_model()


def _get_spec(conn, product_code):
    rows = pd.read_sql(
        "SELECT prop_name, spec_low, spec_high, spec_target "
        "FROM cq_spec WHERE product_code=%s", conn, params=[product_code])
    return {r["prop_name"]: r for _, r in rows.iterrows()}


def predict_batch(batch_no):
    """배치 예측 → 물성/적합여부/경고/영향도 반환"""
    model = load_model()
    conn = get_conn()
    try:
        feats = build_batch_features(batch_no, conn)
        if not feats:
            return {"error": "계근 데이터가 없습니다."}

        prod = pd.read_sql(
            "SELECT product_code FROM cq_batch WHERE batch_no=%s",
            conn, params=[batch_no])
        if prod.empty:
            return {"error": "배치를 찾을 수 없습니다."}
        product_code = prod.iloc[0]["product_code"]
        spec = _get_spec(conn, product_code)
    finally:
        conn.close()

    X = align_features(feats, model["features"])

    # ---- 물성 예측 ----
    props = {}
    for col in REG_TARGETS:
        if col in model["reg"]:
            props[col] = round(float(model["reg"][col].predict(X)[0]), 4)

    # ---- 적합/부적합 확률 ----
    fail_prob = None
    if model["clf"] is not None:
        fail_prob = round(float(model["clf"].predict_proba(X)[0][1]), 4)

    # ---- 경고 판정 ----
    warnings = _evaluate_warnings(props, fail_prob, spec)

    # ---- 종합 판정 ----
    has_critical = any(w["level"] == "CRITICAL" for w in warnings)
    if has_critical or (fail_prob is not None and fail_prob > 0.5):
        pred_result = "FAIL"
    elif any(w["level"] == "WARNING" for w in warnings) or \
         (fail_prob is not None and fail_prob > 0.3):
        pred_result = "REWORK"
    else:
        pred_result = "PASS"

    # ---- 원료 영향도(SHAP) ----
    impacts = _explain(model, X)

    return {
        "batch_no": batch_no,
        "product_code": product_code,
        "properties": props,
        "fail_prob": fail_prob,
        "pred_result": pred_result,
        "warnings": warnings,
        "impacts": impacts,
        "model_ver": model.get("ver"),
    }


def _evaluate_warnings(props, fail_prob, spec):
    warnings = []
    for prop, value in props.items():
        if prop not in spec:
            continue
        low = float(spec[prop]["spec_low"])
        high = float(spec[prop]["spec_high"])
        margin = (high - low) * 0.10
        if value < low or value > high:
            warnings.append({
                "prop": prop, "level": "CRITICAL",
                "message": f"{_kname(prop)} 예측값 {value} 이(가) "
                           f"관리기준({low}~{high})을 벗어남"})
        elif value < low + margin or value > high - margin:
            warnings.append({
                "prop": prop, "level": "WARNING",
                "message": f"{_kname(prop)} 예측값 {value} 이(가) "
                           f"관리기준 경계에 근접"})
    if fail_prob is not None and fail_prob > 0.3:
        lvl = "CRITICAL" if fail_prob > 0.5 else "WARNING"
        warnings.append({
            "prop": "overall", "level": lvl,
            "message": f"부적합 발생 확률 {fail_prob*100:.1f}%"})
    return warnings


def _explain(model, X):
    """SHAP 기반 상위 영향 원료. SHAP 미설치 시 feature_importance fallback."""
    out = {}
    targets = {**model["reg"]}
    if model["clf"] is not None:
        targets["result"] = model["clf"]

    for name, m in targets.items():
        try:
            if HAS_SHAP:
                explainer = shap.TreeExplainer(m)
                sv = explainer.shap_values(X)
                vals = sv[1][0] if isinstance(sv, list) else sv[0]
            else:
                vals = m.feature_importances_
            pairs = sorted(zip(model["features"], np.ravel(vals)),
                           key=lambda t: abs(t[1]), reverse=True)[:5]
            out[name] = [{
                "feature": f, "impact": round(float(v), 6),
                "direction": "UP" if v > 0 else "DOWN"
            } for f, v in pairs]
        except Exception:
            continue
    return out


def simulate(batch_no, material_code, delta_pct):
    """What-if: 특정 원료 투입량을 delta_pct% 조정 시 물성 변화"""
    model = load_model()
    conn = get_conn()
    try:
        feats = build_batch_features(batch_no, conn)
    finally:
        conn.close()
    if not feats:
        return {"error": "계근 데이터가 없습니다."}

    base = align_features(feats, model["features"])
    adj = base.copy()
    key = f"ratio_{material_code}"
    if key in adj.columns:
        adj[key] *= (1 + delta_pct / 100.0)

    result = {"material_code": material_code, "delta_pct": delta_pct, "changes": {}}
    for col in REG_TARGETS:
        if col in model["reg"]:
            b = float(model["reg"][col].predict(base)[0])
            a = float(model["reg"][col].predict(adj)[0])
            result["changes"][col] = {
                "before": round(b, 4), "after": round(a, 4),
                "diff": round(a - b, 4)}
    return result


_KMAP = {"ph": "pH", "hardness": "경도", "specific_grav": "비중",
         "content_pct": "함량", "viscosity": "점도", "overall": "종합"}


def _kname(prop):
    return _KMAP.get(prop, prop)
