"""
Probability Matrix Service  (Feature 112)
------------------------------------------
Calculates a win-probability score (0-100 %) for each asset by combining
ML direction data, bullish phase quality, range-position proximity to
support, signal confidence, mean-reversion likelihood, and momentum
characteristics.

Tiers:
    high_prob   >= 70
    medium_prob >= 45
    low_prob    <  45
"""

import math
import logging
from datetime import datetime

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Safe-value helpers
# ---------------------------------------------------------------------------
def _safe_float(val, default=0.0):
    """Convert *val* to float, returning *default* on None / NaN / Inf."""
    if val is None:
        return default
    try:
        f = float(val)
        if math.isnan(f) or math.isinf(f):
            return default
        return f
    except (ValueError, TypeError):
        return default


def _safe_int(val, default=0):
    """Convert *val* to int safely."""
    if val is None:
        return default
    try:
        return int(val)
    except (ValueError, TypeError):
        return default


def _clamp(val, lo=0.0, hi=100.0):
    """Clamp a numeric value between *lo* and *hi*."""
    return max(lo, min(hi, val))


# ---------------------------------------------------------------------------
# Service
# ---------------------------------------------------------------------------
class ProbabilityMatrixService:
    """Win-probability matrix scoring every active asset 0-100 %."""

    # Phase quality map for momentum scoring
    PHASE_QUALITY = {
        'EARLY_BULLISH': 95,
        'PEAK_BULLISH': 75,
        'ACCUMULATION': 60,
        'NOT_BULLISH': 25,
        'MARKDOWN': 10,
    }

    MOMENTUM_QUALITY = {
        'ACCELERATING_UP': 95,
        'REVERSING_UP': 85,
        'DECELERATING_UP': 60,
        'STABLE': 45,
        'DECELERATING_DOWN': 25,
        'REVERSING_DOWN': 20,
        'ACCELERATING_DOWN': 10,
    }

    SORT_OPTIONS = {
        'probability_desc': ('win_probability', True),
        'probability_asc': ('win_probability', False),
        'name_asc': ('symbol', False),
        'name_desc': ('symbol', True),
    }

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------
    def scan_all(
        self,
        asset_type='crypto',
        limit=50,
        page=1,
        sort_by='probability_desc',
        tier='all',
    ):
        """Return paginated, scored probability matrix.

        Parameters
        ----------
        asset_type : str
            'crypto', 'stock', 'stock_us', etc.
        limit : int
            Items per page (max 200).
        page : int
            1-based page number.
        sort_by : str
            One of SORT_OPTIONS keys.
        tier : str
            'all', 'high_prob', 'medium_prob', 'low_prob'.

        Returns
        -------
        dict  {'items', 'total', 'page', 'limit', 'has_more', 'stats'}
        """
        try:
            limit = max(1, min(200, _safe_int(limit, 50)))
            page = max(1, _safe_int(page, 1))

            raw = self._fetch_data(asset_type)
            if not raw:
                return self._empty_result(page, limit)

            scored = []
            for row in raw:
                item = self._score_asset(row)
                if item is not None:
                    scored.append(item)

            if not scored:
                return self._empty_result(page, limit)

            # Tier filter
            if tier and tier != 'all':
                scored = [s for s in scored if s.get('tier') == tier]

            # Sort
            scored = self._sort_items(scored, sort_by)

            total = len(scored)
            start = (page - 1) * limit
            end = start + limit
            page_items = scored[start:end]
            has_more = end < total

            stats = self._compute_stats(scored)

            return {
                'items': page_items,
                'total': total,
                'page': page,
                'limit': limit,
                'has_more': has_more,
                'stats': stats,
            }
        except Exception as exc:
            logger.exception("ProbabilityMatrixService.scan_all failed: %s", exc)
            return self._empty_result(page, limit)

    # ------------------------------------------------------------------
    # Data fetching  (lazy imports inside method)
    # ------------------------------------------------------------------
    def _fetch_data(self, asset_type):
        """Fetch active assets with their latest profile, range, bullish, and signal rows."""
        from app.extensions import db
        from app.models.asset import Asset, AssetProfile
        from app.models.range_score import RangeTradingScore
        from app.models.bullish_score import BullishMomentumScore
        from app.models.signal import TradingSignal
        from sqlalchemy import func

        try:
            latest_profile = (
                db.session.query(
                    AssetProfile.asset_id,
                    func.max(AssetProfile.id).label('max_id'),
                )
                .group_by(AssetProfile.asset_id)
                .subquery()
            )
            latest_range = (
                db.session.query(
                    RangeTradingScore.asset_id,
                    func.max(RangeTradingScore.id).label('max_id'),
                )
                .group_by(RangeTradingScore.asset_id)
                .subquery()
            )
            latest_bullish = (
                db.session.query(
                    BullishMomentumScore.asset_id,
                    func.max(BullishMomentumScore.id).label('max_id'),
                )
                .group_by(BullishMomentumScore.asset_id)
                .subquery()
            )
            latest_signal = (
                db.session.query(
                    TradingSignal.asset_id,
                    func.max(TradingSignal.id).label('max_id'),
                )
                .group_by(TradingSignal.asset_id)
                .subquery()
            )

            query = (
                db.session.query(
                    Asset, AssetProfile, RangeTradingScore,
                    BullishMomentumScore, TradingSignal,
                )
                .filter(Asset.is_active.is_(True))
                .filter(Asset.asset_type == asset_type)
                .outerjoin(latest_profile, latest_profile.c.asset_id == Asset.id)
                .outerjoin(AssetProfile, AssetProfile.id == latest_profile.c.max_id)
                .outerjoin(latest_range, latest_range.c.asset_id == Asset.id)
                .outerjoin(RangeTradingScore, RangeTradingScore.id == latest_range.c.max_id)
                .outerjoin(latest_bullish, latest_bullish.c.asset_id == Asset.id)
                .outerjoin(BullishMomentumScore, BullishMomentumScore.id == latest_bullish.c.max_id)
                .outerjoin(latest_signal, latest_signal.c.asset_id == Asset.id)
                .outerjoin(TradingSignal, TradingSignal.id == latest_signal.c.max_id)
            )
            return query.all()
        except Exception as exc:
            logger.error("ProbabilityMatrixService._fetch_data error: %s", exc)
            return []

    # ------------------------------------------------------------------
    # Per-asset scoring
    # ------------------------------------------------------------------
    def _score_asset(self, row):
        """Score one (Asset, AssetProfile, RangeTradingScore, BullishMomentumScore, TradingSignal) row."""
        try:
            asset, profile, rts, bullish, signal = row

            if asset is None:
                return None

            symbol = getattr(asset, 'symbol', '') or ''
            name = getattr(asset, 'name', '') or ''
            icon = getattr(asset, 'icon_thumb_url', '') or ''

            # --- Profile ---
            current_price = _safe_float(getattr(profile, 'current_price_idr', None))

            # --- Range ---
            support = _safe_float(getattr(rts, 'nearest_support', None))
            resistance = _safe_float(getattr(rts, 'nearest_resistance', None))
            zscore = _safe_float(getattr(rts, 'zscore', None))
            is_mr = bool(getattr(rts, 'is_mean_reverting', False))
            mr_conf_raw = getattr(rts, 'mr_confidence', None) or ''
            rts_price = _safe_float(getattr(rts, 'current_price_idr', None))

            # Compute range position manually (no attribute on model)
            range_pos = None
            if resistance > support > 0 and rts_price > 0:
                range_pos = (rts_price - support) / (resistance - support) * 100.0

            # --- Bullish ---
            phase = getattr(bullish, 'bullish_phase', None) or 'NOT_BULLISH'
            mom_state = getattr(bullish, 'momentum_state', None) or 'STABLE'
            vel_short = _safe_float(getattr(bullish, 'velocity_short_pct', None))
            accel = _safe_float(getattr(bullish, 'acceleration', None))
            ml_trend = getattr(bullish, 'ml_trend', None) or ''

            # --- Signal ---
            confidence_str = getattr(signal, 'confidence', None) or ''
            dir_prob = _safe_float(getattr(signal, 'direction_probability', None))
            mtf = bool(getattr(signal, 'mtf_confirmed', False))

            # ============================================================
            # Component 1: ML Direction (max 25 pts)
            # ============================================================
            ml_score = 0.0
            ml_lower = ml_trend.lower() if ml_trend else ''
            if ml_lower == 'bullish':
                ml_score += 20.0
                if dir_prob > 0.7:
                    ml_score += 5.0
                elif dir_prob > 0.5:
                    ml_score += 3.0
            elif ml_lower == 'neutral':
                ml_score += 10.0
                if dir_prob > 0.6:
                    ml_score += 2.0
            else:
                # bearish / unknown
                ml_score += 3.0
            ml_score = _clamp(ml_score, 0, 25)

            # ============================================================
            # Component 2: Bullish Phase (max 20 pts)
            # ============================================================
            phase_map = {
                'EARLY_BULLISH': 20.0,
                'PEAK_BULLISH': 15.0,
                'ACCUMULATION': 10.0,
                'NOT_BULLISH': 3.0,
                'MARKDOWN': 1.0,
            }
            phase_score = phase_map.get(phase, 3.0)
            phase_score = _clamp(phase_score, 0, 20)

            # ============================================================
            # Component 3: Range Position (max 15 pts)
            # ============================================================
            range_score = 5.0  # default when no data
            if range_pos is not None:
                if range_pos < 30:
                    range_score = 15.0
                elif range_pos < 50:
                    range_score = 10.0
                elif range_pos < 70:
                    range_score = 7.0
                else:
                    range_score = 3.0
            range_score = _clamp(range_score, 0, 15)

            # ============================================================
            # Component 4: Signal Confidence (max 15 pts)
            # ============================================================
            sig_score = 0.0
            conf_upper = confidence_str.strip().capitalize() if confidence_str else ''
            if conf_upper == 'High':
                sig_score += 15.0
            elif conf_upper == 'Medium':
                sig_score += 10.0
            elif conf_upper == 'Low':
                sig_score += 5.0
            else:
                sig_score += 2.0

            if mtf:
                sig_score += 3.0
            sig_score = _clamp(sig_score, 0, 15)

            # ============================================================
            # Component 5: Mean Reversion (max 10 pts)
            # ============================================================
            mr_score = 0.0
            if is_mr:
                mr_upper = mr_conf_raw.upper() if isinstance(mr_conf_raw, str) else ''
                if mr_upper == 'HIGH':
                    mr_score += 10.0
                elif mr_upper == 'MEDIUM':
                    mr_score += 7.0
                else:
                    mr_score += 5.0
            else:
                mr_score += 1.0

            # Bonus for negative z-score (oversold)
            if zscore < -1.5:
                mr_score += 3.0
            elif zscore < -1.0:
                mr_score += 2.0
            elif zscore < -0.5:
                mr_score += 1.0
            mr_score = _clamp(mr_score, 0, 10)

            # ============================================================
            # Component 6: Momentum (max 15 pts)
            # ============================================================
            mom_score = 0.0
            if vel_short > 0 and accel > 0:
                # Both velocity and acceleration positive = strong momentum
                mom_score += 12.0
                if vel_short > 3:
                    mom_score += 3.0
                elif vel_short > 1:
                    mom_score += 1.5
            elif vel_short > 0:
                # Velocity positive but decelerating
                mom_score += 7.0
            elif accel > 0:
                # Reversing upward
                mom_score += 6.0
            else:
                mom_score += 2.0

            # Phase quality bonus
            pq = self.PHASE_QUALITY.get(phase, 25)
            mom_score += (pq / 100.0) * 3.0

            # Momentum state bonus
            mq = self.MOMENTUM_QUALITY.get(mom_state, 45)
            mom_score += (mq / 100.0) * 2.0

            mom_score = _clamp(mom_score, 0, 15)

            # ============================================================
            # Aggregate
            # ============================================================
            win_probability = round(
                _clamp(ml_score + phase_score + range_score + sig_score + mr_score + mom_score, 0, 100),
                1,
            )

            # Tier
            if win_probability >= 70:
                tier_label = 'high_prob'
            elif win_probability >= 45:
                tier_label = 'medium_prob'
            else:
                tier_label = 'low_prob'

            return {
                'asset_id': asset.id,
                'symbol': symbol.upper(),
                'name': name,
                'icon_thumb_url': icon,
                'current_price_idr': current_price,
                'win_probability': win_probability,
                'tier': tier_label,
                'components': {
                    'ml_score': round(ml_score, 2),
                    'phase_score': round(phase_score, 2),
                    'range_score': round(range_score, 2),
                    'signal_score': round(sig_score, 2),
                    'mr_score': round(mr_score, 2),
                    'momentum_score': round(mom_score, 2),
                },
                'bullish_phase': phase,
                'ml_trend': ml_trend,
                'momentum_state': mom_state,
                'direction_probability': round(dir_prob, 4),
                'range_position_pct': round(range_pos, 2) if range_pos is not None else None,
                'support_price': support,
                'resistance_price': resistance,
            }
        except Exception as exc:
            logger.warning("ProbabilityMatrixService._score_asset error: %s", exc)
            return None

    # ------------------------------------------------------------------
    # Stats
    # ------------------------------------------------------------------
    def _compute_stats(self, scored):
        """Build aggregate statistics for the full scored list."""
        total = len(scored)
        if total == 0:
            return self._empty_stats()

        probs = [s['win_probability'] for s in scored]
        avg_prob = round(sum(probs) / total, 1)

        high_count = sum(1 for s in scored if s['tier'] == 'high_prob')
        medium_count = sum(1 for s in scored if s['tier'] == 'medium_prob')

        return {
            'total_assets': total,
            'avg_probability': avg_prob,
            'high_prob_count': high_count,
            'medium_prob_count': medium_count,
        }

    # ------------------------------------------------------------------
    # Sorting
    # ------------------------------------------------------------------
    def _sort_items(self, items, sort_by):
        """Sort items by *sort_by* key."""
        key_name, descending = self.SORT_OPTIONS.get(
            sort_by, ('win_probability', True),
        )
        try:
            if key_name == 'symbol':
                items.sort(
                    key=lambda x: (x.get(key_name) or '').lower(),
                    reverse=descending,
                )
            else:
                items.sort(
                    key=lambda x: _safe_float(x.get(key_name, 0)),
                    reverse=descending,
                )
        except Exception:
            pass
        return items

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------
    @staticmethod
    def _empty_result(page=1, limit=50):
        return {
            'items': [],
            'total': 0,
            'page': page,
            'limit': limit,
            'has_more': False,
            'stats': ProbabilityMatrixService._empty_stats(),
        }

    @staticmethod
    def _empty_stats():
        return {
            'total_assets': 0,
            'avg_probability': 0.0,
            'high_prob_count': 0,
            'medium_prob_count': 0,
        }
