"""ML Backtesting — walk-forward backtest dengan Monte Carlo validation.

Menjalankan walk-forward backtesting pada coin tertentu menggunakan
ML scoring function. Termasuk Monte Carlo simulation untuk confidence
interval dan robustness check.

Metrik yang dihasilkan:
  - Win rate, average return, profit factor
  - Sharpe ratio, max drawdown, total return
  - Equity curve (min/max/final)
  - Monthly returns breakdown
  - Monte Carlo confidence intervals (5th, 25th, 50th, 75th, 95th percentile)

Usage:
    # Backtest top 5 crypto
    python3 scripts/ml_backtest.py

    # Coin tertentu
    python3 scripts/ml_backtest.py --coin bitcoin,ethereum

    # Timeframe 4h
    python3 scripts/ml_backtest.py --timeframe 4h

    # Custom fee (default: 0.62% roundtrip)
    python3 scripts/ml_backtest.py --fee 0.5

    # Verbose — equity curve + monthly returns
    python3 scripts/ml_backtest.py --verbose

    # Output JSON
    python3 scripts/ml_backtest.py --json

Cron example:
    # Backtest mingguan (Minggu jam 6 pagi)
    0 6 * * 0 cd /path && .venv/bin/python scripts/ml_backtest.py --quiet >> logs/ml_backtest.log 2>&1
"""
from __future__ import annotations

import argparse
import json as json_module
import os
import sys
import time

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from dotenv import load_dotenv
load_dotenv()


