"""ChartPatternService - Detects classic chart patterns from price data.

Identifies Head & Shoulders, Double Top/Bottom, Triangles, Wedges,
Flags & Pennants from local peaks/troughs analysis.

Data sources (all lazy-imported):
- Asset: id, symbol, name, asset_type, is_active, market_cap_rank
- AssetProfile: profile_json (sparkline_in_7d)
- OHLCVData: close, high, low prices
"""
from __future__ import annotations

import logging
import math
from datetime import datetime

logger = logging.getLogger(__name__)


class ChartPatternService:
    """Detects classic chart patterns and generates trade setups."""

    SIMILARITY_PCT = 0.02   # 2% tolerance for matching peak/trough levels

    def __init__(self):
        pass

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def analyze(self, symbol=None, asset_type='crypto', **kwargs):
        """Detect all chart patterns for a asset.

        Returns list of detected patterns sorted by reliability.
        """
        import numpy as np
        from app.extensions import db
        from app.models.asset import Asset

        try:
            asset = self._resolve_coin(symbol, asset_type)
            if asset is None:
                return {'status': 'error', 'message': f'Asset not found: {symbol}'}

            prices = self._get_prices(asset)
            if prices is None or len(prices) < 30:
                return {
                    'status': 'error',
                    'message': f'Insufficient data for {symbol}',
                }

            peaks, troughs = self._find_peaks_troughs(prices, order=5)

            patterns = []

            # Run all pattern detectors
            hs = self._detect_head_shoulders(peaks, troughs, prices)
            if hs:
                patterns.extend(hs)

            dt_db = self._detect_double_top_bottom(peaks, troughs, prices)
            if dt_db:
                patterns.extend(dt_db)

            tri = self._detect_triangle(peaks, troughs, prices)
            if tri:
                patterns.extend(tri)

            wedge = self._detect_wedge(peaks, troughs, prices)
            if wedge:
                patterns.extend(wedge)

            flag = self._detect_flag_pennant(prices)
            if flag:
                patterns.extend(flag)

            # Sort by reliability descending
            patterns.sort(key=lambda p: p.get('reliability', 0), reverse=True)

            current_price = float(prices[-1])

            return {
                'status': 'success',
                'data': {
                    'symbol': asset.symbol,
                    'asset_id': asset.id,
                    'current_price': current_price,
                    'patterns': patterns,
                    'total_patterns': len(patterns),
                    'peaks_found': len(peaks),
                    'troughs_found': len(troughs),
                    'analyzed_at': datetime.utcnow().isoformat(),
                },
            }

        except Exception as e:
            logger.exception(f'ChartPatternService.analyze error for {symbol}')
            return {'status': 'error', 'message': str(e)[:200]}

    def scan_all(self, asset_type='crypto', limit=50, **kwargs):
        """Scan top assets for chart patterns."""
        from app.models.asset import Asset

        try:
            assets = (
                Asset.query
                .filter_by(asset_type=asset_type, is_active=True)
                .order_by(Asset.market_cap_rank.asc())
                .limit(limit)
                .all()
            )

            results = []
            for asset in assets:
                result = self.analyze(symbol=asset.symbol, asset_type=asset_type)
                if result.get('status') == 'success':
                    d = result['data']
                    if d['total_patterns'] > 0:
                        top_pattern = d['patterns'][0]
                        results.append({
                            'symbol': asset.symbol,
                            'asset_id': asset.id,
                            'name': asset.name,
                            'total_patterns': d['total_patterns'],
                            'top_pattern': top_pattern.get('type', ''),
                            'top_reliability': top_pattern.get('reliability', 0),
                            'top_direction': top_pattern.get('direction', 'neutral'),
                            'target_price': top_pattern.get('target_price'),
                        })

            results.sort(key=lambda x: x['top_reliability'], reverse=True)

            return {
                'status': 'success',
                'data': {
                    'results': results,
                    'total_with_patterns': len(results),
                    'asset_type': asset_type,
                    'scanned_at': datetime.utcnow().isoformat(),
                },
            }

        except Exception as e:
            logger.exception('ChartPatternService.scan_all error')
            return {'status': 'error', 'message': str(e)[:200]}

    # ------------------------------------------------------------------
    # Peak / trough detection
    # ------------------------------------------------------------------

    def _find_peaks_troughs(self, prices, order=5):
        """Find local peaks and troughs using neighbor comparison.

        A peak is higher than `order` neighbors on each side.
        A trough is lower than `order` neighbors on each side.

        Returns:
            peaks: list of (index, price_value)
            troughs: list of (index, price_value)
        """
        import numpy as np

        prices = np.asarray(prices, dtype=float)
        n = len(prices)
        peaks = []
        troughs = []

        for i in range(order, n - order):
            window = prices[i - order: i + order + 1]
            if prices[i] == np.max(window):
                peaks.append((i, float(prices[i])))
            elif prices[i] == np.min(window):
                troughs.append((i, float(prices[i])))

        return peaks, troughs

    # ------------------------------------------------------------------
    # Pattern detectors
    # ------------------------------------------------------------------

    def _detect_head_shoulders(self, peaks, troughs, prices):
        """Detect Head & Shoulders and Inverse H&S patterns.

        H&S: 3 peaks where middle is highest, 2 troughs form neckline.
        Inverse: 3 troughs where middle is lowest, 2 peaks form neckline.
        """
        import numpy as np

        patterns = []
        prices = np.asarray(prices, dtype=float)
        current = float(prices[-1])

        # Regular H&S (bearish)
        if len(peaks) >= 3 and len(troughs) >= 2:
            for i in range(len(peaks) - 2):
                left_shoulder = peaks[i]
                head = peaks[i + 1]
                right_shoulder = peaks[i + 2]

                # Head must be highest
                if head[1] <= left_shoulder[1] or head[1] <= right_shoulder[1]:
                    continue

                # Shoulders should be roughly equal (within tolerance)
                shoulder_ratio = min(left_shoulder[1], right_shoulder[1]) / max(left_shoulder[1], right_shoulder[1])
                if shoulder_ratio < (1 - self.SIMILARITY_PCT * 5):
                    continue

                # Find neckline troughs between shoulders
                neckline_troughs = [
                    t for t in troughs
                    if left_shoulder[0] < t[0] < right_shoulder[0]
                ]
                if len(neckline_troughs) < 1:
                    continue

                neckline = min(t[1] for t in neckline_troughs)
                pattern_height = head[1] - neckline
                target = neckline - pattern_height  # Measured move down

                reliability = 70
                # Boost if right shoulder is lower than left (more bearish)
                if right_shoulder[1] < left_shoulder[1]:
                    reliability += 10
                # Boost if current price is near/below neckline
                if current <= neckline * 1.01:
                    reliability += 10

                patterns.append({
                    'type': 'head_and_shoulders',
                    'direction': 'bearish',
                    'reliability': min(reliability, 95),
                    'neckline': round(neckline, 8),
                    'head_price': round(head[1], 8),
                    'left_shoulder': round(left_shoulder[1], 8),
                    'right_shoulder': round(right_shoulder[1], 8),
                    'target_price': round(max(target, 0), 8),
                    'entry_price': round(neckline * 0.995, 8),
                    'stop_loss': round(right_shoulder[1] * 1.01, 8),
                })

        # Inverse H&S (bullish)
        if len(troughs) >= 3 and len(peaks) >= 2:
            for i in range(len(troughs) - 2):
                left_shoulder = troughs[i]
                head = troughs[i + 1]
                right_shoulder = troughs[i + 2]

                # Head must be lowest
                if head[1] >= left_shoulder[1] or head[1] >= right_shoulder[1]:
                    continue

                shoulder_ratio = min(left_shoulder[1], right_shoulder[1]) / max(left_shoulder[1], right_shoulder[1])
                if shoulder_ratio < (1 - self.SIMILARITY_PCT * 5):
                    continue

                neckline_peaks = [
                    p for p in peaks
                    if left_shoulder[0] < p[0] < right_shoulder[0]
                ]
                if len(neckline_peaks) < 1:
                    continue

                neckline = max(p[1] for p in neckline_peaks)
                pattern_height = neckline - head[1]
                target = neckline + pattern_height

                reliability = 70
                if right_shoulder[1] > left_shoulder[1]:
                    reliability += 10
                if current >= neckline * 0.99:
                    reliability += 10

                patterns.append({
                    'type': 'inverse_head_and_shoulders',
                    'direction': 'bullish',
                    'reliability': min(reliability, 95),
                    'neckline': round(neckline, 8),
                    'head_price': round(head[1], 8),
                    'left_shoulder': round(left_shoulder[1], 8),
                    'right_shoulder': round(right_shoulder[1], 8),
                    'target_price': round(target, 8),
                    'entry_price': round(neckline * 1.005, 8),
                    'stop_loss': round(right_shoulder[1] * 0.99, 8),
                })

        return patterns

    def _detect_double_top_bottom(self, peaks, troughs, prices):
        """Detect Double Top (bearish) and Double Bottom (bullish)."""
        import numpy as np

        patterns = []
        prices = np.asarray(prices, dtype=float)
        current = float(prices[-1])
        tol = self.SIMILARITY_PCT

        # Double Top
        for i in range(len(peaks) - 1):
            p1 = peaks[i]
            p2 = peaks[i + 1]

            ratio = abs(p1[1] - p2[1]) / max(p1[1], p2[1])
            if ratio > tol:
                continue

            # Find trough between them
            mid_troughs = [t for t in troughs if p1[0] < t[0] < p2[0]]
            if not mid_troughs:
                continue

            neckline = min(t[1] for t in mid_troughs)
            top_level = (p1[1] + p2[1]) / 2.0
            height = top_level - neckline
            target = neckline - height

            reliability = 65
            if current < neckline:
                reliability += 15
            if p2[1] < p1[1]:
                reliability += 5

            patterns.append({
                'type': 'double_top',
                'direction': 'bearish',
                'reliability': min(reliability, 90),
                'top_level': round(top_level, 8),
                'neckline': round(neckline, 8),
                'target_price': round(max(target, 0), 8),
                'entry_price': round(neckline * 0.995, 8),
                'stop_loss': round(top_level * 1.01, 8),
            })

        # Double Bottom
        for i in range(len(troughs) - 1):
            t1 = troughs[i]
            t2 = troughs[i + 1]

            ratio = abs(t1[1] - t2[1]) / max(t1[1], t2[1])
            if ratio > tol:
                continue

            mid_peaks = [p for p in peaks if t1[0] < p[0] < t2[0]]
            if not mid_peaks:
                continue

            neckline = max(p[1] for p in mid_peaks)
            bottom_level = (t1[1] + t2[1]) / 2.0
            height = neckline - bottom_level
            target = neckline + height

            reliability = 65
            if current > neckline:
                reliability += 15
            if t2[1] > t1[1]:
                reliability += 5

            patterns.append({
                'type': 'double_bottom',
                'direction': 'bullish',
                'reliability': min(reliability, 90),
                'bottom_level': round(bottom_level, 8),
                'neckline': round(neckline, 8),
                'target_price': round(target, 8),
                'entry_price': round(neckline * 1.005, 8),
                'stop_loss': round(bottom_level * 0.99, 8),
            })

        return patterns

    def _detect_triangle(self, peaks, troughs, prices):
        """Detect Ascending, Descending, and Symmetrical triangles."""
        import numpy as np

        patterns = []
        prices = np.asarray(prices, dtype=float)
        current = float(prices[-1])

        if len(peaks) < 2 or len(troughs) < 2:
            return patterns

        # Use last several peaks/troughs for trend line fitting
        recent_peaks = peaks[-4:] if len(peaks) >= 4 else peaks[-2:]
        recent_troughs = troughs[-4:] if len(troughs) >= 4 else troughs[-2:]

        if len(recent_peaks) < 2 or len(recent_troughs) < 2:
            return patterns

        # Fit linear trend to peaks and troughs
        peak_indices = np.array([p[0] for p in recent_peaks], dtype=float)
        peak_values = np.array([p[1] for p in recent_peaks], dtype=float)
        trough_indices = np.array([t[0] for t in recent_troughs], dtype=float)
        trough_values = np.array([t[1] for t in recent_troughs], dtype=float)

        peak_slope = self._linear_slope(peak_indices, peak_values)
        trough_slope = self._linear_slope(trough_indices, trough_values)

        # Normalize slopes relative to price magnitude
        avg_price = float(np.mean(prices))
        if avg_price == 0:
            return patterns

        norm_peak_slope = peak_slope / avg_price
        norm_trough_slope = trough_slope / avg_price

        flat_threshold = 0.0002  # ~0.02% per candle
        converging_check = norm_peak_slope < norm_trough_slope  # Converging

        # Ascending Triangle: flat tops + rising bottoms
        if abs(norm_peak_slope) < flat_threshold and norm_trough_slope > flat_threshold:
            resistance = float(np.mean(peak_values))
            height = resistance - float(trough_values[-1])
            target = resistance + height

            patterns.append({
                'type': 'ascending_triangle',
                'direction': 'bullish',
                'reliability': 70,
                'resistance': round(resistance, 8),
                'support_rising': True,
                'target_price': round(target, 8),
                'entry_price': round(resistance * 1.005, 8),
                'stop_loss': round(float(trough_values[-1]) * 0.99, 8),
            })

        # Descending Triangle: flat bottoms + descending tops
        elif abs(norm_trough_slope) < flat_threshold and norm_peak_slope < -flat_threshold:
            support = float(np.mean(trough_values))
            height = float(peak_values[-1]) - support
            target = support - height

            patterns.append({
                'type': 'descending_triangle',
                'direction': 'bearish',
                'reliability': 70,
                'support': round(support, 8),
                'resistance_falling': True,
                'target_price': round(max(target, 0), 8),
                'entry_price': round(support * 0.995, 8),
                'stop_loss': round(float(peak_values[-1]) * 1.01, 8),
            })

        # Symmetrical Triangle: descending peaks + ascending troughs (converging)
        elif norm_peak_slope < -flat_threshold and norm_trough_slope > flat_threshold:
            mid_price = (float(peak_values[-1]) + float(trough_values[-1])) / 2.0
            height = float(peak_values[0]) - float(trough_values[0])
            # Direction based on preceding trend
            pre_trend_start = max(0, recent_troughs[0][0] - 20)
            pre_trend = prices[recent_troughs[0][0]] - prices[pre_trend_start]
            direction = 'bullish' if pre_trend > 0 else 'bearish'

            target = mid_price + height if direction == 'bullish' else mid_price - height

            patterns.append({
                'type': 'symmetrical_triangle',
                'direction': direction,
                'reliability': 60,
                'apex_price': round(mid_price, 8),
                'target_price': round(max(target, 0), 8),
                'entry_price': round(mid_price, 8),
                'stop_loss': round(
                    float(trough_values[-1]) * 0.99 if direction == 'bullish'
                    else float(peak_values[-1]) * 1.01, 8
                ),
            })

        return patterns

    def _detect_wedge(self, peaks, troughs, prices):
        """Detect Rising Wedge (bearish) and Falling Wedge (bullish)."""
        import numpy as np

        patterns = []
        prices = np.asarray(prices, dtype=float)

        if len(peaks) < 2 or len(troughs) < 2:
            return patterns

        recent_peaks = peaks[-4:] if len(peaks) >= 4 else peaks[-2:]
        recent_troughs = troughs[-4:] if len(troughs) >= 4 else troughs[-2:]

        if len(recent_peaks) < 2 or len(recent_troughs) < 2:
            return patterns

        peak_indices = np.array([p[0] for p in recent_peaks], dtype=float)
        peak_values = np.array([p[1] for p in recent_peaks], dtype=float)
        trough_indices = np.array([t[0] for t in recent_troughs], dtype=float)
        trough_values = np.array([t[1] for t in recent_troughs], dtype=float)

        peak_slope = self._linear_slope(peak_indices, peak_values)
        trough_slope = self._linear_slope(trough_indices, trough_values)

        avg_price = float(np.mean(prices))
        if avg_price == 0:
            return patterns

        norm_ps = peak_slope / avg_price
        norm_ts = trough_slope / avg_price

        slope_threshold = 0.0002

        # Rising Wedge: both rising but converging (peak slope < trough slope)
        if norm_ps > slope_threshold and norm_ts > slope_threshold and norm_ps < norm_ts:
            height = float(peak_values[-1]) - float(trough_values[-1])
            target = float(trough_values[-1]) - height

            patterns.append({
                'type': 'rising_wedge',
                'direction': 'bearish',
                'reliability': 65,
                'target_price': round(max(target, 0), 8),
                'entry_price': round(float(trough_values[-1]) * 0.995, 8),
                'stop_loss': round(float(peak_values[-1]) * 1.01, 8),
            })

        # Falling Wedge: both falling but converging (peak slope more negative)
        elif norm_ps < -slope_threshold and norm_ts < -slope_threshold and norm_ps < norm_ts:
            height = float(peak_values[-1]) - float(trough_values[-1])
            target = float(peak_values[-1]) + height

            patterns.append({
                'type': 'falling_wedge',
                'direction': 'bullish',
                'reliability': 65,
                'target_price': round(target, 8),
                'entry_price': round(float(peak_values[-1]) * 1.005, 8),
                'stop_loss': round(float(trough_values[-1]) * 0.99, 8),
            })

        return patterns

    def _detect_flag_pennant(self, prices):
        """Detect Flag and Pennant patterns after strong moves.

        Flag: After sharp impulse, a parallel channel in the opposite direction.
        Pennant: After sharp impulse, converging lines (small symmetrical triangle).
        """
        import numpy as np

        patterns = []
        prices = np.asarray(prices, dtype=float)
        n = len(prices)

        if n < 30:
            return patterns

        # Look for strong impulse in the last 50% of data
        mid = n // 2
        impulse_section = prices[:mid]
        flag_section = prices[mid:]

        if len(impulse_section) < 10 or len(flag_section) < 10:
            return patterns

        # Measure impulse strength
        impulse_change = (impulse_section[-1] - impulse_section[0]) / impulse_section[0]
        flag_change = (flag_section[-1] - flag_section[0]) / flag_section[0]

        # Need strong impulse (>5%) and mild counter-move (flag)
        if abs(impulse_change) < 0.05:
            return patterns

        # Flag: counter-move is in opposite direction but smaller
        if abs(flag_change) < abs(impulse_change) * 0.5:
            # Check if flag moves opposite to impulse
            is_bullish_flag = impulse_change > 0 and flag_change <= 0
            is_bearish_flag = impulse_change < 0 and flag_change >= 0

            if is_bullish_flag or is_bearish_flag:
                direction = 'bullish' if is_bullish_flag else 'bearish'
                # Measured move = flag pole height projected from flag
                pole_height = abs(impulse_section[-1] - impulse_section[0])
                current = float(prices[-1])

                if direction == 'bullish':
                    target = current + pole_height
                    stop = float(np.min(flag_section)) * 0.99
                else:
                    target = current - pole_height
                    stop = float(np.max(flag_section)) * 1.01

                # Check volatility contraction in flag (pennant vs flag)
                flag_range = float(np.max(flag_section) - np.min(flag_section))
                impulse_range = float(np.max(impulse_section) - np.min(impulse_section))
                is_pennant = flag_range < impulse_range * 0.3

                ptype = 'pennant' if is_pennant else 'flag'
                reliability = 60 if ptype == 'flag' else 55

                # Boost if flag is tight
                if flag_range / impulse_range < 0.2:
                    reliability += 10

                patterns.append({
                    'type': f'bull_{ptype}' if direction == 'bullish' else f'bear_{ptype}',
                    'direction': direction,
                    'reliability': min(reliability, 85),
                    'impulse_change_pct': round(impulse_change * 100, 2),
                    'flag_change_pct': round(flag_change * 100, 2),
                    'target_price': round(max(target, 0), 8),
                    'entry_price': round(current, 8),
                    'stop_loss': round(max(stop, 0), 8),
                })

        return patterns

    # ------------------------------------------------------------------
    # Utilities
    # ------------------------------------------------------------------

    def _linear_slope(self, x, y):
        """Simple linear regression slope."""
        import numpy as np

        if len(x) < 2:
            return 0.0
        x = np.asarray(x, dtype=float)
        y = np.asarray(y, dtype=float)
        x_mean = np.mean(x)
        y_mean = np.mean(y)
        denom = np.sum((x - x_mean) ** 2)
        if denom == 0:
            return 0.0
        return float(np.sum((x - x_mean) * (y - y_mean)) / denom)

    def _resolve_coin(self, symbol, asset_type):
        """Look up a Asset by symbol and asset_type."""
        from app.models.asset import Asset

        if symbol is None:
            return None
        return (
            Asset.query
            .filter(
                Asset.symbol.ilike(symbol),
                Asset.asset_type == asset_type,
                Asset.is_active.is_(True),
            )
            .first()
        )

    def _get_prices(self, asset):
        """Get close-price array for a asset.

        Tries OHLCV data first, falls back to sparkline.
        """
        import numpy as np
        import json

        # 1) OHLCV hourly data
        try:
            from app.models.ohlcv import OHLCVData
            rows = (
                OHLCVData.query
                .filter_by(asset_id=asset.id, timeframe='1h')
                .order_by(OHLCVData.datetime_wib.asc())
                .limit(500)
                .all()
            )
            if rows and len(rows) >= 30:
                return np.array([float(r.close) for r in rows], dtype=float)
        except Exception:
            pass

        # 2) Sparkline fallback
        try:
            from app.models.asset import AssetProfile
            profile = (
                AssetProfile.query
                .filter_by(asset_id=asset.id)
                .order_by(AssetProfile.fetched_at.desc())
                .first()
            )
            if profile and profile.profile_json:
                pj = profile.profile_json
                if isinstance(pj, str):
                    pj = json.loads(pj)
                sparkline = pj.get('sparkline_in_7d', {})
                if isinstance(sparkline, dict):
                    price_list = sparkline.get('price', [])
                elif isinstance(sparkline, list):
                    price_list = sparkline
                else:
                    price_list = []
                if price_list and len(price_list) >= 30:
                    return np.array(price_list, dtype=float)
        except Exception:
            pass

        return None
