"""ML Ensemble Predictions — prediksi arah harga menggunakan 3 model ML.

Menjalankan RandomForest + GradientBoosting + Ridge Regression ensemble
untuk semua coin aktif atau coin tertentu. Cocok untuk cron job harian.

Models:
  - Random Forest (35% weight) — classification
  - Gradient Boosting (40% weight) — classification
  - Ridge Regression (25% weight) — magnitude

Usage:
    # Prediksi semua crypto (default, top 50)
    python3 scripts/ml_predict.py

    # Prediksi saham IDX
    python3 scripts/ml_predict.py --asset-type stock

    # Coin tertentu
    python3 scripts/ml_predict.py --symbol BTC,ETH,SOL

    # Limit jumlah dan sort
    python3 scripts/ml_predict.py --limit 20 --sort confidence

    # Hanya tampilkan strong signals
    python3 scripts/ml_predict.py --strong-only

    # Quiet mode (untuk cron)
    python3 scripts/ml_predict.py --quiet

    # Output JSON
    python3 scripts/ml_predict.py --json

Cron example:
    # Prediksi setiap 6 jam
    0 */6 * * * cd /path && .venv/bin/python scripts/ml_predict.py --quiet >> logs/ml_predict.log 2>&1
"""
from __future__ import annotations

import argparse
import json as json_module
import os
import sys
import time

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from dotenv import load_dotenv
load_dotenv()


def main():
    parser = argparse.ArgumentParser(
        description='Run ML ensemble predictions (RF + GB + Ridge)',
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('--asset-type', default='stock',
                        choices=['crypto', 'stock', 'stock_us'],
                        help='Asset type (default: stock)')
    parser.add_argument('--symbol', type=str, default='',
                        help='Symbols tertentu, comma-separated (e.g. BTC,ETH,SOL)')
    parser.add_argument('--limit', type=int, default=50,
                        help='Max jumlah coin (default: 50)')
    parser.add_argument('--sort', default='confidence',
                        choices=['confidence', 'signal', 'symbol'],
                        help='Sort by (default: confidence)')
    parser.add_argument('--strong-only', action='store_true',
                        help='Hanya tampilkan strong_buy / strong_sell')
    parser.add_argument('--min-confidence', type=int, default=0,
                        help='Minimum confidence %% (default: 0)')
    parser.add_argument('--quiet', '-q', action='store_true',
                        help='Quiet mode — hanya summary')
    parser.add_argument('--json', action='store_true',
                        help='Output JSON format')
    args = parser.parse_args()

    from app import create_app
    app = create_app()

    with app.app_context():
        from app.helpers.market_db import switch_market_schema
        switch_market_schema(args.asset_type)

        from app.services.ml_ensemble import MLEnsembleService
        svc = MLEnsembleService()

        start = time.time()

        if args.symbol:
            # Predict individual symbols
            symbols = [s.strip().upper() for s in args.symbol.split(',') if s.strip()]
            if not args.quiet and not args.json:
                print(f'\n🤖  ML Ensemble Predictions — {len(symbols)} symbols')
                print(f'{"─" * 70}')

            results = []
            for sym in symbols:
                r = svc.predict(sym, args.asset_type)
                if r.get('status') == 'success':
                    results.append(r['data'])
                elif not args.quiet and not args.json:
                    print(f'  ✗  {sym}: {r.get("message", "error")}')
        else:
            # Scan all
            if not args.quiet and not args.json:
                print(f'\n🤖  ML Ensemble Predictions — top {args.limit} {args.asset_type}')
                print(f'{"─" * 70}')

            resp = svc.scan_all(asset_type=args.asset_type, limit=args.limit)
            if resp.get('status') != 'success':
                print(f'✗  Error: {resp.get("message", "unknown")}')
                sys.exit(1)
            results = resp.get('data', [])

        # ── Filter ──
        if args.strong_only:
            results = [r for r in results if r.get('signal', '').startswith('strong_')]
        if args.min_confidence > 0:
            results = [r for r in results if r.get('confidence', 0) >= args.min_confidence]

        # ── Sort ──
        if args.sort == 'confidence':
            results.sort(key=lambda x: x.get('confidence', 0), reverse=True)
        elif args.sort == 'signal':
            order = {'strong_buy': 0, 'buy': 1, 'weak_buy': 2, 'weak_sell': 3, 'sell': 4, 'strong_sell': 5}
            results.sort(key=lambda x: order.get(x.get('signal', ''), 99))
        elif args.sort == 'symbol':
            results.sort(key=lambda x: x.get('symbol', ''))

        elapsed = time.time() - start

        # ── Output ──
        if args.json:
            print(json_module.dumps({
                'results': results,
                'total': len(results),
                'elapsed_seconds': round(elapsed, 1),
            }, indent=2))
            return

        if not args.quiet:
            # Table header
            print(f'\n  {"#":>3}  {"Symbol":<8} {"Direction":>9}  {"Conf":>5}  {"Signal":<12} '
                  f'{"RF":>5}  {"GB":>5}  {"Ridge":>6}  {"Agree":>5}')
            print(f'  {"─" * 3}  {"─" * 8} {"─" * 9}  {"─" * 5}  {"─" * 12} '
                  f'{"─" * 5}  {"─" * 5}  {"─" * 6}  {"─" * 5}')

            for i, r in enumerate(results, 1):
                direction = r.get('predicted_direction', '?')
                confidence = r.get('confidence', 0)
                signal = r.get('signal', '?')
                models = r.get('models', {})
                rf_acc = models.get('random_forest', {}).get('accuracy', 0)
                gb_acc = models.get('gradient_boost', {}).get('accuracy', 0)
                ridge_mag = models.get('ridge_regression', {}).get('magnitude', 0)
                agreement = r.get('agreement', 0)

                # Color indicators
                dir_icon = '📈' if direction == 'up' else '📉'
                sig_icon = _signal_icon(signal)

                print(f'  {i:3d}  {r.get("symbol", "?"):<8} {dir_icon} {direction:>5}  '
                      f'{confidence:4d}%  {sig_icon} {signal:<10} '
                      f'{rf_acc:4.0f}%  {gb_acc:4.0f}%  {ridge_mag:+5.1f}%  {agreement}/3')

        # Summary
        buy_count = sum(1 for r in results if 'buy' in r.get('signal', ''))
        sell_count = sum(1 for r in results if 'sell' in r.get('signal', ''))
        strong_buy = sum(1 for r in results if r.get('signal') == 'strong_buy')
        strong_sell = sum(1 for r in results if r.get('signal') == 'strong_sell')
        avg_conf = sum(r.get('confidence', 0) for r in results) / max(len(results), 1)
        avg_agree = sum(r.get('agreement', 0) for r in results) / max(len(results), 1)

        print(f'\n{"═" * 70}')
        print(f'📊  {len(results)} predictions in {elapsed:.1f}s')
        print(f'    BUY: {buy_count} ({strong_buy} strong)  |  '
              f'SELL: {sell_count} ({strong_sell} strong)')
        print(f'    Avg confidence: {avg_conf:.0f}%  |  Avg agreement: {avg_agree:.1f}/3')
        print(f'{"═" * 70}')


def _signal_icon(signal):
    icons = {
        'strong_buy': '🟢',
        'buy': '🟩',
        'weak_buy': '⬜',
        'weak_sell': '⬜',
        'sell': '🟥',
        'strong_sell': '🔴',
    }
    return icons.get(signal, '⬜')


if __name__ == '__main__':
    main()
