"""HarmonicPatternService — detects harmonic patterns using Fibonacci ratios.

Identifies Gartley, Butterfly, Bat, Crab, and Shark patterns from swing
point analysis on price data.  Each pattern is scored by completion
percentage and reliability to rank trade setups.

Data sources (all lazy-imported):
- Asset: id, symbol, name, asset_type, is_active, market_cap_rank
- AssetProfile: profile_json -> sparkline_in_7d (168 hourly data points)
"""
from __future__ import annotations

import logging
import math
from datetime import datetime

logger = logging.getLogger(__name__)


def _safe_float(val, default: float = 0.0) -> float:
    if val is None:
        return default
    try:
        f = float(val)
        return default if (math.isnan(f) or math.isinf(f)) else f
    except (ValueError, TypeError):
        return default


class HarmonicPatternService:
    """Detects harmonic patterns (Gartley, Butterfly, Bat, Crab, Shark)."""

    # Fibonacci tolerances
    TOLERANCE = 0.06  # +-6% ratio tolerance

    def __init__(self):
        pass

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def analyze(self, symbol=None, asset_type='crypto', **kwargs):
        """Find harmonic patterns for a single asset."""
        import json
        import numpy as np

        try:
            asset = self._resolve_coin(symbol, asset_type)
            if asset is None:
                return {'status': 'error', 'message': f'Asset not found: {symbol}'}

            prices = self._get_prices(asset)
            if prices is None or len(prices) < 30:
                return {'status': 'error', 'message': f'Insufficient data for {symbol}'}

            prices = np.asarray(prices, dtype=float)
            current_price = float(prices[-1])

            swing_points = self._find_swing_points(prices, order=5)
            if len(swing_points) < 5:
                return {
                    'status': 'success',
                    'data': {
                        'symbol': asset.symbol, 'current_price': current_price,
                        'patterns': [], 'swing_points': swing_points,
                    },
                }

            patterns = []
            # Slide a window of 5 swing points across all detected swings
            for i in range(len(swing_points) - 4):
                swings = swing_points[i:i + 5]

                for checker, name in [
                    (self._check_gartley, 'Gartley'),
                    (self._check_butterfly, 'Butterfly'),
                    (self._check_bat, 'Bat'),
                    (self._check_crab, 'Crab'),
                    (self._check_shark, 'Shark'),
                ]:
                    result = checker(swings, current_price)
                    if result is not None:
                        result['type'] = name
                        patterns.append(result)

            # Deduplicate by keeping highest reliability per type
            seen = {}
            for p in patterns:
                key = (p['type'], p['direction'])
                if key not in seen or p['reliability'] > seen[key]['reliability']:
                    seen[key] = p
            patterns = sorted(seen.values(), key=lambda x: x['reliability'], reverse=True)

            return {
                'status': 'success',
                'data': {
                    'symbol': asset.symbol,
                    'current_price': current_price,
                    'patterns': patterns,
                    'swing_points': swing_points,
                },
            }

        except Exception as e:
            logger.exception(f'HarmonicPatternService.analyze error for {symbol}')
            return {'status': 'error', 'message': str(e)[:200]}

    def scan_all(self, asset_type='crypto', limit=50, **kwargs):
        """Scan top assets for harmonic patterns."""
        from app.models.asset import Asset

        try:
            assets = (
                Asset.query
                .filter_by(asset_type=asset_type, is_active=True)
                .order_by(Asset.market_cap_rank.asc())
                .limit(limit)
                .all()
            )

            results = []
            for asset in assets:
                result = self.analyze(symbol=asset.symbol, asset_type=asset_type)
                if result.get('status') == 'success':
                    d = result['data']
                    if d['patterns']:
                        top = d['patterns'][0]
                        results.append({
                            'symbol': asset.symbol,
                            'asset_id': asset.id,
                            'name': asset.name,
                            'current_price': d['current_price'],
                            'total_patterns': len(d['patterns']),
                            'top_pattern': top.get('type', ''),
                            'top_direction': top.get('direction', 'neutral'),
                            'top_reliability': top.get('reliability', 0),
                            'completion': top.get('completion', 0),
                        })

            results.sort(key=lambda x: x['top_reliability'], reverse=True)

            return {
                'status': 'success',
                'data': {
                    'results': results,
                    'total_with_patterns': len(results),
                    'asset_type': asset_type,
                    'scanned_at': datetime.utcnow().isoformat(),
                },
            }

        except Exception as e:
            logger.exception('HarmonicPatternService.scan_all error')
            return {'status': 'error', 'message': str(e)[:200]}

    # ------------------------------------------------------------------
    # Swing point detection
    # ------------------------------------------------------------------

    def _find_swing_points(self, prices, order=5):
        """Find significant peaks and troughs.

        Returns list of dicts with 'price', 'index', 'type' (high/low).
        """
        import numpy as np

        prices = np.asarray(prices, dtype=float)
        n = len(prices)
        points = []

        for i in range(order, n - order):
            window = prices[i - order: i + order + 1]
            val = float(prices[i])
            if val == float(np.max(window)):
                points.append({'price': val, 'index': int(i), 'type': 'high'})
            elif val == float(np.min(window)):
                points.append({'price': val, 'index': int(i), 'type': 'low'})

        # Ensure alternating high/low by keeping the more extreme of consecutive same-type
        filtered = []
        for pt in points:
            if filtered and filtered[-1]['type'] == pt['type']:
                if pt['type'] == 'high' and pt['price'] > filtered[-1]['price']:
                    filtered[-1] = pt
                elif pt['type'] == 'low' and pt['price'] < filtered[-1]['price']:
                    filtered[-1] = pt
            else:
                filtered.append(pt)

        return filtered

    # ------------------------------------------------------------------
    # Fibonacci helpers
    # ------------------------------------------------------------------

    def _fibonacci_ratio(self, a, b, c):
        """Calculate retracement ratio of c relative to move a->b.

        Returns abs(c - b) / abs(b - a).  Returns None if move is zero.
        """
        move = abs(b - a)
        if move < 1e-12:
            return None
        return abs(c - b) / move

    def _in_range(self, value, target, tol=None):
        """Check if value is within tolerance of target."""
        if value is None:
            return False
        tol = tol or self.TOLERANCE
        return abs(value - target) <= tol

    def _in_band(self, value, low, high, tol=None):
        """Check if value falls within [low-tol, high+tol]."""
        if value is None:
            return False
        tol = tol or self.TOLERANCE
        return (low - tol) <= value <= (high + tol)

    def _build_pattern(self, swings, direction, completion, reliability,
                       ratios, current_price):
        """Build a standard pattern result dict."""
        pts = {
            'X': swings[0]['price'], 'A': swings[1]['price'],
            'B': swings[2]['price'], 'C': swings[3]['price'],
            'D': swings[4]['price'],
        }
        d_price = pts['D']

        if direction == 'bullish':
            prz_low = d_price * 0.99
            prz_high = d_price * 1.01
            target_1 = d_price + abs(pts['A'] - d_price) * 0.382
            target_2 = d_price + abs(pts['A'] - d_price) * 0.618
            stop_loss = d_price * 0.97
        else:
            prz_low = d_price * 0.99
            prz_high = d_price * 1.01
            target_1 = d_price - abs(d_price - pts['A']) * 0.382
            target_2 = d_price - abs(d_price - pts['A']) * 0.618
            stop_loss = d_price * 1.03

        return {
            'direction': direction,
            'completion': round(completion, 1),
            'reliability': round(reliability, 1),
            'prz_high': round(prz_high, 8),
            'prz_low': round(prz_low, 8),
            'target_1': round(target_1, 8),
            'target_2': round(target_2, 8),
            'stop_loss': round(stop_loss, 8),
            'ratios': {k: round(v, 4) if v is not None else None for k, v in ratios.items()},
            'swing_points': {k: round(v, 8) for k, v in pts.items()},
        }

    def _pattern_direction(self, swings):
        """Determine bullish/bearish from swing structure.

        Bullish when X is low and A is high (price rises first).
        Bearish when X is high and A is low.
        """
        if swings[0]['type'] == 'low' and swings[1]['type'] == 'high':
            return 'bullish'
        return 'bearish'

    def _completion_pct(self, swings, current_price):
        """How close current price is to the D-point completion zone."""
        d = swings[4]['price']
        c = swings[3]['price']
        if abs(d - c) < 1e-12:
            return 100.0
        progress = abs(current_price - c) / abs(d - c) * 100
        return min(progress, 100.0)

    # ------------------------------------------------------------------
    # Pattern checkers
    # ------------------------------------------------------------------

    def _check_gartley(self, swings, current_price):
        """Gartley: AB=0.618 of XA, BC=0.382-0.886 of AB, CD=1.27-1.618 of BC."""
        x, a, b, c, d = [s['price'] for s in swings]
        ab_xa = self._fibonacci_ratio(x, a, b)
        bc_ab = self._fibonacci_ratio(a, b, c)
        cd_bc = self._fibonacci_ratio(b, c, d)

        if not self._in_range(ab_xa, 0.618):
            return None
        if not self._in_band(bc_ab, 0.382, 0.886):
            return None
        if not self._in_band(cd_bc, 1.27, 1.618):
            return None

        direction = self._pattern_direction(swings)
        completion = self._completion_pct(swings, current_price)

        # Reliability based on how close ratios are to ideal
        ab_score = max(0, 1.0 - abs(ab_xa - 0.618) / 0.618) * 30
        bc_mid = 0.634
        bc_score = max(0, 1.0 - abs(bc_ab - bc_mid) / 0.5) * 25
        cd_mid = 1.444
        cd_score = max(0, 1.0 - abs(cd_bc - cd_mid) / 0.5) * 25
        comp_score = min(completion, 100) / 100 * 20
        reliability = min(ab_score + bc_score + cd_score + comp_score, 100)

        ratios = {'XA': 1.0, 'AB': ab_xa, 'BC': bc_ab, 'CD': cd_bc}
        return self._build_pattern(swings, direction, completion, reliability,
                                   ratios, current_price)

    def _check_butterfly(self, swings, current_price):
        """Butterfly: AB=0.786 of XA, BC=0.382-0.886, CD=1.618-2.618 of BC."""
        x, a, b, c, d = [s['price'] for s in swings]
        ab_xa = self._fibonacci_ratio(x, a, b)
        bc_ab = self._fibonacci_ratio(a, b, c)
        cd_bc = self._fibonacci_ratio(b, c, d)

        if not self._in_range(ab_xa, 0.786):
            return None
        if not self._in_band(bc_ab, 0.382, 0.886):
            return None
        if not self._in_band(cd_bc, 1.618, 2.618):
            return None

        direction = self._pattern_direction(swings)
        completion = self._completion_pct(swings, current_price)

        ab_score = max(0, 1.0 - abs(ab_xa - 0.786) / 0.786) * 28
        bc_score = max(0, 1.0 - abs(bc_ab - 0.634) / 0.5) * 24
        cd_score = max(0, 1.0 - abs(cd_bc - 2.118) / 1.0) * 28
        comp_score = min(completion, 100) / 100 * 20
        reliability = min(ab_score + bc_score + cd_score + comp_score, 100)

        ratios = {'XA': 1.0, 'AB': ab_xa, 'BC': bc_ab, 'CD': cd_bc}
        return self._build_pattern(swings, direction, completion, reliability,
                                   ratios, current_price)

    def _check_bat(self, swings, current_price):
        """Bat: AB=0.382-0.5 of XA, BC=0.382-0.886, CD=1.618-2.618 of BC."""
        x, a, b, c, d = [s['price'] for s in swings]
        ab_xa = self._fibonacci_ratio(x, a, b)
        bc_ab = self._fibonacci_ratio(a, b, c)
        cd_bc = self._fibonacci_ratio(b, c, d)

        if not self._in_band(ab_xa, 0.382, 0.5):
            return None
        if not self._in_band(bc_ab, 0.382, 0.886):
            return None
        if not self._in_band(cd_bc, 1.618, 2.618):
            return None

        direction = self._pattern_direction(swings)
        completion = self._completion_pct(swings, current_price)

        ab_mid = 0.441
        ab_score = max(0, 1.0 - abs(ab_xa - ab_mid) / 0.2) * 28
        bc_score = max(0, 1.0 - abs(bc_ab - 0.634) / 0.5) * 24
        cd_score = max(0, 1.0 - abs(cd_bc - 2.118) / 1.0) * 28
        comp_score = min(completion, 100) / 100 * 20
        reliability = min(ab_score + bc_score + cd_score + comp_score, 100)

        ratios = {'XA': 1.0, 'AB': ab_xa, 'BC': bc_ab, 'CD': cd_bc}
        return self._build_pattern(swings, direction, completion, reliability,
                                   ratios, current_price)

    def _check_crab(self, swings, current_price):
        """Crab: AB=0.382-0.618 of XA, BC=0.382-0.886, CD=2.618-3.618."""
        x, a, b, c, d = [s['price'] for s in swings]
        ab_xa = self._fibonacci_ratio(x, a, b)
        bc_ab = self._fibonacci_ratio(a, b, c)
        cd_bc = self._fibonacci_ratio(b, c, d)

        if not self._in_band(ab_xa, 0.382, 0.618):
            return None
        if not self._in_band(bc_ab, 0.382, 0.886):
            return None
        if not self._in_band(cd_bc, 2.618, 3.618):
            return None

        direction = self._pattern_direction(swings)
        completion = self._completion_pct(swings, current_price)

        ab_score = max(0, 1.0 - abs(ab_xa - 0.5) / 0.3) * 26
        bc_score = max(0, 1.0 - abs(bc_ab - 0.634) / 0.5) * 24
        cd_score = max(0, 1.0 - abs(cd_bc - 3.118) / 1.0) * 26
        comp_score = min(completion, 100) / 100 * 24
        reliability = min(ab_score + bc_score + cd_score + comp_score, 100)

        ratios = {'XA': 1.0, 'AB': ab_xa, 'BC': bc_ab, 'CD': cd_bc}
        return self._build_pattern(swings, direction, completion, reliability,
                                   ratios, current_price)

    def _check_shark(self, swings, current_price):
        """Shark: OX=any, XA=1.13-1.618 of OX, AB=1.618-2.24, CD=0.886 retrace.

        Uses first swing as O, second as X, etc.
        """
        o, x, a, b, d = [s['price'] for s in swings]
        xa_ox = self._fibonacci_ratio(o, x, a)
        ab_xa = self._fibonacci_ratio(x, a, b)
        # CD is 0.886 retracement of overall move
        cd_ratio = self._fibonacci_ratio(a, b, d)

        if not self._in_band(xa_ox, 1.13, 1.618):
            return None
        if not self._in_band(ab_xa, 1.618, 2.24):
            return None
        if not self._in_range(cd_ratio, 0.886, tol=0.08):
            return None

        direction = self._pattern_direction(swings)
        completion = self._completion_pct(swings, current_price)

        xa_score = max(0, 1.0 - abs(xa_ox - 1.374) / 0.5) * 28
        ab_score = max(0, 1.0 - abs(ab_xa - 1.929) / 0.6) * 28
        cd_score = max(0, 1.0 - abs(cd_ratio - 0.886) / 0.2) * 24
        comp_score = min(completion, 100) / 100 * 20
        reliability = min(xa_score + ab_score + cd_score + comp_score, 100)

        ratios = {'OX': 1.0, 'XA': xa_ox, 'AB': ab_xa, 'CD': cd_ratio}
        return self._build_pattern(swings, direction, completion, reliability,
                                   ratios, current_price)

    # ------------------------------------------------------------------
    # Data helpers
    # ------------------------------------------------------------------

    def _resolve_coin(self, symbol, asset_type):
        """Look up a Asset by symbol and asset_type."""
        from app.models.asset import Asset

        if symbol is None:
            return None
        return (
            Asset.query
            .filter(
                Asset.symbol.ilike(symbol),
                Asset.asset_type == asset_type,
                Asset.is_active.is_(True),
            )
            .first()
        )

    def _get_prices(self, asset):
        """Get close-price array for a asset.

        Tries OHLCV data first, falls back to sparkline.
        """
        import numpy as np
        import json

        # 1) OHLCV hourly data
        try:
            from app.models.ohlcv import OHLCVData
            rows = (
                OHLCVData.query
                .filter_by(asset_id=asset.id, timeframe='1h')
                .order_by(OHLCVData.datetime_wib.asc())
                .limit(500)
                .all()
            )
            if rows and len(rows) >= 30:
                return np.array([float(r.close) for r in rows], dtype=float)
        except Exception:
            pass

        # 2) Sparkline fallback
        try:
            from app.models.asset import AssetProfile
            profile = (
                AssetProfile.query
                .filter_by(asset_id=asset.id)
                .order_by(AssetProfile.fetched_at.desc())
                .first()
            )
            if profile and profile.profile_json:
                pj = profile.profile_json
                if isinstance(pj, str):
                    pj = json.loads(pj)
                sparkline = pj.get('sparkline_in_7d', {})
                if isinstance(sparkline, dict):
                    price_list = sparkline.get('price', [])
                elif isinstance(sparkline, list):
                    price_list = sparkline
                else:
                    price_list = []
                if price_list and len(price_list) >= 30:
                    return np.array(price_list, dtype=float)
        except Exception:
            pass

        return None
