"""ML Feature Engineering — generate & analyze 90+ technical features.

Menjalankan FeatureEngineer untuk menghasilkan features dari OHLCV data.
Berguna untuk debug, quality check, dan analisis distribusi features.

Feature categories (90+ total):
  - Price (6): price_change, price_range, body_size, wicks, is_bullish
  - Momentum (8): RSI (7,14,21), ROC, Williams %R, Stochastic K/D
  - Trend (9): SMA/EMA (7,14,21,50), MACD, ADX (+/-)
  - Volatility (5): Bollinger Bands (width, %B), ATR, ATR %
  - Volume (5): OBV, volume_ratio, volume_momentum, VWAP
  - Lags (12): close_lag & return_lag (1,2,4,8,16,24)
  - Rolling Stats (20): mean/std/min/max/zscore × 4 windows (4,8,16,24)
  - Microstructure (4): Amihud, spread, flow_imbalance, Hurst
  - Wavelet (5): detail, energy (3 levels), trend, detrended
  - Fourier (4): dominant_freq, dominant_power, spectral_entropy, dominant_period

Usage:
    # Generate features untuk top 5 crypto, tampilkan stats
    python3 scripts/ml_feature_eng.py

    # Coin tertentu
    python3 scripts/ml_feature_eng.py --coin bitcoin,ethereum

    # Tampilkan korelasi antar features
    python3 scripts/ml_feature_eng.py --coin bitcoin --show-corr

    # Tampilkan distribusi (mean, std, min, max)
    python3 scripts/ml_feature_eng.py --coin bitcoin --show-dist

    # Tampilkan NaN report
    python3 scripts/ml_feature_eng.py --show-nan

    # Export features ke CSV
    python3 scripts/ml_feature_eng.py --coin bitcoin --export /tmp/features.csv

    # Output JSON
    python3 scripts/ml_feature_eng.py --json
"""
from __future__ import annotations

import argparse
import json as json_module
import os
import sys
import time

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from dotenv import load_dotenv
load_dotenv()


