"""VolatilityRankService — Feature 137: Ranks assets by volatility.

Classifies volatility using ATR-based data from range scores (resistance -
support as ATR proxy), price change magnitude, and range width percentage.

Data sources (all lazy-imported):
- Asset: id, symbol, name, asset_type, is_active, icon_thumb_url
- AssetProfile: current_price_idr, price_change_24h/7d/30d, total_volume_idr
- RangeTradingScore: nearest_support/resistance, adaptive_support/resistance,
  oscillation_score, bb_width, atr_pct, range_width_pct,
  is_mean_reverting, mr_confidence, current_price_idr
- BullishMomentumScore: score, momentum_state, velocity_short_pct,
  velocity_medium_pct, bb_position, current_price_idr, rsi_latest, zscore
- TradingSignal: signal_type, confidence, score, regime, status='active'

Scoring (0-100) — normalized volatility rank:
  ATR Component       35pts — ATR percentage (range width as % of price)
  Price Swing          25pts — 24h/7d price change magnitude
  Range Width          20pts — absolute range width percentage
  Regime Indicator     20pts — BB width, momentum state, velocity

Volatility Regime:
  EXTREME    (> 80)  — very high volatility, potentially dangerous
  HIGH       (60-80) — elevated volatility, active trading opportunities
  MODERATE   (40-60) — normal market volatility
  LOW        (20-40) — below-average volatility, consolidation
  DORMANT    (< 20)  — very low volatility, potential squeeze setup

Historical Volatility Trend:
  EXPANDING    — volatility increasing over recent period
  STABLE       — volatility relatively unchanged
  CONTRACTING  — volatility decreasing, potential squeeze

Response format:
  {'items': [...], 'total': int, 'page': int, 'limit': int,
   'has_more': bool, 'stats': {...}}
"""
from __future__ import annotations

import logging
import math
from datetime import datetime

logger = logging.getLogger(__name__)


def _safe_float(val, default: float = 0.0) -> float:
    """Safely convert a DB Numeric / Decimal to float."""
    if val is None:
        return default
    try:
        f = float(val)
        return default if (math.isnan(f) or math.isinf(f)) else f
    except (ValueError, TypeError):
        return default


