"""Momentum Heatmap — visual heatmap data for all assets colored by momentum.

Generates a grouped data structure showing every active asset with heat
intensity derived from price changes across multiple timeframes (1h, 24h,
7d, 30d).  Each asset is enriched with momentum metrics from
BullishMomentumScore and assigned a CSS-friendly heat color class.

Data sources (all lazy-imported):
- Asset: id, symbol, name, icon_thumb_url, asset_type, sector, sub_sector,
  market_cap_rank, is_active
- AssetProfile: current_price_idr, market_cap_idr, total_volume_idr,
  price_change_1h/24h/7d/30d (latest via max(id) GROUP BY asset_id)
- BullishMomentumScore: score, ml_trend, momentum_state, bullish_phase,
  rsi_latest, zscore, velocity_short_pct, velocity_medium_pct,
  acceleration, safety_rating
"""
from __future__ import annotations

import math
from typing import Optional


def _safe_float(val, default: float = 0.0) -> float:
    """Safely convert a DB Numeric / Decimal to float."""
    if val is None:
        return default
    try:
        f = float(val)
        return default if (math.isnan(f) or math.isinf(f)) else f
    except (ValueError, TypeError):
        return default


# ── Heat color mapping (Tailwind CSS classes) ──────────────────────
HEAT_COLORS = {
    -5: 'bg-red-900 text-red-200',
    -4: 'bg-red-800 text-red-200',
    -3: 'bg-red-700 text-red-100',
    -2: 'bg-red-600/60 text-red-200',
    -1: 'bg-red-500/30 text-red-300',
    0: 'bg-gray-800 text-gray-300',
    1: 'bg-emerald-500/30 text-emerald-300',
    2: 'bg-emerald-600/60 text-emerald-200',
    3: 'bg-emerald-700 text-emerald-100',
    4: 'bg-emerald-800 text-emerald-200',
    5: 'bg-emerald-900 text-emerald-200',
}

# Scale factors per timeframe — defines what counts as "extreme" move
TIMEFRAME_SCALES = {
    '1h': 3,
    '24h': 8,
    '7d': 20,
    '30d': 40,
}

# Map timeframe key → AssetProfile attribute name
TIMEFRAME_FIELD_MAP = {
    '1h': 'price_change_1h',
    '24h': 'price_change_24h',
    '7d': 'price_change_7d',
    '30d': 'price_change_30d',
}

# Friendly group names for asset_type values
ASSET_TYPE_LABELS = {
    'crypto': 'Crypto',
    'stock': 'Saham IDX',
    'stock_us': 'Stock US',
}


def _calc_heat_level(change_pct: Optional[float], timeframe: str) -> int:
    """Map price change percentage to heat level -5 to +5."""
    scale = TIMEFRAME_SCALES.get(timeframe, 8)

    if change_pct is None:
        return 0

    # Normalize to -5..+5 range
    normalized = (change_pct / scale) * 5
    return max(-5, min(5, round(normalized)))


def _classify_group_momentum(avg_change: float, timeframe: str) -> str:
    """Classify group momentum based on average change and timeframe scale."""
    scale = TIMEFRAME_SCALES.get(timeframe, 8)

    if avg_change > 2 * scale:
        return 'very_bullish'
    if avg_change > 0.5 * scale:
        return 'bullish'
    if avg_change > -0.5 * scale:
        return 'neutral'
    if avg_change > -2 * scale:
        return 'bearish'
    return 'very_bearish'


