"""ML Model Training — train/retrain advanced ensemble models per coin.

Menggunakan EnsemblePredictor dari app/engine/ml/ensemble.py yang mencakup:
  - TreeBoost: XGBoost/LightGBM/CatBoost/GradientBoosting ensemble
  - DirectionClassifier: calibrated Platt-scaling classifier
  - FeatureEngineer: 90+ features (momentum, trend, volatility, wavelet, fourier)

Optional:
  - Hyperparameter optimization (Optuna)
  - Feature selection (FAST/DEEP with SHAP)

Trained models disimpan ke disk (model_cache/) dan bisa dipakai predict tanpa retrain.

Usage:
    # Train semua crypto (top 20 by market cap)
    python3 scripts/ml_train.py

    # Train coin tertentu
    python3 scripts/ml_train.py --coin bitcoin,ethereum,solana

    # Train saham IDX
    python3 scripts/ml_train.py --asset-type stock

    # Dengan hyperparameter optimization (lebih lambat, lebih akurat)
    python3 scripts/ml_train.py --hyperopt --trials 100

    # Dengan DEEP feature selection (SHAP, 2-5 min/coin)
    python3 scripts/ml_train.py --feature-selection deep

    # Force retrain (abaikan timer)
    python3 scripts/ml_train.py --force

    # Timeframe tertentu
    python3 scripts/ml_train.py --timeframe 4h

    # Limit + quiet
    python3 scripts/ml_train.py --limit 10 --quiet

Cron example:
    # Retrain setiap hari jam 3 pagi
    0 3 * * * cd /path && .venv/bin/python scripts/ml_train.py --quiet >> logs/ml_train.log 2>&1
"""
from __future__ import annotations

import argparse
import os
import sys
import time

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from dotenv import load_dotenv
load_dotenv()