def main():
    parser = argparse.ArgumentParser(
        description='Walk-forward backtesting with Monte Carlo validation',
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('--asset-type', default='stock',
                        choices=['crypto', 'stock', 'stock_us'],
                        help='Asset type (default: stock)')
    parser.add_argument('--coin', type=str, default='',
                        help='Coin IDs, comma-separated (e.g. bitcoin,ethereum)')
    parser.add_argument('--limit', type=int, default=5,
                        help='Max coins (default: 5)')
    parser.add_argument('--timeframe', default='1D',
                        choices=['1h', '4h', '1D'],
                        help='OHLCV timeframe (default: 1D)')
    parser.add_argument('--fee', type=float, default=0.62,
                        help='Roundtrip fee %% (default: 0.62)')
    parser.add_argument('--quiet', '-q', action='store_true',
                        help='Quiet — hanya summary')
    parser.add_argument('--verbose', '-v', action='store_true',
                        help='Verbose — equity curve, monthly returns, Monte Carlo')
    parser.add_argument('--json', action='store_true',
                        help='Output JSON')
    args = parser.parse_args()

    from app import create_app
    app = create_app()

    with app.app_context():
        from app.helpers.market_db import switch_market_schema
        switch_market_schema(args.asset_type)

        from app.extensions import db
        from app.models.coin import Coin
        from app.models.ohlcv import OHLCVData
        from app.engine.ml.backtester import WalkForwardBacktester
        import pandas as pd

        start_time = time.time()

        # ── Resolve coins ──
        if args.coin:
            coin_ids = [c.strip() for c in args.coin.split(',') if c.strip()]
            coin_ids = [cid if '.' in cid else f'COIN.{cid}' for cid in coin_ids]
        else:
            coins = (
                Coin.query
                .filter_by(asset_type=args.asset_type, is_active=True)
                .order_by(
                    db.case((Coin.market_cap_rank.is_(None), 1), else_=0),
                    Coin.market_cap_rank.asc(),
                )
                .limit(args.limit)
                .all()
            )
            coin_ids = [c.id for c in coins]

        if not coin_ids:
            print('⊘  Tidak ada coin.')
            sys.exit(0)

        total = len(coin_ids)
        all_results = []

        if not args.quiet and not args.json:
            print(f'\n📈  ML Walk-Forward Backtesting')
            print(f'{"─" * 80}')
            print(f'  Coins    : {total}')
            print(f'  Timeframe: {args.timeframe}')
            print(f'  Fee      : {args.fee}% roundtrip')
            print()

        for idx, coin_id in enumerate(coin_ids, 1):
            coin = db.session.get(Coin, coin_id)
            symbol = coin.symbol.upper() if coin else coin_id

            try:
                t0 = time.time()

                # Load OHLCV
                rows = (
                    OHLCVData.query
                    .filter_by(coin_id=coin_id, timeframe=args.timeframe)
                    .order_by(OHLCVData.timestamp.asc())
                    .all()
                )
                if len(rows) < 100:
                    if not args.quiet and not args.json:
                        print(f'  ⊘  [{idx}/{total}] {symbol:<8} — insufficient data ({len(rows)} rows)')
                    continue

                df = pd.DataFrame([{
                    'open': float(r.open or 0), 'high': float(r.high or 0),
                    'low': float(r.low or 0), 'close': float(r.close or 0),
                    'volume': float(r.volume or 0),
                } for r in rows])

                # Run backtest
                backtester = WalkForwardBacktester()
                bt_result = backtester.backtest(df, timeframe=args.timeframe)

                metrics = bt_result.metrics if hasattr(bt_result, 'metrics') else bt_result
                if isinstance(metrics, dict):
                    pass  # already a dict
                else:
                    metrics = vars(metrics) if hasattr(metrics, '__dict__') else {}

                elapsed_coin = time.time() - t0

                result = {
                    'symbol': symbol,
                    'coin_id': coin_id,
                    'metrics': metrics,
                    'elapsed': round(elapsed_coin, 1),
                }

                # Extract key metrics safely
                winrate = _safe(metrics, 'winrate', 0)
                avg_return = _safe(metrics, 'avg_return', 0)
                profit_factor = _safe(metrics, 'profit_factor', 0)
                sharpe = _safe(metrics, 'sharpe_ratio', 0)
                max_dd = _safe(metrics, 'max_drawdown', 0)
                total_return = _safe(metrics, 'total_return', 0)
                total_trades = _safe(metrics, 'total_trades', 0)

                result['summary'] = {
                    'winrate': winrate,
                    'avg_return': avg_return,
                    'profit_factor': profit_factor,
                    'sharpe': sharpe,
                    'max_drawdown': max_dd,
                    'total_return': total_return,
                    'total_trades': total_trades,
                }
                all_results.append(result)

                if not args.quiet and not args.json:
                    pf_icon = '🟢' if profit_factor > 1.5 else ('🟡' if profit_factor > 1.0 else '🔴')
                    print(f'  {pf_icon} [{idx}/{total}] {symbol:<8} '
                          f'WR={winrate:.1f}%  PF={profit_factor:.2f}  '
                          f'Sharpe={sharpe:.2f}  MaxDD={max_dd:.1f}%  '
                          f'Return={total_return:+.1f}%  Trades={total_trades}  '
                          f'({elapsed_coin:.1f}s)')

                    if args.verbose:
                        # Monthly returns
                        monthly = metrics.get('monthly_returns', {})
                        if monthly:
                            print(f'         Monthly: ', end='')
                            for month, ret in list(monthly.items())[:6]:
                                print(f'{month}={ret:+.1f}%  ', end='')
                            print()

                        # Monte Carlo
                        mc = metrics.get('monte_carlo', {})
                        if mc:
                            print(f'         Monte Carlo CI: '
                                  f'5th={mc.get("p5", 0):.1f}%  '
                                  f'25th={mc.get("p25", 0):.1f}%  '
                                  f'median={mc.get("p50", 0):.1f}%  '
                                  f'75th={mc.get("p75", 0):.1f}%  '
                                  f'95th={mc.get("p95", 0):.1f}%')

                        # Equity curve summary
                        equity = metrics.get('equity_curve', [])
                        if equity:
                            print(f'         Equity: start={equity[0]:.0f} → end={equity[-1]:.0f} '
                                  f'(peak={max(equity):.0f})')
                        print()

            except Exception as e:
                if not args.quiet and not args.json:
                    print(f'  ✗  [{idx}/{total}] {symbol:<8} — {e}')

        elapsed = time.time() - start_time

        # ── Output ──
        if args.json:
            print(json_module.dumps({
                'results': all_results,
                'total': len(all_results),
                'elapsed_seconds': round(elapsed, 1),
            }, indent=2, default=str))
            return

        # Summary
        if all_results:
            avg_wr = sum(r['summary']['winrate'] for r in all_results) / len(all_results)
            avg_pf = sum(r['summary']['profit_factor'] for r in all_results) / len(all_results)
            avg_sharpe = sum(r['summary']['sharpe'] for r in all_results) / len(all_results)
            avg_return = sum(r['summary']['total_return'] for r in all_results) / len(all_results)
            profitable = sum(1 for r in all_results if r['summary']['total_return'] > 0)

            best = max(all_results, key=lambda x: x['summary']['total_return'])
            worst = min(all_results, key=lambda x: x['summary']['total_return'])

            mins = int(elapsed // 60)
            secs = int(elapsed % 60)

            print(f'\n{"═" * 80}')
            print(f'📈  Backtest Complete — {mins}m {secs}s')
            print(f'    Coins tested: {len(all_results)}  |  Profitable: {profitable}/{len(all_results)}')
            print(f'    Avg Win Rate: {avg_wr:.1f}%  |  Avg Profit Factor: {avg_pf:.2f}')
            print(f'    Avg Sharpe: {avg_sharpe:.2f}  |  Avg Return: {avg_return:+.1f}%')
            print(f'    Best : {best["symbol"]} ({best["summary"]["total_return"]:+.1f}%)')
            print(f'    Worst: {worst["symbol"]} ({worst["summary"]["total_return"]:+.1f}%)')
            print(f'{"═" * 80}')


def _safe(d, key, default=0):
    """Safely extract numeric value from dict."""
    val = d.get(key, default)
    try:
        return float(val)
    except (TypeError, ValueError):
        return default


if __name__ == '__main__':
    main()