class MomentumHeatmap:
    """Generates heatmap data for all assets colored by momentum."""

    # ──────────────────────────────────────────────────────────
    #  Public API
    # ──────────────────────────────────────────────────────────

    def get_heatmap(
        self,
        group_by: str = 'asset_type',
        timeframe: str = '24h',
        sort_by: str = 'change',
        limit: int = 200,
        page: int = 1,
        search_q: str = '',
        asset_type: str = 'all',
    ) -> dict:
        """Return heatmap data structure organized by groups.

        Parameters
        ----------
        group_by : str
            'asset_type' (Crypto / Stock IDX / Stock US),
            'sector' (by Asset.sector), or 'none' (flat list).
        timeframe : str
            '1h', '24h', '7d', '30d' — which price change to use for heat.
        sort_by : str
            'change' (desc), 'change_asc', 'market_cap', 'name', 'momentum'.
        limit : int
            Max total assets across all groups (paginated).
        page : int
            Page number (1-based).
        search_q : str
            Case-insensitive filter on symbol or name.
        asset_type : str
            'all', 'crypto', 'stock', 'stock_us'.

        Returns
        -------
        dict with 'groups', 'summary', 'total', 'page', 'limit', 'has_more'.
        """
        # Validate timeframe
        if timeframe not in TIMEFRAME_SCALES:
            timeframe = '24h'

        # ── 1. Load data ──────────────────────────────────────
        coins_map = self._load_coins(asset_type)
        profiles_map = self._load_latest_profiles(asset_type)
        bullish_map = self._load_bullish_scores(asset_type)

        if not coins_map:
            return self._empty_result(timeframe, page, limit)

        # ── 2. Build per-asset records ─────────────────────────
        records: list[dict] = []

        for asset_id, asset in coins_map.items():
            # Apply search filter
            if search_q:
                q_lower = search_q.lower()
                if (q_lower not in asset.symbol.lower()
                        and q_lower not in asset.name.lower()):
                    continue

            profile = profiles_map.get(asset_id)
            bullish = bullish_map.get(asset_id)

            # Price changes from profile
            change_1h = _safe_float(profile.price_change_1h) if profile else 0.0
            change_24h = _safe_float(profile.price_change_24h) if profile else 0.0
            change_7d = _safe_float(profile.price_change_7d) if profile else 0.0
            change_30d = _safe_float(profile.price_change_30d) if profile else 0.0

            # Selected timeframe change
            change_map = {
                '1h': change_1h,
                '24h': change_24h,
                '7d': change_7d,
                '30d': change_30d,
            }
            change_value = change_map.get(timeframe, change_24h)

            # Heat
            heat_level = _calc_heat_level(change_value, timeframe)
            heat_color = HEAT_COLORS.get(heat_level, HEAT_COLORS[0])

            # Momentum data from BullishMomentumScore
            rsi = _safe_float(bullish.rsi_latest, 50.0) if bullish else 50.0
            zscore = _safe_float(bullish.zscore) if bullish else 0.0
            momentum_state = (bullish.momentum_state or 'STABLE') if bullish else 'STABLE'
            ml_trend = (bullish.ml_trend or 'unknown') if bullish else 'unknown'
            momentum_score = _safe_float(bullish.score) if bullish else 0.0
            safety_rating = (bullish.safety_rating or 'MODERATE') if bullish else 'MODERATE'

            current_price = _safe_float(profile.current_price_idr) if profile else 0.0
            market_cap = _safe_float(profile.market_cap_idr) if profile else 0.0

            records.append({
                'asset_id': asset_id,
                'symbol': asset.symbol,
                'name': asset.name,
                'icon_thumb_url': asset.icon_thumb_url or '',
                'asset_type': asset.asset_type or 'crypto',
                'sector': asset.sector or '',
                'current_price_idr': current_price,
                'market_cap_idr': market_cap,
                'market_cap_rank': asset.market_cap_rank or 9999,
                'change_1h': change_1h,
                'change_24h': change_24h,
                'change_7d': change_7d,
                'change_30d': change_30d,
                'change_value': change_value,
                'rsi': rsi,
                'zscore': zscore,
                'momentum_state': momentum_state,
                'ml_trend': ml_trend,
                'momentum_score': momentum_score,
                'safety_rating': safety_rating,
                'heat_level': heat_level,
                'heat_color': heat_color,
            })

        if not records:
            return self._empty_result(timeframe, page, limit)

        # ── 3. Sort records ───────────────────────────────────
        records = self._sort_records(records, sort_by)

        # ── 4. Build summary (before pagination) ──────────────
        summary = self._build_summary(records, timeframe)

        # ── 5. Group records ──────────────────────────────────
        grouped = self._group_records(records, group_by)

        # Sort groups by avg_change descending (best first)
        for g in grouped:
            changes = [c['change_value'] for c in g['assets']]
            g['avg_change'] = round(sum(changes) / len(changes), 4) if changes else 0.0
            g['coin_count'] = len(g['assets'])
            g['group_momentum'] = _classify_group_momentum(g['avg_change'], timeframe)
        grouped.sort(key=lambda g: g['avg_change'], reverse=True)

        # ── 6. Paginate across all groups ─────────────────────
        total_coins = sum(g['coin_count'] for g in grouped)
        offset = (page - 1) * limit

        paginated_groups = self._paginate_groups(grouped, offset, limit)

        has_more = (offset + limit) < total_coins

        return {
            'groups': paginated_groups,
            'summary': summary,
            'total': total_coins,
            'page': page,
            'limit': limit,
            'has_more': has_more,
        }

    # ──────────────────────────────────────────────────────────
    #  Private helpers — data loading
    # ──────────────────────────────────────────────────────────

    def _load_coins(self, asset_type: str = 'all') -> dict:
        """Load all active Asset objects, optionally filtered by asset_type."""
        from app.models.asset import Asset

        query = Asset.query.filter(Asset.is_active.is_(True))
        if asset_type and asset_type != 'all':
            query = query.filter(Asset.asset_type == asset_type)

        return {c.id: c for c in query.all()}

    def _load_latest_profiles(self, asset_type: str = 'all') -> dict:
        """Load latest AssetProfile per asset via max(id) subquery."""
        from app.extensions import db
        from app.models.asset import Asset, AssetProfile
        from sqlalchemy import func as sa_func

        # Subquery: active asset ids of the right asset_type
        coin_q = db.session.query(Asset.id).filter(Asset.is_active.is_(True))
        if asset_type and asset_type != 'all':
            coin_q = coin_q.filter(Asset.asset_type == asset_type)
        asset_ids_sq = coin_q.subquery()

        # Subquery: latest profile id per asset
        latest_sq = db.session.query(
            AssetProfile.asset_id,
            sa_func.max(AssetProfile.id).label('max_id'),
        ).filter(
            AssetProfile.asset_id.in_(db.session.query(asset_ids_sq.c.id)),
        ).group_by(AssetProfile.asset_id).subquery()

        profiles = AssetProfile.query.join(
            latest_sq, AssetProfile.id == latest_sq.c.max_id,
        ).all()

        return {p.asset_id: p for p in profiles}

    def _load_bullish_scores(self, asset_type: str = 'all') -> dict:
        """Load all BullishMomentumScore rows."""
        from app.models.bullish_score import BullishMomentumScore
        from app.models.asset import Asset

        query = BullishMomentumScore.query.join(
            Asset, Asset.id == BullishMomentumScore.asset_id,
        ).filter(Asset.is_active.is_(True))

        if asset_type and asset_type != 'all':
            query = query.filter(Asset.asset_type == asset_type)

        return {b.asset_id: b for b in query.all()}

    # ──────────────────────────────────────────────────────────
    #  Private helpers — sorting
    # ──────────────────────────────────────────────────────────

    def _sort_records(self, records: list[dict], sort_by: str) -> list[dict]:
        """Sort asset records by the requested criteria."""
        sort_configs = {
            'change': (lambda r: r['change_value'], True),
            'change_asc': (lambda r: r['change_value'], False),
            'market_cap': (lambda r: r['market_cap_idr'], True),
            'name': (lambda r: r['name'].lower(), False),
            'momentum': (lambda r: r['momentum_score'], True),
        }

        key_fn, reverse = sort_configs.get(sort_by, sort_configs['change'])
        records.sort(key=key_fn, reverse=reverse)
        return records

    # ──────────────────────────────────────────────────────────
    #  Private helpers — grouping
    # ──────────────────────────────────────────────────────────

    def _group_records(self, records: list[dict], group_by: str) -> list[dict]:
        """Group sorted records by the requested grouping strategy.

        Returns a list of group dicts, each with 'group_name' and 'assets'.
        The 'avg_change', 'coin_count', and 'group_momentum' fields are
        computed later by the caller.
        """
        if group_by == 'none':
            return [{
                'group_name': 'All Assets',
                'avg_change': 0.0,
                'coin_count': 0,
                'group_momentum': 'neutral',
                'assets': list(records),
            }]

        if group_by == 'sector':
            return self._group_by_sector(records)

        # Default: group_by == 'asset_type'
        return self._group_by_asset_type(records)

    def _group_by_asset_type(self, records: list[dict]) -> list[dict]:
        """Group assets by asset_type with friendly labels."""
        buckets: dict[str, list[dict]] = {}
        for rec in records:
            at = rec['asset_type']
            buckets.setdefault(at, []).append(rec)

        groups = []
        for at, assets in buckets.items():
            groups.append({
                'group_name': ASSET_TYPE_LABELS.get(at, at),
                'avg_change': 0.0,
                'coin_count': 0,
                'group_momentum': 'neutral',
                'assets': assets,
            })
        return groups

    def _group_by_sector(self, records: list[dict]) -> list[dict]:
        """Group assets by sector field."""
        buckets: dict[str, list[dict]] = {}
        for rec in records:
            sector = rec['sector'] if rec['sector'] else 'Uncategorized'
            buckets.setdefault(sector, []).append(rec)

        groups = []
        for sector_name, assets in buckets.items():
            groups.append({
                'group_name': sector_name,
                'avg_change': 0.0,
                'coin_count': 0,
                'group_momentum': 'neutral',
                'assets': assets,
            })
        return groups

    # ──────────────────────────────────────────────────────────
    #  Private helpers — pagination
    # ──────────────────────────────────────────────────────────

    def _paginate_groups(
        self,
        groups: list[dict],
        offset: int,
        limit: int,
    ) -> list[dict]:
        """Paginate total assets across groups.

        Walks through groups in order, skipping *offset* assets, then
        collecting up to *limit* assets.  Groups with zero assets remaining
        after slicing are omitted from the result.
        """
        result: list[dict] = []
        remaining_skip = offset
        remaining_take = limit

        for g in groups:
            if remaining_take <= 0:
                break

            group_coins = g['assets']
            group_size = len(group_coins)

            # Still skipping?
            if remaining_skip >= group_size:
                remaining_skip -= group_size
                continue

            # Slice within this group
            start = remaining_skip
            remaining_skip = 0
            end = start + remaining_take
            page_coins = group_coins[start:end]
            remaining_take -= len(page_coins)

            if page_coins:
                # Recompute group stats for the visible slice
                changes = [c['change_value'] for c in page_coins]
                avg_ch = round(sum(changes) / len(changes), 4) if changes else 0.0
                result.append({
                    'group_name': g['group_name'],
                    'coin_count': g['coin_count'],       # original full count
                    'avg_change': g['avg_change'],        # original full avg
                    'group_momentum': g['group_momentum'],
                    'assets': page_coins,
                })

        return result

    # ──────────────────────────────────────────────────────────
    #  Private helpers — summary
    # ──────────────────────────────────────────────────────────

    def _build_summary(self, records: list[dict], timeframe: str) -> dict:
        """Build aggregate summary from all records (pre-pagination)."""
        total = len(records)
        if total == 0:
            return {
                'total_coins': 0,
                'positive_count': 0,
                'negative_count': 0,
                'neutral_count': 0,
                'avg_change': 0.0,
                'strongest': {'symbol': '', 'change': 0.0},
                'weakest': {'symbol': '', 'change': 0.0},
                'timeframe': timeframe,
            }

        positive = 0
        negative = 0
        neutral = 0
        total_change = 0.0
        strongest = records[0]
        weakest = records[0]

        for rec in records:
            cv = rec['change_value']
            total_change += cv

            if cv > 0.01:
                positive += 1
            elif cv < -0.01:
                negative += 1
            else:
                neutral += 1

            if cv > strongest['change_value']:
                strongest = rec
            if cv < weakest['change_value']:
                weakest = rec

        avg_change = round(total_change / total, 4)

        return {
            'total_coins': total,
            'positive_count': positive,
            'negative_count': negative,
            'neutral_count': neutral,
            'avg_change': avg_change,
            'strongest': {
                'symbol': strongest['symbol'],
                'change': strongest['change_value'],
            },
            'weakest': {
                'symbol': weakest['symbol'],
                'change': weakest['change_value'],
            },
            'timeframe': timeframe,
        }

    # ──────────────────────────────────────────────────────────
    #  Private helpers — empty result
    # ──────────────────────────────────────────────────────────

    def _empty_result(self, timeframe: str, page: int, limit: int) -> dict:
        """Return a valid but empty heatmap response."""
        return {
            'groups': [],
            'summary': {
                'total_coins': 0,
                'positive_count': 0,
                'negative_count': 0,
                'neutral_count': 0,
                'avg_change': 0.0,
                'strongest': {'symbol': '', 'change': 0.0},
                'weakest': {'symbol': '', 'change': 0.0},
                'timeframe': timeframe,
            },
            'total': 0,
            'page': page,
            'limit': limit,
            'has_more': False,
        }
