"""
Precision Timer Service  (Feature 113)
---------------------------------------
Determines the optimal entry timing for each asset and assigns a verdict:

    NOW   (>= 75)  –  Strong buy zone, enter immediately
    SOON  (>= 55)  –  Approaching entry, prepare orders
    WAIT  (>= 35)  –  Not yet ideal, monitor closely
    AVOID (<  35)  –  Poor timing, stay away

Scoring pillars (total 100):
    Phase Timing        25 pts
    Z-Score Extremity   20 pts
    Velocity Reversal   20 pts
    Cycle Position      15 pts
    MTF Alignment       20 pts
"""

import math
import logging
from datetime import datetime

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Safe-value helpers
# ---------------------------------------------------------------------------
def _safe_float(val, default=0.0):
    """Convert *val* to float, returning *default* on None / NaN / Inf."""
    if val is None:
        return default
    try:
        f = float(val)
        if math.isnan(f) or math.isinf(f):
            return default
        return f
    except (ValueError, TypeError):
        return default


def _safe_int(val, default=0):
    """Convert *val* to int safely."""
    if val is None:
        return default
    try:
        return int(val)
    except (ValueError, TypeError):
        return default


def _clamp(val, lo=0.0, hi=100.0):
    """Clamp a numeric value between *lo* and *hi*."""
    return max(lo, min(hi, val))


# ---------------------------------------------------------------------------
# Verdict labels and entry-window descriptions
# ---------------------------------------------------------------------------
_VERDICT_THRESHOLDS = [
    (75, 'NOW'),
    (55, 'SOON'),
    (35, 'WAIT'),
]

_ENTRY_WINDOW = {
    'NOW': 'Strong buy zone - enter now',
    'SOON': 'Approaching entry - prepare orders',
    'WAIT': 'Not ideal yet - monitor closely',
    'AVOID': 'Poor timing - stay away',
}