class VolatilityRankService:
    """Ranks all tracked assets by volatility and classifies into regimes.

    Produces a normalized volatility score (0-100) using ATR data,
    price swings, range width, and regime indicators.
    """

    WEIGHTS = {
        'atr_component': 35,
        'price_swing': 25,
        'range_width': 20,
        'regime_indicator': 20,
    }

    # Regime thresholds
    REGIME_THRESHOLDS = [
        (80, 'EXTREME'),
        (60, 'HIGH'),
        (40, 'MODERATE'),
        (20, 'LOW'),
        (0, 'DORMANT'),
    ]

    # ──────────────────────────────────────────────────────────────────
    #  Public API
    # ──────────────────────────────────────────────────────────────────

    def scan_all(
        self,
        asset_type: str = 'crypto',
        limit: int = 50,
        page: int = 1,
        sort_by: str = 'volatility_desc',
        regime: str = 'all',
        search_q: str = '',
    ) -> dict:
        """Scan all active assets and rank by volatility.

        Args:
            asset_type: 'crypto', 'stock', 'stock_us', or 'all'
            limit: Max items per page
            page: 1-based page number
            sort_by: 'volatility_desc' or 'name_asc'
            regime: 'all', 'extreme', 'high', 'moderate', 'low', 'dormant'
            search_q: Search filter on name/symbol

        Returns:
            dict with items, total, page, limit, has_more, stats
        """
        from app.extensions import db
        from app.models.range_score import RangeTradingScore
        from app.models.bullish_score import BullishMomentumScore
        from app.models.signal import TradingSignal
        from app.models.asset import Asset, AssetProfile
        from sqlalchemy import func as sa_func

        # ── 1. Bulk-load precomputed data ──────────────────────────
        range_scores = self._load_range_scores()
        bullish_scores = self._load_bullish_scores()
        active_signals = self._load_active_signals()

        all_asset_ids = (
            set(range_scores.keys())
            | set(bullish_scores.keys())
            | set(active_signals.keys())
        )
        if not all_asset_ids:
            return self._empty_response(page, limit)

        # ── 2. Load assets (optionally filtered by asset_type) ──────
        coin_query = Asset.query.filter(
            Asset.id.in_(list(all_asset_ids)),
            Asset.is_active.is_(True),
        )
        if asset_type and asset_type != 'all':
            coin_query = coin_query.filter(Asset.asset_type == asset_type)
        coins_map = {c.id: c for c in coin_query.all()}

        # ── 3. Latest AssetProfile per asset (max(id) subquery) ──────
        latest_sq = db.session.query(
            AssetProfile.asset_id,
            sa_func.max(AssetProfile.id).label('max_id'),
        ).filter(
            AssetProfile.asset_id.in_(list(all_asset_ids)),
        ).group_by(AssetProfile.asset_id).subquery()

        profiles = AssetProfile.query.join(
            latest_sq, AssetProfile.id == latest_sq.c.max_id,
        ).all()
        profiles_map = {p.asset_id: p for p in profiles}

        # ── 4. Score each asset ─────────────────────────────────────
        candidates = []
        for asset_id in all_asset_ids:
            asset = coins_map.get(asset_id)
            if not asset:
                continue

            rd = range_scores.get(asset_id)
            bd = bullish_scores.get(asset_id)
            sd = active_signals.get(asset_id)
            profile = profiles_map.get(asset_id)

            # Resolve current price
            current_price = self._resolve_price(profile, bd, rd)
            if current_price is None or current_price <= 0:
                continue

            # Compute volatility analysis
            result = self._compute_volatility(rd, bd, sd, profile, current_price)
            if result is None:
                continue

            # Apply regime filter
            if regime != 'all' and result['regime'].lower() != regime.lower():
                continue

            # Apply search filter
            if search_q:
                sq = search_q.lower()
                if sq not in asset.name.lower() and sq not in asset.symbol.lower():
                    continue

            # Price changes
            price_change_24h = _safe_float(profile.price_change_24h) if profile and profile.price_change_24h is not None else None
            price_change_7d = _safe_float(profile.price_change_7d) if profile and profile.price_change_7d is not None else None
            price_change_30d = _safe_float(profile.price_change_30d) if profile and profile.price_change_30d is not None else None

            # Daily range from 24h high/low
            daily_range = self._calc_daily_range(profile, current_price)

            # Volume
            volume_24h = _safe_float(profile.total_volume_idr) if profile and profile.total_volume_idr is not None else None

            # Momentum state
            momentum_state = bd.momentum_state if bd else None

            # Regime from signal
            signal_regime = sd.regime if sd else None

            candidates.append({
                'asset_id': asset_id,
                'name': asset.name,
                'symbol': asset.symbol,
                'icon_thumb_url': asset.icon_thumb_url,
                'asset_type': asset.asset_type,
                'current_price_idr': current_price,
                'price_change_24h': price_change_24h,
                'price_change_7d': price_change_7d,
                'price_change_30d': price_change_30d,
                'volatility_score': result['volatility_score'],
                'regime': result['regime'],
                'atr_pct': result['atr_pct'],
                'daily_range': daily_range,
                'historical_vol_trend': result['historical_vol_trend'],
                'components': result['components'],
                'bb_width': result.get('bb_width'),
                'range_width_pct': result.get('range_width_pct'),
                'momentum_state': momentum_state,
                'signal_regime': signal_regime,
                'volume_24h': volume_24h,
                'mean_reverting': result.get('mean_reverting', False),
                'data_sources': {
                    'range_trading': rd is not None,
                    'bullish_momentum': bd is not None,
                    'trading_signal': sd is not None,
                },
            })

        # ── 5. Sort ────────────────────────────────────────────────
        if sort_by == 'name_asc':
            candidates.sort(key=lambda c: c['name'].lower())
        else:  # volatility_desc (default)
            candidates.sort(key=lambda c: c['volatility_score'], reverse=True)

        # ── 6. Stats ───────────────────────────────────────────────
        stats = self._compute_stats(candidates)

        # ── 7. Paginate ────────────────────────────────────────────
        total = len(candidates)
        offset = (page - 1) * limit
        page_items = candidates[offset:offset + limit]

        return {
            'items': page_items,
            'total': total,
            'page': page,
            'limit': limit,
            'has_more': (offset + limit) < total,
            'stats': stats,
        }

    # ──────────────────────────────────────────────────────────────────
    #  Core volatility computation
    # ──────────────────────────────────────────────────────────────────

    def _compute_volatility(self, rd, bd, sd, profile, current_price) -> dict | None:
        """Compute volatility score and classify regime.

        Returns dict with volatility_score, regime, atr_pct, historical_vol_trend,
        components, and auxiliary metrics.
        """
        atr_comp = self._score_atr_component(rd, current_price)
        price_swing = self._score_price_swing(profile)
        range_width = self._score_range_width(rd, profile, current_price)
        regime_ind = self._score_regime_indicator(rd, bd, sd)

        total_score = (
            atr_comp['pts']
            + price_swing['pts']
            + range_width['pts']
            + regime_ind['pts']
        )
        total_score = max(0.0, min(100.0, total_score))

        # Classify regime
        vol_regime = 'DORMANT'
        for threshold, regime_name in self.REGIME_THRESHOLDS:
            if total_score >= threshold:
                vol_regime = regime_name
                break

        # ATR percentage (from range data or computed)
        atr_pct = atr_comp.get('atr_pct', 0.0)

        # Historical volatility trend
        historical_vol_trend = self._determine_vol_trend(rd, bd, profile)

        # BB width
        bb_width = _safe_float(rd.bb_width) if rd and rd.bb_width is not None else None

        # Range width
        range_width_pct = _safe_float(rd.range_width_pct) if rd and rd.range_width_pct is not None else None

        # Mean reverting
        mean_reverting = bool(rd.is_mean_reverting) if rd and rd.is_mean_reverting else False

        return {
            'volatility_score': round(total_score, 1),
            'regime': vol_regime,
            'atr_pct': round(atr_pct, 2),
            'historical_vol_trend': historical_vol_trend,
            'components': {
                'atr_component': round(atr_comp['pts'], 1),
                'price_swing': round(price_swing['pts'], 1),
                'range_width': round(range_width['pts'], 1),
                'regime_indicator': round(regime_ind['pts'], 1),
            },
            'bb_width': round(bb_width, 4) if bb_width is not None else None,
            'range_width_pct': round(range_width_pct, 2) if range_width_pct is not None else None,
            'mean_reverting': mean_reverting,
        }

    # ──────────────────────────────────────────────────────────────────
    #  Component scorers
    # ──────────────────────────────────────────────────────────────────

    def _score_atr_component(self, rd, current_price) -> dict:
        """ATR-based volatility from range data (max 35pts).

        Uses atr_pct directly if available, or computes from
        resistance - support as ATR proxy.
        """
        pts = 0.0
        atr_pct = 0.0

        # 1. Direct ATR percentage from range score
        if rd and rd.atr_pct is not None:
            atr_pct = _safe_float(rd.atr_pct)

        # 2. Fallback: compute from support/resistance range
        if atr_pct <= 0 and rd and current_price > 0:
            support = _safe_float(rd.adaptive_support) if rd.adaptive_support is not None else _safe_float(rd.nearest_support)
            resistance = _safe_float(rd.adaptive_resistance) if rd.adaptive_resistance is not None else _safe_float(rd.nearest_resistance)

            if support > 0 and resistance > support:
                atr_proxy = resistance - support
                atr_pct = (atr_proxy / current_price) * 100

        # Score based on ATR percentage
        if atr_pct >= 20:
            pts = 35
        elif atr_pct >= 15:
            pts = 30
        elif atr_pct >= 10:
            pts = 25
        elif atr_pct >= 7:
            pts = 20
        elif atr_pct >= 5:
            pts = 15
        elif atr_pct >= 3:
            pts = 10
        elif atr_pct >= 1.5:
            pts = 6
        elif atr_pct >= 0.5:
            pts = 3
        elif atr_pct > 0:
            pts = 1

        return {'pts': min(pts, 35), 'atr_pct': atr_pct}

    def _score_price_swing(self, profile) -> dict:
        """Price change magnitude scoring (max 25pts).

        Larger absolute price swings indicate higher volatility.
        """
        pts = 0.0

        if not profile:
            return {'pts': 0}

        # 24h price change magnitude (12pts)
        if profile.price_change_24h is not None:
            abs_24h = abs(_safe_float(profile.price_change_24h))
            if abs_24h >= 15:
                pts += 12
            elif abs_24h >= 10:
                pts += 10
            elif abs_24h >= 7:
                pts += 8
            elif abs_24h >= 5:
                pts += 6
            elif abs_24h >= 3:
                pts += 4
            elif abs_24h >= 1:
                pts += 2
            elif abs_24h > 0:
                pts += 1

        # 7d price change magnitude (8pts)
        if profile.price_change_7d is not None:
            abs_7d = abs(_safe_float(profile.price_change_7d))
            if abs_7d >= 30:
                pts += 8
            elif abs_7d >= 20:
                pts += 6
            elif abs_7d >= 10:
                pts += 4
            elif abs_7d >= 5:
                pts += 2
            elif abs_7d > 0:
                pts += 1

        # 30d price change magnitude (5pts)
        if profile.price_change_30d is not None:
            abs_30d = abs(_safe_float(profile.price_change_30d))
            if abs_30d >= 50:
                pts += 5
            elif abs_30d >= 30:
                pts += 4
            elif abs_30d >= 15:
                pts += 3
            elif abs_30d >= 7:
                pts += 2
            elif abs_30d > 0:
                pts += 1

        return {'pts': min(pts, 25)}

    def _score_range_width(self, rd, profile, current_price) -> dict:
        """Range width percentage scoring (max 20pts).

        Wide range indicates high volatility.
        """
        pts = 0.0

        # 1. Range width from precomputed data
        if rd and rd.range_width_pct is not None:
            rw = _safe_float(rd.range_width_pct)
            if rw >= 30:
                pts += 12
            elif rw >= 20:
                pts += 10
            elif rw >= 15:
                pts += 8
            elif rw >= 10:
                pts += 6
            elif rw >= 5:
                pts += 4
            elif rw >= 2:
                pts += 2
            elif rw > 0:
                pts += 1

        # 2. Daily range from 24h high/low
        daily_range_pct = self._calc_daily_range(profile, current_price)
        if daily_range_pct is not None:
            if daily_range_pct >= 15:
                pts += 8
            elif daily_range_pct >= 10:
                pts += 6
            elif daily_range_pct >= 7:
                pts += 5
            elif daily_range_pct >= 5:
                pts += 4
            elif daily_range_pct >= 3:
                pts += 2
            elif daily_range_pct >= 1:
                pts += 1

        return {'pts': min(pts, 20)}

    def _score_regime_indicator(self, rd, bd, sd) -> dict:
        """BB width, momentum state, and velocity as regime indicators (max 20pts)."""
        pts = 0.0

        # 1. BB width (wide bands = high volatility) (7pts)
        if rd and rd.bb_width is not None:
            bbw = _safe_float(rd.bb_width)
            if bbw >= 0.20:
                pts += 7
            elif bbw >= 0.15:
                pts += 6
            elif bbw >= 0.10:
                pts += 5
            elif bbw >= 0.07:
                pts += 4
            elif bbw >= 0.05:
                pts += 3
            elif bbw >= 0.03:
                pts += 2
            elif bbw > 0:
                pts += 1

        # 2. Momentum state (volatile states = higher volatility) (6pts)
        if bd and bd.momentum_state:
            state = bd.momentum_state
            volatile_states = {
                'ACCELERATING_UP': 5, 'ACCELERATING_DOWN': 5,
                'REVERSING_UP': 6, 'REVERSING_DOWN': 6,
                'DECELERATING_UP': 3, 'DECELERATING_DOWN': 3,
                'STABLE': 1,
            }
            pts += volatile_states.get(state, 0)

        # 3. Velocity magnitude (high velocity = volatile) (4pts)
        if bd and bd.velocity_short_pct is not None:
            abs_vel = abs(_safe_float(bd.velocity_short_pct))
            if abs_vel >= 10:
                pts += 4
            elif abs_vel >= 5:
                pts += 3
            elif abs_vel >= 2:
                pts += 2
            elif abs_vel > 0:
                pts += 1

        # 4. Regime from signal (volatile regime bonus) (3pts)
        if sd and sd.regime:
            regime_vol = {
                'volatile': 3, 'breakout': 2,
                'trending_up': 1, 'trending_down': 1,
                'ranging': 0, 'accumulation': 0,
            }
            pts += regime_vol.get(sd.regime, 0)

        return {'pts': min(pts, 20)}

    # ──────────────────────────────────────────────────────────────────
    #  Helpers
    # ──────────────────────────────────────────────────────────────────

    def _calc_daily_range(self, profile, current_price) -> float | None:
        """Estimate daily range from price_change_24h as proxy.

        AssetProfile does not have high_24h_idr/low_24h_idr, so we use
        the absolute value of 24h change * 2 as a rough daily range estimate.
        """
        if not profile:
            return None

        change_24h = _safe_float(profile.price_change_24h) if profile.price_change_24h is not None else None
        if change_24h is not None:
            # Absolute 24h change * 2 approximates intraday range
            return round(abs(change_24h) * 2, 2)
        return None

    def _determine_vol_trend(self, rd, bd, profile) -> str:
        """Determine historical volatility trend.

        Uses available signals to infer whether volatility is
        expanding, stable, or contracting.
        """
        expanding_signals = 0
        contracting_signals = 0

        # 1. Momentum state as volatility proxy
        if bd and bd.momentum_state:
            state = bd.momentum_state
            if state in ('ACCELERATING_UP', 'ACCELERATING_DOWN', 'REVERSING_UP', 'REVERSING_DOWN'):
                expanding_signals += 1
            elif state == 'STABLE':
                contracting_signals += 1
            elif state in ('DECELERATING_UP', 'DECELERATING_DOWN'):
                contracting_signals += 1

        # 2. BB width compared to range width
        if rd and rd.bb_width is not None and rd.range_width_pct is not None:
            bbw = _safe_float(rd.bb_width)
            rw = _safe_float(rd.range_width_pct)
            if rw > 0:
                ratio = (bbw * 100) / rw
                if ratio > 1.5:
                    expanding_signals += 1
                elif ratio < 0.5:
                    contracting_signals += 1

        # 3. Velocity trend (short vs medium)
        if bd and bd.velocity_short_pct is not None and bd.velocity_medium_pct is not None:
            short_vel = abs(_safe_float(bd.velocity_short_pct))
            med_vel = abs(_safe_float(bd.velocity_medium_pct))
            if med_vel > 0:
                if short_vel / med_vel > 1.3:
                    expanding_signals += 1
                elif short_vel / med_vel < 0.7:
                    contracting_signals += 1

        # 4. 24h vs 7d price change comparison
        if profile and profile.price_change_24h is not None and profile.price_change_7d is not None:
            abs_24h = abs(_safe_float(profile.price_change_24h))
            abs_7d = abs(_safe_float(profile.price_change_7d))
            if abs_7d > 0:
                ratio = abs_24h / (abs_7d / 7)  # Daily average from 7d
                if ratio > 2:
                    expanding_signals += 1
                elif ratio < 0.5:
                    contracting_signals += 1

        if expanding_signals > contracting_signals:
            return 'EXPANDING'
        elif contracting_signals > expanding_signals:
            return 'CONTRACTING'
        return 'STABLE'

    def _resolve_price(self, profile, bd, rd) -> float | None:
        """Resolve current price from available data sources."""
        if profile and profile.current_price_idr is not None:
            p = _safe_float(profile.current_price_idr)
            if p > 0:
                return p
        if bd and bd.current_price_idr is not None:
            p = _safe_float(bd.current_price_idr)
            if p > 0:
                return p
        if rd and rd.current_price_idr is not None:
            p = _safe_float(rd.current_price_idr)
            if p > 0:
                return p
        return None

    def _load_range_scores(self) -> dict:
        """Bulk-load all RangeTradingScore records."""
        from app.models.range_score import RangeTradingScore

        return {
            r.asset_id: r
            for r in RangeTradingScore.query.filter(
                RangeTradingScore.oscillation_score > 0,
            ).all()
        }

    def _load_bullish_scores(self) -> dict:
        """Bulk-load all BullishMomentumScore records."""
        from app.models.bullish_score import BullishMomentumScore

        return {
            b.asset_id: b
            for b in BullishMomentumScore.query.filter(
                BullishMomentumScore.score > 0,
            ).all()
        }

    def _load_active_signals(self) -> dict:
        """Bulk-load active TradingSignal records (best per asset)."""
        from app.models.signal import TradingSignal

        signals: dict = {}
        for s in TradingSignal.query.filter(
            TradingSignal.status == 'active',
        ).all():
            if s.asset_id not in signals or _safe_float(s.score) > _safe_float(signals[s.asset_id].score):
                signals[s.asset_id] = s
        return signals

    def _compute_stats(self, candidates: list) -> dict:
        """Compute aggregate statistics from candidate list."""
        if not candidates:
            return {
                'extreme_count': 0,
                'high_count': 0,
                'moderate_count': 0,
                'low_count': 0,
                'dormant_count': 0,
                'avg_volatility': 0,
                'most_volatile_symbol': None,
                'total_analyzed': 0,
            }

        extreme = sum(1 for c in candidates if c['regime'] == 'EXTREME')
        high = sum(1 for c in candidates if c['regime'] == 'HIGH')
        moderate = sum(1 for c in candidates if c['regime'] == 'MODERATE')
        low = sum(1 for c in candidates if c['regime'] == 'LOW')
        dormant = sum(1 for c in candidates if c['regime'] == 'DORMANT')

        scores = [c['volatility_score'] for c in candidates]
        avg_vol = sum(scores) / len(scores) if scores else 0

        # Most volatile symbol
        most_volatile = None
        if candidates:
            top = max(candidates, key=lambda c: c['volatility_score'])
            most_volatile = top['symbol']

        return {
            'extreme_count': extreme,
            'high_count': high,
            'moderate_count': moderate,
            'low_count': low,
            'dormant_count': dormant,
            'avg_volatility': round(avg_vol, 1),
            'most_volatile_symbol': most_volatile,
            'total_analyzed': len(candidates),
        }

    def _empty_response(self, page: int, limit: int) -> dict:
        """Return empty response structure."""
        return {
            'items': [],
            'total': 0,
            'page': page,
            'limit': limit,
            'has_more': False,
            'stats': {
                'extreme_count': 0,
                'high_count': 0,
                'moderate_count': 0,
                'low_count': 0,
                'dormant_count': 0,
                'avg_volatility': 0,
                'most_volatile_symbol': None,
                'total_analyzed': 0,
            },
        }
