"""Volatility Surface Service — analyzes volatility structure across all assets.

Computes annualized historical volatility from sparkline data (7d hourly),
classifies vol regime, ranks assets by volatility, and detects vol trend
(expanding/contracting).  Also includes Bollinger Band width and ATR-based
volatility proxies from precomputed range-trading scores.

Data sources (all lazy-imported):
- Asset: id, symbol, name, icon_thumb_url, asset_type, market_cap_rank, is_active
- AssetProfile: current_price_idr, market_cap_idr, total_volume_idr,
  price_change_1h/24h/7d/30d, ath_idr, atl_idr (latest per asset)
- RangeTradingScore: bb_width, atr_pct, range_width_pct, zscore
- BullishMomentumScore: rsi_latest, velocity_short_pct
- Sparkline (from profile_json): 168 hourly prices for 7d
"""
from __future__ import annotations

import logging
import math
from datetime import datetime
from typing import Optional

logger = logging.getLogger(__name__)

# ── Constants ────────────────────────────────────────────────────────
HOURS_PER_YEAR = 24 * 365  # ~8760
HOURS_PER_DAY = 24

# Vol regime thresholds (annualized %)
VOL_REGIME_THRESHOLDS = {
    'low': 30,
    'normal': 60,
    'high': 100,
    # > 100 = extreme
}

# Recent window (hours) for vol trend detection
RECENT_WINDOW = 48  # last 48h vs full 168h


def _safe_float(val, default: float = 0.0) -> float:
    """Safely convert a DB Numeric / Decimal to float."""
    if val is None:
        return default
    try:
        f = float(val)
        return default if (math.isnan(f) or math.isinf(f)) else f
    except (ValueError, TypeError):
        return default


def _compute_returns(prices: list[float]) -> list[float]:
    """Compute hourly log returns from a list of prices."""
    if len(prices) < 2:
        return []
    returns = []
    for i in range(1, len(prices)):
        if prices[i - 1] > 0 and prices[i] > 0:
            returns.append(math.log(prices[i] / prices[i - 1]))
    return returns


def _std_dev(values: list[float]) -> float:
    """Compute standard deviation of a list of floats."""
    n = len(values)
    if n < 2:
        return 0.0
    mean = sum(values) / n
    variance = sum((x - mean) ** 2 for x in values) / (n - 1)
    return math.sqrt(variance) if variance > 0 else 0.0


def _mean(values: list[float]) -> float:
    """Compute mean of a list of floats."""
    if not values:
        return 0.0
    return sum(values) / len(values)


def _classify_vol_regime(ann_vol_pct: float) -> str:
    """Classify annualized vol into regime bucket."""
    if ann_vol_pct < VOL_REGIME_THRESHOLDS['low']:
        return 'low'
    elif ann_vol_pct < VOL_REGIME_THRESHOLDS['normal']:
        return 'normal'
    elif ann_vol_pct < VOL_REGIME_THRESHOLDS['high']:
        return 'high'
    return 'extreme'


def _percentile_rank(value: float, sorted_values: list[float]) -> float:
    """Compute percentile rank (0-100) of value within sorted list."""
    n = len(sorted_values)
    if n == 0:
        return 50.0
    count_below = sum(1 for v in sorted_values if v < value)
    return round((count_below / n) * 100, 1)


