"""ML Full Pipeline — train + predict dalam satu langkah per coin.

Pipeline lengkap:
  1. Load OHLCV data
  2. Generate 90+ features (FeatureEngineer)
  3. Detect market regime (RegimeDetector)
  4. Train EnsemblePredictor (TreeBoost + DirectionClassifier)
  5. Predict multi-horizon (short/medium/long)
  6. Risk check (RiskManager)
  7. Output actionable summary

Ini adalah versi CLI dari halaman /admin/ml-ensemble.

Usage:
    # Full pipeline untuk top 10 crypto
    python3 scripts/ml_ensemble.py

    # Coin tertentu
    python3 scripts/ml_ensemble.py --coin bitcoin,ethereum

    # Dengan output detail (regime, risk, horizons)
    python3 scripts/ml_ensemble.py --verbose

    # Saham IDX
    python3 scripts/ml_ensemble.py --asset-type stock --limit 10

    # Output JSON
    python3 scripts/ml_ensemble.py --json

Cron example:
    # Full pipeline setiap 4 jam
    0 */4 * * * cd /path && .venv/bin/python scripts/ml_ensemble.py --quiet >> logs/ml_ensemble.log 2>&1
"""
from __future__ import annotations

import argparse
import json as json_module
import os
import sys
import time

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from dotenv import load_dotenv
load_dotenv()