def main():
    parser = argparse.ArgumentParser(
        description='Generate & analyze 90+ ML features from OHLCV data',
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('--asset-type', default='stock',
                        choices=['crypto', 'stock', 'stock_us'],
                        help='Asset type (default: stock)')
    parser.add_argument('--coin', type=str, default='',
                        help='Coin IDs, comma-separated')
    parser.add_argument('--limit', type=int, default=5,
                        help='Max coins (default: 5)')
    parser.add_argument('--timeframe', default='1D',
                        choices=['1h', '4h', '1D'],
                        help='OHLCV timeframe (default: 1D)')
    parser.add_argument('--show-corr', action='store_true',
                        help='Tampilkan top correlated feature pairs')
    parser.add_argument('--show-dist', action='store_true',
                        help='Tampilkan distribusi features (mean, std, min, max)')
    parser.add_argument('--show-nan', action='store_true',
                        help='Tampilkan NaN report per feature')
    parser.add_argument('--export', type=str, default='',
                        help='Export features ke CSV file path')
    parser.add_argument('--quiet', '-q', action='store_true',
                        help='Quiet — hanya summary')
    parser.add_argument('--json', action='store_true',
                        help='Output JSON')
    args = parser.parse_args()

    from app import create_app
    app = create_app()

    with app.app_context():
        from app.helpers.market_db import switch_market_schema
        switch_market_schema(args.asset_type)

        from app.extensions import db
        from app.models.coin import Coin
        from app.models.ohlcv import OHLCVData
        from app.engine.ml.feature_eng import FeatureEngineer
        import pandas as pd
        import numpy as np

        start_time = time.time()

        # ── Resolve coins ──
        if args.coin:
            coin_ids = [c.strip() for c in args.coin.split(',') if c.strip()]
            coin_ids = [cid if '.' in cid else f'COIN.{cid}' for cid in coin_ids]
        else:
            coins = (
                Coin.query
                .filter_by(asset_type=args.asset_type, is_active=True)
                .order_by(
                    db.case((Coin.market_cap_rank.is_(None), 1), else_=0),
                    Coin.market_cap_rank.asc(),
                )
                .limit(args.limit)
                .all()
            )
            coin_ids = [c.id for c in coins]

        if not coin_ids:
            print('⊘  Tidak ada coin.')
            sys.exit(0)

        total = len(coin_ids)
        all_summaries = []

        if not args.quiet and not args.json:
            print(f'\n🔧  Feature Engineering — 90+ features × {total} coins')
            print(f'{"─" * 80}')

        fe = FeatureEngineer()

        for idx, coin_id in enumerate(coin_ids, 1):
            coin = db.session.get(Coin, coin_id)
            symbol = coin.symbol.upper() if coin else coin_id

            try:
                t0 = time.time()

                # Load OHLCV
                rows = (
                    OHLCVData.query
                    .filter_by(coin_id=coin_id, timeframe=args.timeframe)
                    .order_by(OHLCVData.timestamp.asc())
                    .all()
                )
                if len(rows) < 100:
                    if not args.quiet and not args.json:
                        print(f'  ⊘  [{idx}/{total}] {symbol:<8} — insufficient data ({len(rows)} rows)')
                    continue

                df = pd.DataFrame([{
                    'open': float(r.open or 0), 'high': float(r.high or 0),
                    'low': float(r.low or 0), 'close': float(r.close or 0),
                    'volume': float(r.volume or 0),
                } for r in rows])

                # Generate features
                feature_df = fe.generate(df)
                elapsed_coin = time.time() - t0

                n_features = len(feature_df.columns)
                n_rows = len(feature_df)
                nan_count = feature_df.isna().sum().sum()
                nan_pct = nan_count / (n_features * n_rows) * 100 if n_rows > 0 else 0

                summary = {
                    'symbol': symbol,
                    'coin_id': coin_id,
                    'ohlcv_rows': len(rows),
                    'feature_rows': n_rows,
                    'feature_cols': n_features,
                    'nan_count': int(nan_count),
                    'nan_pct': round(nan_pct, 2),
                    'elapsed': round(elapsed_coin, 3),
                    'features': list(feature_df.columns),
                }
                all_summaries.append(summary)

                if not args.quiet and not args.json:
                    quality = '🟢' if nan_pct < 1 else ('🟡' if nan_pct < 5 else '🔴')
                    print(f'  {quality} [{idx}/{total}] {symbol:<8} '
                          f'{n_features} features × {n_rows} rows  '
                          f'NaN={nan_pct:.1f}%  '
                          f'({elapsed_coin:.2f}s)')

                # ── Detailed analysis ──
                if not args.quiet and not args.json and (args.show_dist or args.show_corr or args.show_nan):
                    if args.show_dist:
                        print(f'\n         {"Feature":<30} {"Mean":>10} {"Std":>10} {"Min":>10} {"Max":>10}')
                        print(f'         {"─" * 30} {"─" * 10} {"─" * 10} {"─" * 10} {"─" * 10}')
                        for col in feature_df.columns[:20]:
                            s = feature_df[col]
                            print(f'         {col:<30} {s.mean():>10.4f} {s.std():>10.4f} '
                                  f'{s.min():>10.4f} {s.max():>10.4f}')
                        if n_features > 20:
                            print(f'         ... dan {n_features - 20} features lainnya')

                    if args.show_nan:
                        nan_cols = feature_df.isna().sum()
                        nan_cols = nan_cols[nan_cols > 0].sort_values(ascending=False)
                        if len(nan_cols) > 0:
                            print(f'\n         NaN per feature:')
                            for col, cnt in nan_cols.items():
                                print(f'           {col:<30} {cnt:>5} ({cnt / n_rows * 100:.1f}%)')
                        else:
                            print(f'\n         ✅ No NaN values')

                    if args.show_corr:
                        corr = feature_df.corr().abs()
                        # Get top correlated pairs (excluding self)
                        pairs = []
                        for i in range(len(corr.columns)):
                            for j in range(i + 1, len(corr.columns)):
                                pairs.append((corr.columns[i], corr.columns[j], corr.iloc[i, j]))
                        pairs.sort(key=lambda x: x[2], reverse=True)
                        print(f'\n         Top 10 correlated pairs:')
                        for a, b, c in pairs[:10]:
                            print(f'           {a:<25} ↔ {b:<25} r={c:.3f}')
                    print()

                # ── Export CSV ──
                if args.export and len(coin_ids) == 1:
                    feature_df.to_csv(args.export, index=False)
                    if not args.quiet:
                        print(f'  📁  Exported to {args.export}')

            except Exception as e:
                if not args.quiet and not args.json:
                    print(f'  ✗  [{idx}/{total}] {symbol:<8} — {e}')

        elapsed = time.time() - start_time

        # ── Output ──
        if args.json:
            print(json_module.dumps({
                'summaries': all_summaries,
                'total': len(all_summaries),
                'elapsed_seconds': round(elapsed, 1),
            }, indent=2))
            return

        # Summary
        if all_summaries:
            total_features = all_summaries[0]['feature_cols'] if all_summaries else 0
            avg_rows = sum(s['feature_rows'] for s in all_summaries) / len(all_summaries)
            avg_nan = sum(s['nan_pct'] for s in all_summaries) / len(all_summaries)
            avg_time = sum(s['elapsed'] for s in all_summaries) / len(all_summaries)

            # Categorize features
            categories = _categorize_features(all_summaries[0]['features'])

            print(f'\n{"═" * 80}')
            print(f'🔧  Feature Engineering Complete — {elapsed:.1f}s')
            print(f'    Coins: {len(all_summaries)}  |  Features: {total_features}')
            print(f'    Avg rows: {avg_rows:.0f}  |  Avg NaN: {avg_nan:.1f}%  |  Avg time: {avg_time:.2f}s')
            if categories:
                cats = ', '.join(f'{k}={v}' for k, v in sorted(categories.items()))
                print(f'    Categories: {cats}')
            print(f'{"═" * 80}')


def _categorize_features(features):
    """Categorize features by prefix/pattern."""
    cats = {}
    for f in features:
        if 'rsi' in f or 'roc' in f or 'williams' in f or 'stoch' in f:
            cats['momentum'] = cats.get('momentum', 0) + 1
        elif 'sma' in f or 'ema' in f or 'macd' in f or 'adx' in f:
            cats['trend'] = cats.get('trend', 0) + 1
        elif 'bollinger' in f or 'atr' in f or 'volatility' in f or 'vol_' in f:
            cats['volatility'] = cats.get('volatility', 0) + 1
        elif 'obv' in f or 'vwap' in f or 'volume' in f:
            cats['volume'] = cats.get('volume', 0) + 1
        elif 'lag' in f:
            cats['lags'] = cats.get('lags', 0) + 1
        elif 'rolling' in f or 'zscore' in f:
            cats['rolling'] = cats.get('rolling', 0) + 1
        elif 'wavelet' in f:
            cats['wavelet'] = cats.get('wavelet', 0) + 1
        elif 'fft' in f or 'fourier' in f or 'spectral' in f:
            cats['fourier'] = cats.get('fourier', 0) + 1
        elif 'amihud' in f or 'spread' in f or 'hurst' in f or 'flow' in f:
            cats['microstructure'] = cats.get('microstructure', 0) + 1
        else:
            cats['price'] = cats.get('price', 0) + 1
    return cats


if __name__ == '__main__':
    main()
