import argparse
import math
import sys
import time
from datetime import datetime, timezone
from typing import List, Optional, Tuple

import numpy as np
import pandas as pd
import requests
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier
from sklearn.experimental import enable_hist_gradient_boosting  # noqa: F401
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (accuracy_score, classification_report,
                             confusion_matrix, roc_auc_score)
from sklearn.model_selection import train_test_split, TimeSeriesSplit
from joblib import dump

BINANCE_API = "https://api.binance.com/api/v3/klines"


def parse_date_to_ms(date_str: Optional[str]) -> Optional[int]:
    if not date_str:
        return None
    # Accept formats like YYYY-MM-DD or full isoformat
    try:
        dt = datetime.fromisoformat(date_str)
    except ValueError:
        dt = datetime.strptime(date_str, "%Y-%m-%d")
    if dt.tzinfo is None:
        dt = dt.replace(tzinfo=timezone.utc)
    return int(dt.timestamp() * 1000)


def fetch_binance_klines(
    symbol: str = "BTCUSDT",
    interval: str = "15m",
    max_candles: int = 2000,
    start_time_ms: Optional[int] = None,
    end_time_ms: Optional[int] = None,
    pause: float = 0.2,
) -> List[List]:
    """
    Fetch historical klines from Binance public API.
    Returns a list of klines, where each kline is a list per Binance spec.
    Fetches in chunks of up to 1000 candles until reaching max_candles or end.
    """
    all_klines: List[List] = []
    limit = 1000
    # If start_time is provided, page forward using startTime; otherwise page backward using endTime
    forward_mode = start_time_ms is not None
    next_start = start_time_ms
    next_end = end_time_ms

    while len(all_klines) < max_candles:
        remaining = max_candles - len(all_klines)
        this_limit = limit if remaining > limit else remaining

        params = {
            "symbol": symbol,
            "interval": interval,
            "limit": this_limit,
        }
        if forward_mode:
            # Move forward from start_time until we reach end_time or collected enough
            if next_start is not None:
                params["startTime"] = next_start
            if end_time_ms is not None:
                params["endTime"] = end_time_ms
        else:
            # Pull most recent chunk first; then step backward using endTime
            if next_end is not None:
                params["endTime"] = next_end

        resp = requests.get(BINANCE_API, params=params, timeout=30)
        resp.raise_for_status()
        klines = resp.json()

        if not klines:
            break

        all_klines.extend(klines)

        if forward_mode:
            # Prepare next page start time: last close time + 1 ms
            last_close_time = klines[-1][6]
            next_start = last_close_time + 1
        else:
            # Prepare next page end time to go backward: first open time - 1 ms
            first_open_time = klines[0][0]
            next_end = first_open_time - 1

        # Stop if server returned fewer than requested (reached end)
        if len(klines) < this_limit:
            break

        # Be nice to API
        time.sleep(pause)

    return all_klines[:max_candles]


def klines_to_df(klines: List[List]) -> pd.DataFrame:
    cols = [
        "open_time", "open", "high", "low", "close", "volume",
        "close_time", "quote_asset_volume", "number_of_trades",
        "taker_buy_base", "taker_buy_quote", "ignore",
    ]
    df = pd.DataFrame(klines, columns=cols)
    # Convert numeric columns
    num_cols = ["open", "high", "low", "close", "volume",
                "quote_asset_volume", "taker_buy_base", "taker_buy_quote"]
    df[num_cols] = df[num_cols].astype(float)
    df["number_of_trades"] = df["number_of_trades"].astype(int)

    # Timestamps
    df["open_time"] = pd.to_datetime(df["open_time"], unit="ms", utc=True)
    df["close_time"] = pd.to_datetime(df["close_time"], unit="ms", utc=True)
    df = df.set_index("close_time").sort_index()
    return df