def main():
    parser = argparse.ArgumentParser(
        description='Full ML pipeline: features → regime → train → predict → risk check',
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('--asset-type', default='stock',
                        choices=['crypto', 'stock', 'stock_us'],
                        help='Asset type (default: stock)')
    parser.add_argument('--coin', type=str, default='',
                        help='Coin IDs, comma-separated (e.g. bitcoin,ethereum)')
    parser.add_argument('--limit', type=int, default=10,
                        help='Max coins (default: 10)')
    parser.add_argument('--timeframe', default='1D',
                        choices=['1h', '4h', '1D'],
                        help='OHLCV timeframe (default: 1D)')
    parser.add_argument('--quiet', '-q', action='store_true',
                        help='Quiet — hanya summary')
    parser.add_argument('--verbose', '-v', action='store_true',
                        help='Verbose — regime, risk details, horizons')
    parser.add_argument('--json', action='store_true',
                        help='Output JSON')
    args = parser.parse_args()

    from app import create_app
    app = create_app()

    with app.app_context():
        from app.helpers.market_db import switch_market_schema
        switch_market_schema(args.asset_type)

        from app.extensions import db
        from app.models.coin import Coin
        from app.models.ohlcv import OHLCVData
        from app.engine.ml.feature_eng import FeatureEngineer
        from app.engine.ml.ensemble import EnsemblePredictor
        from app.engine.ml.regime import RegimeDetector
        from app.engine.ml.risk_manager import RiskManager
        import numpy as np
        import pandas as pd

        start_time = time.time()

        # ── Resolve coins ──
        if args.coin:
            coin_ids = [c.strip() for c in args.coin.split(',') if c.strip()]
            coin_ids = [cid if '.' in cid else f'COIN.{cid}' for cid in coin_ids]
        else:
            coins = (
                Coin.query
                .filter_by(asset_type=args.asset_type, is_active=True)
                .order_by(
                    db.case((Coin.market_cap_rank.is_(None), 1), else_=0),
                    Coin.market_cap_rank.asc(),
                )
                .limit(args.limit)
                .all()
            )
            coin_ids = [c.id for c in coins]

        if not coin_ids:
            print('⊘  Tidak ada coin.')
            sys.exit(0)

        total = len(coin_ids)
        results = []
        success = 0
        failed = 0

        if not args.quiet and not args.json:
            print(f'\n🔬  ML Full Pipeline — {total} coins × {args.timeframe}')
            print(f'{"─" * 80}')

        fe = FeatureEngineer()
        regime_det = RegimeDetector()
        risk_mgr = RiskManager()

        for idx, coin_id in enumerate(coin_ids, 1):
            coin = db.session.get(Coin, coin_id)
            symbol = coin.symbol.upper() if coin else coin_id

            try:
                t0 = time.time()

                # 1. Load OHLCV
                rows = (
                    OHLCVData.query
                    .filter_by(coin_id=coin_id, timeframe=args.timeframe)
                    .order_by(OHLCVData.timestamp.asc())
                    .all()
                )
                if len(rows) < 100:
                    if not args.quiet and not args.json:
                        print(f'  ⊘  [{idx}/{total}] {symbol:<8} — insufficient data ({len(rows)} rows)')
                    continue

                df = pd.DataFrame([{
                    'open': float(r.open or 0), 'high': float(r.high or 0),
                    'low': float(r.low or 0), 'close': float(r.close or 0),
                    'volume': float(r.volume or 0),
                } for r in rows])

                # 2. Feature Engineering
                feature_df = fe.generate(df)
                if len(feature_df) < 60:
                    if not args.quiet and not args.json:
                        print(f'  ⊘  [{idx}/{total}] {symbol:<8} — insufficient features')
                    continue

                closes = df['close'].values[-len(feature_df):]

                # 3. Regime Detection
                regime = regime_det.detect(df)

                # 4. Train + Predict
                predictor = EnsemblePredictor()
                predictor.load_models(coin_id, args.timeframe)  # try cache first
                if predictor.needs_training():
                    predictor.train(feature_df, closes, use_hyperopt=False, use_feature_selection=True)
                    predictor.save_models(coin_id, args.timeframe)

                predictions = predictor.predict(feature_df, closes, n_steps=48)

                # 5. Risk Check
                direction = predictions.get('direction', {})
                risk = risk_mgr.should_trade(
                    coin_id=coin_id,
                    signal=direction.get('direction', 'hold'),
                    score=50,
                    confidence=direction.get('probability', 0.5),
                    regime=regime.get('regime', 'unknown'),
                    direction_probability=direction.get('probability', 0.5),
                )

                elapsed_coin = time.time() - t0

                result = {
                    'symbol': symbol,
                    'coin_id': coin_id,
                    'regime': regime.get('regime', '?'),
                    'regime_desc': regime.get('description', ''),
                    'direction': direction.get('direction', '?'),
                    'probability': round(direction.get('probability', 0) * 100, 1),
                    'risk_allowed': risk.get('allowed', False),
                    'risk_reasons': risk.get('reasons', []),
                    'horizons': predictions.get('horizons', {}),
                    'train_metrics': predictions.get('training_metrics', {}),
                    'elapsed': round(elapsed_coin, 1),
                }
                results.append(result)
                success += 1

                if not args.quiet and not args.json:
                    dir_icon = '📈' if direction.get('direction') == 'up' else '📉'
                    risk_icon = '✅' if risk.get('allowed') else '⛔'
                    print(f'  {dir_icon} [{idx}/{total}] {symbol:<8} '
                          f'regime={regime.get("regime", "?"):<14} '
                          f'dir={direction.get("direction", "?"):<4} '
                          f'prob={result["probability"]:5.1f}%  '
                          f'{risk_icon} trade={"YES" if risk.get("allowed") else "NO":<3}  '
                          f'({elapsed_coin:.1f}s)')

                    if args.verbose:
                        # Regime details
                        print(f'         Hurst={regime.get("hurst_exponent", 0):.3f}  '
                              f'ADX={regime.get("adx", 0):.1f}  '
                              f'VolRank={regime.get("volatility_rank", 0):.0f}%')
                        # Horizons
                        for hz_name, hz_data in (predictions.get('horizons', {}) or {}).items():
                            if isinstance(hz_data, dict):
                                print(f'         {hz_name}: {hz_data.get("predicted_return", 0):+.2f}%')
                        # Risk reasons
                        if risk.get('reasons'):
                            for reason in risk['reasons'][:3]:
                                print(f'         ⚠ {reason}')
                        print()

            except Exception as e:
                failed += 1
                if not args.quiet and not args.json:
                    print(f'  ✗  [{idx}/{total}] {symbol:<8} — {e}')

        elapsed = time.time() - start_time

        # ── Output ──
        if args.json:
            print(json_module.dumps({
                'results': results,
                'total': len(results),
                'elapsed_seconds': round(elapsed, 1),
            }, indent=2, default=str))
            return

        # Summary
        buy_count = sum(1 for r in results if r['direction'] == 'up')
        sell_count = sum(1 for r in results if r['direction'] == 'down')
        tradeable = sum(1 for r in results if r['risk_allowed'])
        avg_prob = sum(r['probability'] for r in results) / max(len(results), 1)

        regime_counts = {}
        for r in results:
            regime_counts[r['regime']] = regime_counts.get(r['regime'], 0) + 1

        mins = int(elapsed // 60)
        secs = int(elapsed % 60)

        print(f'\n{"═" * 80}')
        print(f'🔬  Pipeline Complete — {mins}m {secs}s')
        print(f'    Processed: {success}  |  Failed: {failed}  |  Total: {total}')
        print(f'    Direction: {buy_count} bullish, {sell_count} bearish  |  Avg probability: {avg_prob:.1f}%')
        print(f'    Tradeable: {tradeable}/{len(results)} (passed risk check)')
        print(f'    Regimes  : {", ".join(f"{k}={v}" for k, v in sorted(regime_counts.items()))}')
        print(f'{"═" * 80}')

        sys.exit(1 if failed > 0 else 0)


if __name__ == '__main__':
    main()