def main():
    parser = argparse.ArgumentParser(
        description='Train/retrain ML ensemble models per coin',
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('--asset-type', default='stock',
                        choices=['crypto', 'stock', 'stock_us'],
                        help='Asset type (default: stock)')
    parser.add_argument('--coin', type=str, default='',
                        help='Coin IDs tertentu, comma-separated (e.g. bitcoin,ethereum)')
    parser.add_argument('--limit', type=int, default=20,
                        help='Max jumlah coin untuk train (default: 20)')
    parser.add_argument('--timeframe', default='1D',
                        choices=['1h', '4h', '1D'],
                        help='Timeframe OHLCV (default: 1D)')
    parser.add_argument('--force', action='store_true',
                        help='Force retrain (abaikan retrain timer)')
    parser.add_argument('--hyperopt', action='store_true',
                        help='Enable Optuna hyperparameter optimization')
    parser.add_argument('--trials', type=int, default=50,
                        help='Jumlah Optuna trials (default: 50)')
    parser.add_argument('--feature-selection', default='fast',
                        choices=['none', 'fast', 'deep'],
                        help='Feature selection method (default: fast)')
    parser.add_argument('--quiet', '-q', action='store_true',
                        help='Quiet mode — hanya summary')
    args = parser.parse_args()

    from app import create_app
    app = create_app()

    with app.app_context():
        from app.helpers.market_db import switch_market_schema
        switch_market_schema(args.asset_type)

        from app.extensions import db
        from app.models.coin import Coin
        from app.models.ohlcv import OHLCVData
        from app.engine.ml.feature_eng import FeatureEngineer
        from app.engine.ml.ensemble import EnsemblePredictor
        import numpy as np
        import pandas as pd

        start_time = time.time()

        # ── Resolve coin list ──
        if args.coin:
            coin_ids = [c.strip() for c in args.coin.split(',') if c.strip()]
            # Prepend namespace if needed
            coin_ids = [cid if '.' in cid else f'COIN.{cid}' for cid in coin_ids]
        else:
            asset_map = {'crypto': 'crypto', 'stock': 'stock', 'stock_us': 'stock_us'}
            coins = (
                Coin.query
                .filter_by(asset_type=asset_map[args.asset_type], is_active=True)
                .order_by(
                    db.case((Coin.market_cap_rank.is_(None), 1), else_=0),
                    Coin.market_cap_rank.asc(),
                )
                .limit(args.limit)
                .all()
            )
            coin_ids = [c.id for c in coins]

        if not coin_ids:
            print('⊘  Tidak ada coin untuk ditrain.')
            sys.exit(0)

        total = len(coin_ids)
        trained = 0
        skipped = 0
        failed = 0
        metrics_list = []

        if not args.quiet:
            print(f'\n🧠  ML Model Training')
            print(f'{"─" * 70}')
            print(f'  Coins     : {total}')
            print(f'  Timeframe : {args.timeframe}')
            print(f'  Hyperopt  : {"Ya (" + str(args.trials) + " trials)" if args.hyperopt else "Tidak"}')
            print(f'  Feature   : {args.feature_selection}')
            print(f'  Force     : {"Ya" if args.force else "Tidak"}')
            print()

        fe = FeatureEngineer()

        for idx, coin_id in enumerate(coin_ids, 1):
            coin = db.session.get(Coin, coin_id)
            symbol = coin.symbol.upper() if coin else coin_id
            name = coin.name if coin else coin_id

            try:
                # Load OHLCV data
                ohlcv_rows = (
                    OHLCVData.query
                    .filter_by(coin_id=coin_id, timeframe=args.timeframe)
                    .order_by(OHLCVData.timestamp.asc())
                    .all()
                )

                if len(ohlcv_rows) < 100:
                    skipped += 1
                    if not args.quiet:
                        print(f'  ⊘  [{idx}/{total}] {symbol:<10} — insufficient data ({len(ohlcv_rows)} rows, need 100+)')
                    continue

                # Build DataFrame
                df = pd.DataFrame([{
                    'open': float(r.open or 0),
                    'high': float(r.high or 0),
                    'low': float(r.low or 0),
                    'close': float(r.close or 0),
                    'volume': float(r.volume or 0),
                } for r in ohlcv_rows])

                if df['close'].sum() == 0:
                    skipped += 1
                    if not args.quiet:
                        print(f'  ⊘  [{idx}/{total}] {symbol:<10} — all zero prices')
                    continue

                # Generate features
                t0 = time.time()
                feature_df = fe.generate(df)

                if len(feature_df) < 60:
                    skipped += 1
                    if not args.quiet:
                        print(f'  ⊘  [{idx}/{total}] {symbol:<10} — insufficient features ({len(feature_df)} rows)')
                    continue

                closes = df['close'].values[-len(feature_df):]

                # Create predictor and train
                predictor = EnsemblePredictor()

                if not args.force and not predictor.needs_training():
                    # Try loading existing model
                    if predictor.load_models(coin_id, args.timeframe):
                        skipped += 1
                        if not args.quiet:
                            print(f'  ⊘  [{idx}/{total}] {symbol:<10} — model masih fresh (use --force)')
                        continue

                use_hyperopt = args.hyperopt
                use_feat_sel = args.feature_selection != 'none'

                predictor.train(
                    feature_df,
                    closes,
                    use_hyperopt=use_hyperopt,
                    use_feature_selection=use_feat_sel,
                )

                # Save model to cache
                predictor.save_models(coin_id, args.timeframe)

                elapsed_coin = time.time() - t0
                trained += 1

                # Extract training metrics
                train_metrics = getattr(predictor, 'last_train_metrics', {}) or {}
                mae = train_metrics.get('mae', 0)
                dir_acc = train_metrics.get('directional_accuracy', 0)

                metrics_list.append({
                    'symbol': symbol,
                    'mae': mae,
                    'dir_acc': dir_acc,
                    'features': len(feature_df.columns),
                    'samples': len(feature_df),
                    'time': elapsed_coin,
                })

                if not args.quiet:
                    print(f'  ✓  [{idx}/{total}] {symbol:<10} '
                          f'MAE={mae:.4f}  DirAcc={dir_acc:.1f}%  '
                          f'feat={len(feature_df.columns)}  rows={len(feature_df)}  '
                          f'({elapsed_coin:.1f}s)')

            except Exception as e:
                failed += 1
                if not args.quiet:
                    print(f'  ✗  [{idx}/{total}] {symbol:<10} — {e}')

        # ── Summary ──
        elapsed = time.time() - start_time
        mins = int(elapsed // 60)
        secs = int(elapsed % 60)

        print(f'\n{"═" * 70}')
        print(f'🧠  Training Complete — {mins}m {secs}s')
        print(f'    Trained: {trained}  |  Skipped: {skipped}  |  Failed: {failed}  |  Total: {total}')

        if metrics_list:
            avg_mae = sum(m['mae'] for m in metrics_list) / len(metrics_list)
            avg_dir = sum(m['dir_acc'] for m in metrics_list) / len(metrics_list)
            avg_time = sum(m['time'] for m in metrics_list) / len(metrics_list)
            print(f'    Avg MAE: {avg_mae:.4f}  |  Avg Dir Accuracy: {avg_dir:.1f}%  |  Avg time/coin: {avg_time:.1f}s')

            # Best and worst
            if len(metrics_list) > 1:
                best = max(metrics_list, key=lambda x: x['dir_acc'])
                worst = min(metrics_list, key=lambda x: x['dir_acc'])
                print(f'    Best:  {best["symbol"]} ({best["dir_acc"]:.1f}% accuracy)')
                print(f'    Worst: {worst["symbol"]} ({worst["dir_acc"]:.1f}% accuracy)')

        print(f'{"═" * 70}')

        sys.exit(1 if failed > 0 else 0)


if __name__ == '__main__':
    main()