def add_features(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()

    # Basic returns
    df["ret_1"] = df["close"].pct_change()
    df["log_ret_1"] = np.log(df["close"]).diff()

    # Volatility features
    df["volatility_10"] = df["ret_1"].rolling(10).std()
    df["volatility_20"] = df["ret_1"].rolling(20).std()

    # EMAs
    for span in (5, 10, 20, 50):
        df[f"ema_{span}"] = df["close"].ewm(span=span, adjust=False).mean()
        df[f"close_over_ema_{span}"] = df["close"] / df[f"ema_{span}"] - 1.0

    # RSI(14)
    win = 14
    delta = df["close"].diff()
    gain = (delta.clip(lower=0)).ewm(alpha=1/win, adjust=False).mean()
    loss = (-delta.clip(upper=0)).ewm(alpha=1/win, adjust=False).mean()
    rs = gain / (loss + 1e-12)
    df["rsi_14"] = 100 - (100 / (1 + rs))

    # MACD (12, 26, 9)
    ema12 = df["close"].ewm(span=12, adjust=False).mean()
    ema26 = df["close"].ewm(span=26, adjust=False).mean()
    df["macd"] = ema12 - ema26
    df["macd_signal"] = df["macd"].ewm(span=9, adjust=False).mean()
    df["macd_hist"] = df["macd"] - df["macd_signal"]

    # High-low range and liquidity proxies
    df["hl_range"] = (df["high"] - df["low"]) / df["close"]
    df["log_volume"] = np.log1p(df["volume"])  # to reduce skew

    # Time features (UTC)
    df["hour"] = df.index.tz_convert("UTC").hour
    df["dayofweek"] = df.index.tz_convert("UTC").dayofweek

    return df


def add_advanced_features(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    # Rolling returns
    for w in (3, 6, 12, 24):
        df[f"ret_{w}"] = df["close"].pct_change(w)
        df[f"logret_{w}"] = np.log(df["close"]).diff(w)

    # Bollinger Bands (20)
    sma20 = df["close"].rolling(20).mean()
    std20 = df["close"].rolling(20).std()
    upper = sma20 + 2 * std20
    lower = sma20 - 2 * std20
    df["bb_width"] = (upper - lower) / sma20
    df["bb_percent_b"] = (df["close"] - lower) / (upper - lower)

    # Stochastic oscillator (14)
    low14 = df["low"].rolling(14).min()
    high14 = df["high"].rolling(14).max()
    df["stoch_k"] = (df["close"] - low14) / (high14 - low14)
    df["stoch_d"] = df["stoch_k"].rolling(3).mean()

    # Donchian position (20)
    low20 = df["low"].rolling(20).min()
    high20 = df["high"].rolling(20).max()
    df["donch_pos_20"] = (df["close"] - low20) / (high20 - low20)

    # Candle shape features
    body = df["close"] - df["open"]
    upper_wick = df["high"] - df[["open", "close"]].max(axis=1)
    lower_wick = df[["open", "close"]].min(axis=1) - df["low"]
    tr = (df["high"] - df["low"]).replace(0, np.nan)
    df["body_pct_tr"] = (body / tr).clip(-5, 5)
    df["upper_pct_tr"] = (upper_wick / tr).clip(-5, 5)
    df["lower_pct_tr"] = (lower_wick / tr).clip(-5, 5)

    # Volume features
    df["vol_z20"] = (df["volume"] - df["volume"].rolling(20).mean()) / (df["volume"].rolling(20).std() + 1e-12)
    # OBV
    ret1 = df["close"].diff()
    obv = (np.sign(ret1).fillna(0) * df["volume"]).fillna(0).cumsum()
    df["obv"] = obv
    df["obv_z20"] = (obv - obv.rolling(20).mean()) / (obv.rolling(20).std() + 1e-12)

    return df


def add_labels(df: pd.DataFrame, threshold: float = 0.0) -> pd.DataFrame:
    """
    Label next-interval move.
    threshold: fractional move threshold. If >0, samples with |next_ret| <= threshold are dropped (target NaN).
    """
    df = df.copy()
    next_ret = df["close"].shift(-1) / df["close"] - 1.0
    if threshold > 0:
        target = pd.Series(np.nan, index=df.index)
        target[next_ret > threshold] = 1
        target[next_ret < -threshold] = 0
        df["target_up"] = target
    else:
        df["target_up"] = (next_ret > 0).astype(int)
    return df


def build_dataset(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.Series]:
    feature_cols = [
        # basic
        "ret_1", "log_ret_1", "volatility_10", "volatility_20",
        "close_over_ema_5", "close_over_ema_10", "close_over_ema_20", "close_over_ema_50",
        "rsi_14", "macd", "macd_signal", "macd_hist", "hl_range", "log_volume",
        "hour", "dayofweek",
        # advanced
        "ret_3", "ret_6", "ret_12", "ret_24",
        "logret_3", "logret_6", "logret_12", "logret_24",
        "bb_width", "bb_percent_b",
        "stoch_k", "stoch_d",
        "donch_pos_20",
        "body_pct_tr", "upper_pct_tr", "lower_pct_tr",
        "vol_z20", "obv_z20",
    ]
    df = df.dropna(subset=feature_cols + ["target_up"]).copy()
    X = df[feature_cols]
    y = df["target_up"].astype(int)
    return X, y


def three_way_time_split(X: pd.DataFrame, y: pd.Series, train_ratio=0.6, val_ratio=0.2):
    n = len(X)
    n_train = int(n * train_ratio)
    n_val = int(n * val_ratio)
    X_train = X.iloc[:n_train]
    y_train = y.iloc[:n_train]
    X_val = X.iloc[n_train:n_train + n_val]
    y_val = y.iloc[n_train:n_train + n_val]
    X_test = X.iloc[n_train + n_val:]
    y_test = y.iloc[n_train + n_val:]
    return X_train, y_train, X_val, y_val, X_test, y_test


def evaluate_with_threshold(y_true, proba, threshold=0.5):
    preds = (proba >= threshold).astype(int)
    acc = accuracy_score(y_true, preds)
    return acc, preds


def fit_and_select_model(X_tr, y_tr, X_val, y_val, seed=42):
    candidates = []
    # RandomForest small grid
    for max_depth in [None, 8, 16]:
        for min_leaf in [1, 5]:
            rf = RandomForestClassifier(
                n_estimators=400,
                max_depth=max_depth,
                min_samples_leaf=min_leaf,
                n_jobs=-1,
                random_state=seed,
                class_weight="balanced_subsample",
            )
            candidates.append((f"RF(d={max_depth},leaf={min_leaf})", rf))
    # ExtraTrees
    for max_depth in [None, 8, 16]:
        for min_leaf in [1, 5]:
            et = ExtraTreesClassifier(
                n_estimators=500,
                max_depth=max_depth,
                min_samples_leaf=min_leaf,
                n_jobs=-1,
                random_state=seed,
                class_weight="balanced",
            )
            candidates.append((f"ET(d={max_depth},leaf={min_leaf})", et))

    # HistGradientBoosting small grid
    for max_depth in [None, 8]:
        for lr in [0.05, 0.1]:
            hgb = HistGradientBoostingClassifier(
                learning_rate=lr,
                max_depth=max_depth,
                max_iter=300,
                random_state=seed,
                validation_fraction=None,
            )
            candidates.append((f"HGB(d={max_depth},lr={lr})", hgb))

    # Logistic Regression (with scaling)
    for C in [0.1, 1.0, 3.0]:
        logit = Pipeline([
            ("scaler", StandardScaler()),
            ("clf", LogisticRegression(max_iter=2000, solver="lbfgs", class_weight="balanced", C=C))
        ])
        candidates.append((f"LogReg(C={C})", logit))

    best = {
        "name": None,
        "model": None,
        "val_acc": -np.inf,
        "threshold": 0.5,
        "roc_auc": None,
    }

    thresholds = np.linspace(0.2, 0.8, 13)
    for name, model in candidates:
        model.fit(X_tr, y_tr)
        if hasattr(model, "predict_proba"):
            proba_val = model.predict_proba(X_val)[:, 1]
        elif hasattr(model, "decision_function"):
            # Map decision_function to [0,1] via logistic
            raw = model.decision_function(X_val)
            proba_val = 1.0 / (1.0 + np.exp(-raw))
        else:
            # Fallback to predictions (treat as proba)
            proba_val = model.predict(X_val)

        # Pick threshold maximizing validation accuracy
        local_best_acc = -np.inf
        local_best_thr = 0.5
        for thr in thresholds:
            acc, _ = evaluate_with_threshold(y_val, proba_val, thr)
            if acc > local_best_acc:
                local_best_acc = acc
                local_best_thr = thr

        try:
            roc = roc_auc_score(y_val, proba_val)
        except Exception:
            roc = None

        if local_best_acc > best["val_acc"]:
            best.update({
                "name": name,
                "model": model,
                "val_acc": local_best_acc,
                "threshold": float(local_best_thr),
                "roc_auc": None if roc is None else float(roc),
            })

    return best


def fit_with_tscv_and_select(X_trainval, y_trainval, seed=42, n_splits=3):
    tscv = TimeSeriesSplit(n_splits=n_splits)
    candidates = []

    # Define same candidate space as before
    for max_depth in [None, 8, 16]:
        for min_leaf in [1, 5]:
            rf = RandomForestClassifier(
                n_estimators=400,
                max_depth=max_depth,
                min_samples_leaf=min_leaf,
                n_jobs=-1,
                random_state=seed,
                class_weight="balanced_subsample",
            )
            candidates.append((f"RF(d={max_depth},leaf={min_leaf})", rf))
    for max_depth in [None, 8]:
        for lr in [0.05, 0.1]:
            hgb = HistGradientBoostingClassifier(
                learning_rate=lr,
                max_depth=max_depth,
                max_iter=300,
                random_state=seed,
                validation_fraction=None,
            )
            candidates.append((f"HGB(d={max_depth},lr={lr})", hgb))
    for C in [0.1, 1.0, 3.0]:
        logit = Pipeline([
            ("scaler", StandardScaler()),
            ("clf", LogisticRegression(max_iter=2000, solver="lbfgs", class_weight="balanced", C=C))
        ])
        candidates.append((f"LogReg(C={C})", logit))

    thresholds = np.linspace(0.2, 0.8, 13)
    best = {
        "name": None,
        "model": None,
        "mean_val_acc": -np.inf,
        "mean_thr": 0.5,
        "mean_roc_auc": None,
    }

    for name, model in candidates:
        accs = []
        thrs = []
        rocs = []
        for tr_idx, val_idx in tscv.split(X_trainval):
            X_tr, X_val = X_trainval.iloc[tr_idx], X_trainval.iloc[val_idx]
            y_tr, y_val = y_trainval.iloc[tr_idx], y_trainval.iloc[val_idx]
            model.fit(X_tr, y_tr)
            if hasattr(model, "predict_proba"):
                proba_val = model.predict_proba(X_val)[:, 1]
            elif hasattr(model, "decision_function"):
                raw = model.decision_function(X_val)
                proba_val = 1.0 / (1.0 + np.exp(-raw))
            else:
                proba_val = model.predict(X_val)
            # Threshold search per fold
            best_acc = -np.inf
            best_thr = 0.5
            for thr in thresholds:
                acc, _ = evaluate_with_threshold(y_val, proba_val, thr)
                if acc > best_acc:
                    best_acc = acc
                    best_thr = thr
            accs.append(best_acc)
            thrs.append(best_thr)
            try:
                rocs.append(roc_auc_score(y_val, proba_val))
            except Exception:
                pass

        mean_acc = float(np.mean(accs)) if accs else -np.inf
        mean_thr = float(np.mean(thrs)) if thrs else 0.5
        mean_roc = float(np.mean(rocs)) if rocs else None
        if mean_acc > best["mean_val_acc"]:
            # Refit on full trainval later using the same params
            best.update({
                "name": name,
                "model": model,
                "mean_val_acc": mean_acc,
                "mean_thr": mean_thr,
                "mean_roc_auc": mean_roc,
            })

    return best


def time_series_split(X: pd.DataFrame, y: pd.Series, test_size: float = 0.2) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]:
    n = len(X)
    n_test = int(math.ceil(n * test_size))
    n_train = n - n_test
    X_train, X_test = X.iloc[:n_train], X.iloc[n_train:]
    y_train, y_test = y.iloc[:n_train], y.iloc[n_train:]
    return X_train, X_test, y_train, y_test


def train_and_evaluate(X: pd.DataFrame, y: pd.Series, seed: int = 42, n_estimators: int = 300):
    # Three-way split
    X_tr, y_tr, X_val, y_val, X_te, y_te = three_way_time_split(X, y, 0.6, 0.2)

    # Use train+val (80%) for CV-based model selection
    X_trainval = pd.concat([X_tr, X_val])
    y_trainval = pd.concat([y_tr, y_val])
    best = fit_with_tscv_and_select(X_trainval, y_trainval, seed=seed, n_splits=3)
    best_model = best["model"]
    threshold = best.get("mean_thr", 0.5)

    # Retrain on train+val with best hyperparameters
    # Reuse the best_model instance and refit on train+val
    final_model = best_model
    X_trval = X_trainval
    y_trval = y_trainval
    final_model.fit(X_trval, y_trval)

    # Evaluate on test with tuned threshold
    if hasattr(final_model, "predict_proba"):
        proba_te = final_model.predict_proba(X_te)[:, 1]
    elif hasattr(final_model, "decision_function"):
        raw = final_model.decision_function(X_te)
        proba_te = 1.0 / (1.0 + np.exp(-raw))
    else:
        proba_te = final_model.predict(X_te)

    test_acc, y_pred = evaluate_with_threshold(y_te, proba_te, threshold)
    metrics = {
    "model": best["name"],
    "val_accuracy": float(best.get("mean_val_acc", np.nan)),
    "val_roc_auc": best.get("mean_roc_auc"),
    "threshold": float(threshold),
        "accuracy": float(test_acc),
        "confusion_matrix": confusion_matrix(y_te, y_pred).tolist(),
        "classification_report": classification_report(y_te, y_pred, output_dict=True),
    }
    try:
        metrics["roc_auc"] = float(roc_auc_score(y_te, proba_te))
    except Exception:
        pass

    # Feature importances for tree models
    if hasattr(final_model, "feature_importances_"):
        fi = getattr(final_model, "feature_importances_")
        metrics["feature_importances"] = {
            col: float(imp) for col, imp in sorted(zip(X.columns, fi), key=lambda x: x[1], reverse=True)[:20]
        }

    return final_model, metrics


def main():
    parser = argparse.ArgumentParser(description="Train a model to predict next 15m BTC move (up/down) using Binance data.")
    parser.add_argument("--symbol", default="BTCUSDT", help="Trading pair symbol (default: BTCUSDT)")
    parser.add_argument("--interval", default="15m", help="Kline interval (default: 15m)")
    parser.add_argument("--max-candles", type=int, default=3000, dest="max_candles", help="Max candles to fetch (default: 3000)")
    parser.add_argument("--start", type=str, default=None, help="Start date (YYYY-MM-DD or ISO). If omitted, fetches most recent candles.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
    parser.add_argument("--n-estimators", type=int, default=300, dest="n_estimators", help="RandomForest n_estimators (default: 300)")
    parser.add_argument("--label-threshold-bps", type=float, default=0.0, dest="label_threshold_bps", help="Min move in bps to label up/down; smaller moves are dropped (default: 0)")
    parser.add_argument("--model-out", type=str, default="btc_15m_model.pkl", dest="model_out", help="Path to save trained model (joblib)")
    parser.add_argument("--csv-out", type=str, default=None, dest="csv_out", help="Optional path to save the augmented dataset as CSV")
    parser.add_argument("--no-save", action="store_true", help="Do not save the trained model")

    args = parser.parse_args()

    start_ms = parse_date_to_ms(args.start)

    print(f"Fetching klines: symbol={args.symbol}, interval={args.interval}, max_candles={args.max_candles}, start={args.start}")
    klines = fetch_binance_klines(
        symbol=args.symbol,
        interval=args.interval,
        max_candles=args.max_candles,
        start_time_ms=start_ms,
        end_time_ms=None,
    )
    if not klines or len(klines) < 100:
        print("Not enough data fetched to train (need at least 100 candles). Exiting.")
        sys.exit(1)

    df = klines_to_df(klines)
    df = add_features(df)
    df = add_advanced_features(df)
    label_thr = (args.label_threshold_bps or 0.0) / 10000.0
    df = add_labels(df, threshold=label_thr)

    X, y = build_dataset(df)
    if len(X) < 200:
        print("Not enough feature rows after preprocessing (need at least 200). Exiting.")
        sys.exit(1)

    print(f"Training on {len(X)} samples, features={list(X.columns)}")
    model, metrics = train_and_evaluate(X, y, seed=args.seed, n_estimators=args.n_estimators)

    print("\nMetrics:")
    print(f"Accuracy: {metrics.get('accuracy'):.4f}")
    if "roc_auc" in metrics:
        print(f"ROC AUC: {metrics['roc_auc']:.4f}")
    if metrics.get("model"):
        print(f"Best model: {metrics['model']}  |  Val Acc: {metrics.get('val_accuracy'):.4f}  |  Thr: {metrics.get('threshold'):.2f}")
    print("Confusion Matrix (rows=true, cols=pred):")
    print(np.array(metrics["confusion_matrix"]))

    print("\nClassification Report:")
    # Pretty print a compact report
    cr = metrics["classification_report"]
    for lbl in ["0", "1"]:
        if lbl in cr:
            print(f"Class {lbl}: precision={cr[lbl]['precision']:.3f}, recall={cr[lbl]['recall']:.3f}, f1={cr[lbl]['f1-score']:.3f}")
    if "macro avg" in cr:
        print(f"Macro avg: precision={cr['macro avg']['precision']:.3f}, recall={cr['macro avg']['recall']:.3f}, f1={cr['macro avg']['f1-score']:.3f}")
    if "weighted avg" in cr:
        print(f"Weighted avg: precision={cr['weighted avg']['precision']:.3f}, recall={cr['weighted avg']['recall']:.3f}, f1={cr['weighted avg']['f1-score']:.3f}")

    if args.csv_out:
        out_df = df.copy()
        out_df.to_csv(args.csv_out)
        print(f"Saved dataset with features to: {args.csv_out}")

    if not args.no_save:
        artifact = {
            "model": model,
            "features": list(X.columns),
            "symbol": args.symbol,
            "interval": args.interval,
            "trained_at": datetime.now(timezone.utc).isoformat(),
            "threshold": metrics.get("threshold", 0.5),
            "best_model_name": metrics.get("model"),
        }
        dump(artifact, args.model_out)
        print(f"Saved model to: {args.model_out}")


if __name__ == "__main__":
    main()
