"""
app.py - Flask 예측 API 서버
PHP 프론트엔드가 HTTP로 호출하는 추론 엔진.

엔드포인트:
  GET  /health                 상태 확인
  POST /train                  모델 (재)학습
  POST /predict {batch_no}     배치 품질 예측 + 경고 + 영향도 (DB 저장)
  POST /simulate {batch_no, material_code, delta_pct}   What-if 시뮬레이션
  GET  /model_info             현재 모델 정보
"""
from flask import Flask, request, jsonify
from config import get_conn, MODEL_VER
import trainer
import predictor

app = Flask(__name__)


@app.route("/health")
def health():
    return jsonify({"status": "ok", "model_ver": MODEL_VER})


@app.route("/train", methods=["POST"])
def train_endpoint():
    try:
        report = trainer.train(verbose=False)
        predictor.reload_model()
        return jsonify({"success": True, "report": report})
    except Exception as e:
        return jsonify({"success": False, "error": str(e)}), 400


@app.route("/predict", methods=["POST"])
def predict_endpoint():
    data = request.get_json(force=True)
    batch_no = data.get("batch_no")
    if not batch_no:
        return jsonify({"error": "batch_no required"}), 400
    try:
        result = predictor.predict_batch(batch_no)
        if "error" in result:
            return jsonify(result), 400
        _persist(batch_no, result)
        return jsonify({"success": True, "result": result})
    except FileNotFoundError:
        return jsonify({"error": "학습된 모델이 없습니다. 먼저 /train 을 실행하세요."}), 400
    except Exception as e:
        return jsonify({"error": str(e)}), 500


@app.route("/simulate", methods=["POST"])
def simulate_endpoint():
    data = request.get_json(force=True)
    try:
        r = predictor.simulate(
            data["batch_no"], data["material_code"], float(data["delta_pct"]))
        return jsonify({"success": True, "result": r})
    except Exception as e:
        return jsonify({"error": str(e)}), 400


@app.route("/model_info")
def model_info():
    try:
        m = predictor.load_model()
        return jsonify({
            "ver": m.get("ver"), "algo": m.get("algo"),
            "n_features": len(m.get("features", [])),
            "reg_targets": list(m.get("reg", {}).keys()),
            "has_classifier": m.get("clf") is not None,
        })
    except FileNotFoundError:
        return jsonify({"error": "모델 없음"}), 404


def _persist(batch_no, result):
    """예측 결과 / 경고 / 영향도를 DB에 저장"""
    conn = get_conn()
    try:
        with conn.cursor() as cur:
            p = result["properties"]
            cur.execute("""
                INSERT INTO cq_prediction
                  (batch_no, pred_ph, pred_hardness, pred_sg, pred_content,
                   pred_visc, pred_result, fail_prob, model_ver)
                VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s)
            """, (batch_no, p.get("ph"), p.get("hardness"),
                  p.get("specific_grav"), p.get("content_pct"),
                  p.get("viscosity"), result["pred_result"],
                  result["fail_prob"], result["model_ver"]))

            # 기존 미확인 경고 정리 후 재기록
            cur.execute("DELETE FROM cq_warning WHERE batch_no=%s AND is_ack=0",
                        (batch_no,))
            for w in result["warnings"]:
                cur.execute("""
                    INSERT INTO cq_warning (batch_no, prop_name, level, message)
                    VALUES (%s,%s,%s,%s)
                """, (batch_no, w["prop"], w["level"], w["message"]))

            cur.execute("DELETE FROM cq_impact WHERE batch_no=%s", (batch_no,))
            for prop, items in result["impacts"].items():
                for rank, it in enumerate(items, 1):
                    cur.execute("""
                        INSERT INTO cq_impact
                          (batch_no, prop_name, feature_name, shap_value,
                           direction, rank_no)
                        VALUES (%s,%s,%s,%s,%s,%s)
                    """, (batch_no, prop, it["feature"], it["impact"],
                          it["direction"], rank))

            cur.execute(
                "UPDATE cq_batch SET status='PREDICTED' WHERE batch_no=%s",
                (batch_no,))
        conn.commit()
    finally:
        conn.close()


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=5005, debug=False)