# ---------------------------------------------------------------------------
# Service
# ---------------------------------------------------------------------------
class PrecisionTimerService:
    """Optimal-entry timing scorer for every active asset."""

    SORT_OPTIONS = {
        'timing_score_desc': ('timing_score', True),
        'timing_score_asc': ('timing_score', False),
        'name_asc': ('symbol', False),
    }

    # Phase / regime combo scores (max 25)
    PHASE_TIMING = {
        'EARLY_BULLISH': 22,
        'ACCUMULATION': 18,
        'PEAK_BULLISH': 10,
        'NOT_BULLISH': 5,
        'MARKDOWN': 2,
    }

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------
    def scan_all(
        self,
        asset_type='crypto',
        limit=50,
        page=1,
        sort_by='timing_score_desc',
        verdict='all',
    ):
        """Return paginated timing verdicts for all active assets.

        Parameters
        ----------
        asset_type : str
            'crypto', 'stock', 'stock_us', etc.
        limit : int
            Items per page (max 200).
        page : int
            1-based page number.
        sort_by : str
            One of SORT_OPTIONS keys.
        verdict : str
            'all', 'NOW', 'SOON', 'WAIT', 'AVOID'.

        Returns
        -------
        dict  {'items', 'total', 'page', 'limit', 'has_more', 'stats'}
        """
        try:
            limit = max(1, min(200, _safe_int(limit, 50)))
            page = max(1, _safe_int(page, 1))

            raw = self._fetch_data(asset_type)
            if not raw:
                return self._empty_result(page, limit)

            scored = []
            for row in raw:
                item = self._score_asset(row)
                if item is not None:
                    scored.append(item)

            if not scored:
                return self._empty_result(page, limit)

            # Verdict filter
            if verdict and verdict != 'all':
                verdict_upper = verdict.upper()
                scored = [s for s in scored if s.get('verdict') == verdict_upper]

            # Sort
            scored = self._sort_items(scored, sort_by)

            total = len(scored)
            start = (page - 1) * limit
            end = start + limit
            page_items = scored[start:end]
            has_more = end < total

            stats = self._compute_stats(scored)

            return {
                'items': page_items,
                'total': total,
                'page': page,
                'limit': limit,
                'has_more': has_more,
                'stats': stats,
            }
        except Exception as exc:
            logger.exception("PrecisionTimerService.scan_all failed: %s", exc)
            return self._empty_result(page, limit)

    # ------------------------------------------------------------------
    # Data fetching  (lazy imports inside method)
    # ------------------------------------------------------------------
    def _fetch_data(self, asset_type):
        """Fetch active assets with latest profile, range, bullish, and signal rows."""
        from app.extensions import db
        from app.models.asset import Asset, AssetProfile
        from app.models.range_score import RangeTradingScore
        from app.models.bullish_score import BullishMomentumScore
        from app.models.signal import TradingSignal
        from sqlalchemy import func

        try:
            latest_profile = (
                db.session.query(
                    AssetProfile.asset_id,
                    func.max(AssetProfile.id).label('max_id'),
                )
                .group_by(AssetProfile.asset_id)
                .subquery()
            )
            latest_range = (
                db.session.query(
                    RangeTradingScore.asset_id,
                    func.max(RangeTradingScore.id).label('max_id'),
                )
                .group_by(RangeTradingScore.asset_id)
                .subquery()
            )
            latest_bullish = (
                db.session.query(
                    BullishMomentumScore.asset_id,
                    func.max(BullishMomentumScore.id).label('max_id'),
                )
                .group_by(BullishMomentumScore.asset_id)
                .subquery()
            )
            latest_signal = (
                db.session.query(
                    TradingSignal.asset_id,
                    func.max(TradingSignal.id).label('max_id'),
                )
                .group_by(TradingSignal.asset_id)
                .subquery()
            )

            query = (
                db.session.query(
                    Asset, AssetProfile, RangeTradingScore,
                    BullishMomentumScore, TradingSignal,
                )
                .filter(Asset.is_active.is_(True))
                .filter(Asset.asset_type == asset_type)
                .outerjoin(latest_profile, latest_profile.c.asset_id == Asset.id)
                .outerjoin(AssetProfile, AssetProfile.id == latest_profile.c.max_id)
                .outerjoin(latest_range, latest_range.c.asset_id == Asset.id)
                .outerjoin(RangeTradingScore, RangeTradingScore.id == latest_range.c.max_id)
                .outerjoin(latest_bullish, latest_bullish.c.asset_id == Asset.id)
                .outerjoin(BullishMomentumScore, BullishMomentumScore.id == latest_bullish.c.max_id)
                .outerjoin(latest_signal, latest_signal.c.asset_id == Asset.id)
                .outerjoin(TradingSignal, TradingSignal.id == latest_signal.c.max_id)
            )
            return query.all()
        except Exception as exc:
            logger.error("PrecisionTimerService._fetch_data error: %s", exc)
            return []

    # ------------------------------------------------------------------
    # Per-asset scoring
    # ------------------------------------------------------------------
    def _score_asset(self, row):
        """Score one (Asset, AssetProfile, RTS, Bullish, Signal) row for timing."""
        try:
            asset, profile, rts, bullish, signal = row

            if asset is None:
                return None

            symbol = getattr(asset, 'symbol', '') or ''
            name = getattr(asset, 'name', '') or ''
            icon = getattr(asset, 'icon_thumb_url', '') or ''

            # --- Profile ---
            current_price = _safe_float(getattr(profile, 'current_price_idr', None))

            # --- Range ---
            support = _safe_float(getattr(rts, 'nearest_support', None))
            resistance = _safe_float(getattr(rts, 'nearest_resistance', None))
            zscore = _safe_float(getattr(rts, 'zscore', None))
            rts_price = _safe_float(getattr(rts, 'current_price_idr', None))

            # Compute range position (no attribute on model)
            range_pos = None
            if resistance > support > 0 and rts_price > 0:
                range_pos = (rts_price - support) / (resistance - support) * 100.0

            # --- Bullish ---
            phase = getattr(bullish, 'bullish_phase', None) or 'NOT_BULLISH'
            mom_state = getattr(bullish, 'momentum_state', None) or 'STABLE'
            vel_short = _safe_float(getattr(bullish, 'velocity_short_pct', None))
            accel = _safe_float(getattr(bullish, 'acceleration', None))
            ml_trend = getattr(bullish, 'ml_trend', None) or ''

            # --- Signal ---
            regime = getattr(signal, 'regime', None) or ''
            mtf = bool(getattr(signal, 'mtf_confirmed', False))

            # ============================================================
            # Pillar 1: Phase Timing  (max 25 pts)
            # ============================================================
            phase_pts = float(self.PHASE_TIMING.get(phase, 5))

            # Bonus if regime is accumulation-like and phase is early bullish
            regime_lower = regime.lower() if regime else ''
            if phase == 'EARLY_BULLISH' and 'accum' in regime_lower:
                phase_pts = 25.0
            elif phase == 'ACCUMULATION' and 'accum' in regime_lower:
                phase_pts = 20.0
            elif phase == 'EARLY_BULLISH' and 'bull' in regime_lower:
                phase_pts = 23.0
            phase_pts = _clamp(phase_pts, 0, 25)

            # ============================================================
            # Pillar 2: Z-Score Extremity  (max 20 pts)
            # ============================================================
            if zscore < -1.5:
                zscore_pts = 20.0   # deeply oversold = excellent entry
            elif zscore < -1.0:
                zscore_pts = 17.0
            elif zscore < -0.5:
                zscore_pts = 15.0
            elif zscore < 0:
                zscore_pts = 12.0
            elif zscore <= 0.5:
                zscore_pts = 10.0
            elif zscore <= 1.0:
                zscore_pts = 8.0
            elif zscore <= 1.5:
                zscore_pts = 5.0
            else:
                zscore_pts = 2.0    # overbought = terrible entry
            zscore_pts = _clamp(zscore_pts, 0, 20)

            # ============================================================
            # Pillar 3: Velocity Reversal  (max 20 pts)
            # ============================================================
            vel_pts = 0.0
            if vel_short > 0 and accel > 0:
                # Velocity turning positive, accelerating = ideal entry
                vel_pts = 20.0
            elif vel_short > 0 and accel <= 0:
                # Positive velocity but decelerating = still OK
                vel_pts = 12.0
            elif vel_short <= 0 and accel > 0:
                # Negative velocity but acceleration turning up = reversal starting
                vel_pts = 14.0
            elif vel_short <= 0 and accel > -0.5:
                # Slow decline
                vel_pts = 6.0
            else:
                # Fast decline
                vel_pts = 2.0

            # Already fast moving upward penalty (chasing risk)
            if vel_short > 5:
                vel_pts = min(vel_pts, 10.0)
            vel_pts = _clamp(vel_pts, 0, 20)

            # ============================================================
            # Pillar 4: Cycle Position  (max 15 pts)
            # ============================================================
            cycle_pts = 5.0  # default when no data
            if range_pos is not None:
                if range_pos < 25:
                    cycle_pts = 15.0   # near support = ideal
                elif range_pos < 40:
                    cycle_pts = 12.0
                elif range_pos < 60:
                    cycle_pts = 8.0
                elif range_pos < 75:
                    cycle_pts = 4.0
                else:
                    cycle_pts = 2.0    # near resistance = bad timing
            cycle_pts = _clamp(cycle_pts, 0, 15)

            # ============================================================
            # Pillar 5: MTF Alignment  (max 20 pts)
            # ============================================================
            mtf_pts = 0.0
            ml_lower = ml_trend.lower() if ml_trend else ''
            if mtf and ml_lower == 'bullish':
                mtf_pts = 20.0
            elif mtf and ml_lower == 'neutral':
                mtf_pts = 14.0
            elif mtf:
                mtf_pts = 10.0
            elif ml_lower == 'bullish':
                mtf_pts = 12.0
            elif ml_lower == 'neutral':
                mtf_pts = 6.0
            else:
                mtf_pts = 2.0
            mtf_pts = _clamp(mtf_pts, 0, 20)

            # ============================================================
            # Aggregate
            # ============================================================
            timing_score = round(
                _clamp(phase_pts + zscore_pts + vel_pts + cycle_pts + mtf_pts, 0, 100),
                1,
            )

            # Verdict
            verdict_label = 'AVOID'
            for threshold, label in _VERDICT_THRESHOLDS:
                if timing_score >= threshold:
                    verdict_label = label
                    break

            entry_window = _ENTRY_WINDOW.get(verdict_label, 'Unknown')

            return {
                'asset_id': asset.id,
                'symbol': symbol.upper(),
                'name': name,
                'icon_thumb_url': icon,
                'current_price_idr': current_price,
                'timing_score': timing_score,
                'verdict': verdict_label,
                'components': {
                    'phase_timing': round(phase_pts, 2),
                    'zscore_extremity': round(zscore_pts, 2),
                    'velocity_reversal': round(vel_pts, 2),
                    'cycle_position': round(cycle_pts, 2),
                    'mtf_alignment': round(mtf_pts, 2),
                },
                'entry_window': entry_window,
                'zscore': round(zscore, 4),
                'range_position_pct': round(range_pos, 2) if range_pos is not None else None,
                'bullish_phase': phase,
                'momentum_state': mom_state,
                'ml_trend': ml_trend,
                'support_price': support,
                'resistance_price': resistance,
            }
        except Exception as exc:
            logger.warning("PrecisionTimerService._score_asset error: %s", exc)
            return None

    # ------------------------------------------------------------------
    # Stats
    # ------------------------------------------------------------------
    def _compute_stats(self, scored):
        """Build aggregate statistics for the full scored list."""
        total = len(scored)
        if total == 0:
            return self._empty_stats()

        scores = [s['timing_score'] for s in scored]
        avg_ts = round(sum(scores) / total, 1)

        now_cnt = sum(1 for s in scored if s['verdict'] == 'NOW')
        soon_cnt = sum(1 for s in scored if s['verdict'] == 'SOON')
        wait_cnt = sum(1 for s in scored if s['verdict'] == 'WAIT')
        avoid_cnt = sum(1 for s in scored if s['verdict'] == 'AVOID')

        return {
            'total': total,
            'now_count': now_cnt,
            'soon_count': soon_cnt,
            'wait_count': wait_cnt,
            'avoid_count': avoid_cnt,
            'avg_timing_score': avg_ts,
        }

    # ------------------------------------------------------------------
    # Sorting
    # ------------------------------------------------------------------
    def _sort_items(self, items, sort_by):
        """Sort items by *sort_by* key."""
        key_name, descending = self.SORT_OPTIONS.get(
            sort_by, ('timing_score', True),
        )
        try:
            if key_name == 'symbol':
                items.sort(
                    key=lambda x: (x.get(key_name) or '').lower(),
                    reverse=descending,
                )
            else:
                items.sort(
                    key=lambda x: _safe_float(x.get(key_name, 0)),
                    reverse=descending,
                )
        except Exception:
            pass
        return items

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------
    @staticmethod
    def _empty_result(page=1, limit=50):
        return {
            'items': [],
            'total': 0,
            'page': page,
            'limit': limit,
            'has_more': False,
            'stats': PrecisionTimerService._empty_stats(),
        }

    @staticmethod
    def _empty_stats():
        return {
            'total': 0,
            'now_count': 0,
            'soon_count': 0,
            'wait_count': 0,
            'avoid_count': 0,
            'avg_timing_score': 0.0,
        }