class VolatilitySurfaceService:
    """Analyzes volatility structure across all assets using sparkline data
    and precomputed range-trading metrics."""

    # ──────────────────────────────────────────────────────────────
    #  Public API
    # ──────────────────────────────────────────────────────────────

    def scan_all(
        self,
        asset_type: str = 'crypto',
        limit: int = 50,
        page: int = 1,
        sort_by: str = 'vol_desc',
        min_vol: Optional[float] = None,
        max_vol: Optional[float] = None,
        regime_filter: Optional[str] = None,
    ) -> dict:
        """Analyze volatility structure across all active assets.

        Parameters
        ----------
        asset_type : str
            'crypto', 'stock', 'stock_us', or 'all'.
        limit : int
            Items per page.
        page : int
            Page number (1-based).
        sort_by : str
            'vol_desc' (highest first), 'vol_asc', 'rank', 'bb_width'.
        min_vol / max_vol : float | None
            Filter by annualized vol range.
        regime_filter : str | None
            'low', 'normal', 'high', 'extreme' — only show this regime.

        Returns
        -------
        dict with 'items', 'total', 'page', 'limit', 'has_more', 'stats'.
        """
        # ── 1. Load data ──────────────────────────────────────────
        coins_map = self._load_coins(asset_type)
        profiles_map = self._load_latest_profiles(asset_type)
        range_map = self._load_range_scores(asset_type)
        bullish_map = self._load_bullish_scores(asset_type)

        if not coins_map:
            return self._empty_result(page, limit)

        # ── 2. Compute per-asset volatility ────────────────────────
        records: list[dict] = []
        all_vols: list[float] = []

        for asset_id, asset in coins_map.items():
            profile = profiles_map.get(asset_id)
            range_sc = range_map.get(asset_id)
            bullish = bullish_map.get(asset_id)

            # Load sparkline from helper
            sparkline = self._get_sparkline(asset)
            if len(sparkline) < 24:
                # Need at minimum 24 hours of data for meaningful vol calc
                continue

            # Compute returns
            returns = _compute_returns(sparkline)
            if len(returns) < 12:
                continue

            # Full-period historical vol (annualized)
            hourly_std = _std_dev(returns)
            ann_vol = hourly_std * math.sqrt(HOURS_PER_YEAR)
            ann_vol_pct = round(ann_vol * 100, 2)

            # Intraday vol: use only last 24h of returns
            recent_24h = returns[-24:] if len(returns) >= 24 else returns
            intraday_std = _std_dev(recent_24h)
            intraday_vol_pct = round(intraday_std * math.sqrt(HOURS_PER_DAY) * 100, 2)

            # Vol trend: compare recent window vs historical
            recent_n = min(RECENT_WINDOW, len(returns) // 2)
            if recent_n > 6:
                recent_returns = returns[-recent_n:]
                historical_returns = returns[:-recent_n]
                recent_std = _std_dev(recent_returns)
                hist_std = _std_dev(historical_returns) if len(historical_returns) > 2 else hourly_std
                vol_ratio = (recent_std / hist_std) if hist_std > 0 else 1.0
                vol_trend = 'expanding' if vol_ratio > 1.15 else ('contracting' if vol_ratio < 0.85 else 'stable')
                vol_trend_ratio = round(vol_ratio, 3)
            else:
                vol_trend = 'stable'
                vol_trend_ratio = 1.0

            # Vol regime classification
            vol_regime = _classify_vol_regime(ann_vol_pct)

            # BB width from precomputed range score
            bb_width = _safe_float(range_sc.bb_width) if range_sc else 0.0
            atr_pct = _safe_float(range_sc.atr_pct) if range_sc else 0.0
            range_width = _safe_float(range_sc.range_width_pct) if range_sc else 0.0
            zscore = _safe_float(range_sc.zscore) if range_sc else 0.0

            # Additional metrics
            rsi = _safe_float(bullish.rsi_latest, 50.0) if bullish else 50.0
            price = _safe_float(profile.current_price_idr) if profile else 0.0
            change_24h = _safe_float(profile.price_change_24h) if profile else 0.0
            change_7d = _safe_float(profile.price_change_7d) if profile else 0.0
            market_cap = _safe_float(profile.market_cap_idr) if profile else 0.0

            # Realized vol from price changes (simple sanity check)
            # High/low range from sparkline
            sparkline_high = max(sparkline) if sparkline else price
            sparkline_low = min(sparkline) if sparkline else price
            sparkline_range_pct = 0.0
            if sparkline_low > 0:
                sparkline_range_pct = round(
                    ((sparkline_high - sparkline_low) / sparkline_low) * 100, 2
                )

            # Mean hourly return
            mean_return = _mean(returns)
            drift_annualized = round(mean_return * HOURS_PER_YEAR * 100, 2)

            # Skewness proxy (asymmetry of returns)
            if len(returns) > 10:
                mean_r = _mean(returns)
                std_r = _std_dev(returns)
                if std_r > 0:
                    skew = sum((r - mean_r) ** 3 for r in returns) / (len(returns) * std_r ** 3)
                    skew = round(skew, 3)
                else:
                    skew = 0.0
            else:
                skew = 0.0

            all_vols.append(ann_vol_pct)

            records.append({
                'asset_id': asset_id,
                'symbol': asset.symbol,
                'name': asset.name,
                'icon_thumb_url': asset.icon_thumb_url or '',
                'asset_type': asset.asset_type or 'crypto',
                'market_cap_rank': asset.market_cap_rank or 9999,
                'current_price_idr': price,
                'market_cap_idr': market_cap,
                # Core vol metrics
                'annualized_vol_pct': ann_vol_pct,
                'intraday_vol_pct': intraday_vol_pct,
                'vol_regime': vol_regime,
                'vol_rank': 0,  # placeholder — ranked later
                'vol_trend': vol_trend,
                'vol_trend_ratio': vol_trend_ratio,
                # Band/range metrics
                'bb_width': round(bb_width, 6),
                'atr_pct': round(atr_pct, 4),
                'range_width_pct': round(range_width, 2),
                'sparkline_range_pct': sparkline_range_pct,
                # Additional
                'zscore': round(zscore, 3),
                'rsi': round(rsi, 1),
                'change_24h': round(change_24h, 2),
                'change_7d': round(change_7d, 2),
                'drift_annualized_pct': drift_annualized,
                'return_skewness': skew,
                'vol_percentile': 0.0,  # computed later
                'sparkline_points': len(sparkline),
            })

        if not records:
            return self._empty_result(page, limit)

        # ── 3. Compute ranks and percentiles ──────────────────────
        all_vols_sorted = sorted(all_vols)
        records.sort(key=lambda r: r['annualized_vol_pct'], reverse=True)

        for rank_idx, rec in enumerate(records, start=1):
            rec['vol_rank'] = rank_idx
            rec['vol_percentile'] = _percentile_rank(
                rec['annualized_vol_pct'], all_vols_sorted
            )

        # ── 4. Apply filters ──────────────────────────────────────
        filtered = records
        if min_vol is not None:
            filtered = [r for r in filtered if r['annualized_vol_pct'] >= min_vol]
        if max_vol is not None:
            filtered = [r for r in filtered if r['annualized_vol_pct'] <= max_vol]
        if regime_filter:
            filtered = [r for r in filtered if r['vol_regime'] == regime_filter]

        # ── 5. Sort ───────────────────────────────────────────────
        sort_configs = {
            'vol_desc': (lambda r: r['annualized_vol_pct'], True),
            'vol_asc': (lambda r: r['annualized_vol_pct'], False),
            'rank': (lambda r: r['market_cap_rank'], False),
            'bb_width': (lambda r: r['bb_width'], True),
            'atr': (lambda r: r['atr_pct'], True),
            'trend_ratio': (lambda r: r['vol_trend_ratio'], True),
        }
        key_fn, reverse = sort_configs.get(sort_by, sort_configs['vol_desc'])
        filtered.sort(key=key_fn, reverse=reverse)

        # ── 6. Build stats (pre-pagination) ───────────────────────
        stats = self._build_stats(filtered, all_vols_sorted)

        # ── 7. Paginate ───────────────────────────────────────────
        total = len(filtered)
        offset = (page - 1) * limit
        page_items = filtered[offset:offset + limit]
        has_more = (offset + limit) < total

        return {
            'items': page_items,
            'total': total,
            'page': page,
            'limit': limit,
            'has_more': has_more,
            'stats': stats,
        }

    # ──────────────────────────────────────────────────────────────
    #  Private helpers — data loading
    # ──────────────────────────────────────────────────────────────

    def _load_coins(self, asset_type: str = 'all') -> dict:
        """Load all active Asset objects."""
        from app.models.asset import Asset

        query = Asset.query.filter(Asset.is_active.is_(True))
        if asset_type and asset_type != 'all':
            query = query.filter(Asset.asset_type == asset_type)
        return {c.id: c for c in query.all()}

    def _load_latest_profiles(self, asset_type: str = 'all') -> dict:
        """Load latest AssetProfile per asset via max(id) subquery."""
        from app.extensions import db
        from app.models.asset import Asset, AssetProfile
        from sqlalchemy import func as sa_func

        coin_q = db.session.query(Asset.id).filter(Asset.is_active.is_(True))
        if asset_type and asset_type != 'all':
            coin_q = coin_q.filter(Asset.asset_type == asset_type)
        asset_ids_sq = coin_q.subquery()

        latest_sq = db.session.query(
            AssetProfile.asset_id,
            sa_func.max(AssetProfile.id).label('max_id'),
        ).filter(
            AssetProfile.asset_id.in_(db.session.query(asset_ids_sq.c.id)),
        ).group_by(AssetProfile.asset_id).subquery()

        profiles = AssetProfile.query.join(
            latest_sq, AssetProfile.id == latest_sq.c.max_id,
        ).all()
        return {p.asset_id: p for p in profiles}

    def _load_range_scores(self, asset_type: str = 'all') -> dict:
        """Load all RangeTradingScore rows."""
        from app.models.range_score import RangeTradingScore
        from app.models.asset import Asset

        query = RangeTradingScore.query.join(
            Asset, Asset.id == RangeTradingScore.asset_id,
        ).filter(Asset.is_active.is_(True))
        if asset_type and asset_type != 'all':
            query = query.filter(Asset.asset_type == asset_type)
        return {r.asset_id: r for r in query.all()}

    def _load_bullish_scores(self, asset_type: str = 'all') -> dict:
        """Load all BullishMomentumScore rows."""
        from app.models.bullish_score import BullishMomentumScore
        from app.models.asset import Asset

        query = BullishMomentumScore.query.join(
            Asset, Asset.id == BullishMomentumScore.asset_id,
        ).filter(Asset.is_active.is_(True))
        if asset_type and asset_type != 'all':
            query = query.filter(Asset.asset_type == asset_type)
        return {b.asset_id: b for b in query.all()}

    def _get_sparkline(self, asset) -> list[float]:
        """Load sparkline prices via shared helper."""
        from app.helpers.sparkline import load_sparkline
        return load_sparkline(asset)

    # ──────────────────────────────────────────────────────────────
    #  Private helpers — stats
    # ──────────────────────────────────────────────────────────────

    def _build_stats(self, records: list[dict], all_vols_sorted: list[float]) -> dict:
        """Build aggregate stats from all filtered records."""
        if not records:
            return {
                'total_analyzed': 0,
                'avg_vol_pct': 0.0,
                'median_vol_pct': 0.0,
                'regime_distribution': {},
                'trend_distribution': {},
                'highest_vol': None,
                'lowest_vol': None,
                'expanding_count': 0,
                'contracting_count': 0,
            }

        vols = [r['annualized_vol_pct'] for r in records]
        vols_sorted = sorted(vols)
        n = len(vols_sorted)
        median_vol = vols_sorted[n // 2] if n > 0 else 0.0

        # Regime distribution
        regime_dist = {'low': 0, 'normal': 0, 'high': 0, 'extreme': 0}
        trend_dist = {'expanding': 0, 'contracting': 0, 'stable': 0}

        for rec in records:
            regime_dist[rec['vol_regime']] = regime_dist.get(rec['vol_regime'], 0) + 1
            trend_dist[rec['vol_trend']] = trend_dist.get(rec['vol_trend'], 0) + 1

        # Top/bottom
        highest = records[0] if records else None
        # Find the one with lowest vol
        lowest = min(records, key=lambda r: r['annualized_vol_pct']) if records else None

        return {
            'total_analyzed': len(records),
            'avg_vol_pct': round(_mean(vols), 2),
            'median_vol_pct': round(median_vol, 2),
            'regime_distribution': regime_dist,
            'trend_distribution': trend_dist,
            'highest_vol': {
                'symbol': highest['symbol'],
                'vol_pct': highest['annualized_vol_pct'],
                'regime': highest['vol_regime'],
            } if highest else None,
            'lowest_vol': {
                'symbol': lowest['symbol'],
                'vol_pct': lowest['annualized_vol_pct'],
                'regime': lowest['vol_regime'],
            } if lowest else None,
            'expanding_count': trend_dist.get('expanding', 0),
            'contracting_count': trend_dist.get('contracting', 0),
            'p25_vol_pct': round(vols_sorted[n // 4], 2) if n >= 4 else 0.0,
            'p75_vol_pct': round(vols_sorted[3 * n // 4], 2) if n >= 4 else 0.0,
        }

    # ──────────────────────────────────────────────────────────────
    #  Empty result
    # ──────────────────────────────────────────────────────────────

    def _empty_result(self, page: int, limit: int) -> dict:
        """Return a valid but empty response."""
        return {
            'items': [],
            'total': 0,
            'page': page,
            'limit': limit,
            'has_more': False,
            'stats': {
                'total_analyzed': 0,
                'avg_vol_pct': 0.0,
                'median_vol_pct': 0.0,
                'regime_distribution': {'low': 0, 'normal': 0, 'high': 0, 'extreme': 0},
                'trend_distribution': {'expanding': 0, 'contracting': 0, 'stable': 0},
                'highest_vol': None,
                'lowest_vol': None,
                'expanding_count': 0,
                'contracting_count': 0,
            },
        }
