"""Trailing Stop Engine — calculates 6 trailing stop strategies and recommends the best.

Strategies:
  1. ATR Trailing       — price minus ATR * multiplier
  2. Chandelier Exit    — highest high minus ATR * multiplier
  3. Parabolic SAR      — classic Parabolic Stop and Reverse
  4. Percentage Trailing — simple percentage from highest price
  5. Keltner Trailing    — lower Keltner channel as trailing stop
  6. Supertrend          — Supertrend indicator based stop

Each strategy is backtested over the 168-hour sparkline data to evaluate:
  - profit retained percentage
  - times stopped out
  - average profit per trade
  - max profit given back

Scoring (0-100):
  Best profit retention contributes 40 pts
  Fewest false stops contributes 25 pts
  Regime fit contributes 20 pts
  Current distance from price contributes 15 pts
"""
from __future__ import annotations

from app.helpers.sparkline import load_sparkline


class TrailingStopEngineService:
    """Calculates and compares 6 trailing stop strategies per asset."""

    def __init__(self):
        pass

    # ── Public API ────────────────────────────────────────────────────

    def analyze(self, symbol, asset_type='crypto', entry_price=None):
        """Calculate all trailing stop variants for a symbol."""
        import json
        import numpy as np
        from app.models.asset import Asset

        try:
            asset = Asset.query.filter_by(symbol=symbol.upper(), asset_type=asset_type).first()
            if not asset:
                return {'status': 'error', 'message': f'Asset {symbol} not found'}

            prices = load_sparkline(asset)
            if len(prices) < 30:
                return {'status': 'error', 'message': 'Insufficient data'}

            prices = [float(p) for p in prices if p and float(p) > 0]
            if len(prices) < 30:
                return {'status': 'error', 'message': 'Insufficient valid price data'}

            prices_arr = np.array(prices, dtype=float)
            current_price = float(prices_arr[-1])

            # If no entry price, use price from ~24h ago (24 data points back)
            if entry_price is None:
                idx_24h = max(0, len(prices_arr) - 24)
                entry_price = float(prices_arr[idx_24h])
            else:
                entry_price = float(entry_price)

            current_pnl_pct = round(((current_price - entry_price) / entry_price) * 100, 2) if entry_price > 0 else 0

            # Calculate all 6 strategies
            strategies = {}
            backtests = {}

            # 1. ATR Trailing
            atr_params = {'period': 14, 'multiplier': 3.0}
            atr_result = self._atr_trailing(prices_arr, **atr_params)
            atr_bt = self._backtest_trailing(prices_arr, 'atr', atr_params)
            strategies['atr_trailing'] = self._format_strategy(
                atr_result, atr_params, atr_bt, current_price
            )
            backtests['atr_trailing'] = atr_bt

            # 2. Chandelier Exit
            chan_params = {'period': 22, 'multiplier': 3.0}
            chan_result = self._chandelier_exit(prices_arr, **chan_params)
            chan_bt = self._backtest_trailing(prices_arr, 'chandelier', chan_params)
            strategies['chandelier'] = self._format_strategy(
                chan_result, chan_params, chan_bt, current_price
            )
            backtests['chandelier'] = chan_bt

            # 3. Parabolic SAR
            sar_params = {'af_start': 0.02, 'af_step': 0.02, 'af_max': 0.2}
            sar_result = self._parabolic_sar(prices_arr, **sar_params)
            sar_bt = self._backtest_trailing(prices_arr, 'parabolic_sar', sar_params)
            strategies['parabolic_sar'] = self._format_strategy(
                sar_result, sar_params, sar_bt, current_price
            )
            backtests['parabolic_sar'] = sar_bt

            # 4. Percentage Trailing
            pct_params = {'pct': 5.0}
            pct_result = self._percentage_trailing(prices_arr, **pct_params)
            pct_bt = self._backtest_trailing(prices_arr, 'percentage', pct_params)
            strategies['percentage'] = self._format_strategy(
                pct_result, pct_params, pct_bt, current_price
            )
            backtests['percentage'] = pct_bt

            # 5. Keltner Trailing
            kelt_params = {'period': 20, 'mult': 2.0}
            kelt_result = self._keltner_trailing(prices_arr, **kelt_params)
            kelt_bt = self._backtest_trailing(prices_arr, 'keltner', kelt_params)
            strategies['keltner'] = self._format_strategy(
                kelt_result, kelt_params, kelt_bt, current_price
            )
            backtests['keltner'] = kelt_bt

            # 6. Supertrend
            st_params = {'period': 10, 'multiplier': 3.0}
            st_result = self._supertrend(prices_arr, **st_params)
            st_bt = self._backtest_trailing(prices_arr, 'supertrend', st_params)
            strategies['supertrend'] = self._format_strategy(
                st_result, st_params, st_bt, current_price
            )
            backtests['supertrend'] = st_bt

            # Determine best strategy
            recommended, reason = self._recommend_best(backtests)

            # Calculate overall score
            score = self._calculate_score(backtests, prices_arr)

            # Determine regime fit
            regime_fit = self._detect_regime(prices_arr)

            return {
                'status': 'success',
                'data': {
                    'symbol': symbol.upper(),
                    'name': asset.name or symbol.upper(),
                    'current_price': current_price,
                    'entry_price': entry_price,
                    'current_pnl_pct': current_pnl_pct,
                    'strategies': strategies,
                    'recommended': recommended,
                    'recommended_reason': reason,
                    'score': score,
                    'regime_fit': regime_fit,
                }
            }

        except Exception as e:
            return {'status': 'error', 'message': str(e)}

    def scan_all(self, asset_type='crypto', limit=50):
        """Scan top assets with best trailing stop recommendations."""
        from app.models.asset import Asset

        try:
            assets = Asset.query.filter_by(
                asset_type=asset_type, is_active=True
            ).order_by(Asset.market_cap_rank.asc()).limit(limit).all()

            results = []
            for asset in assets:
                result = self.analyze(symbol=asset.symbol, asset_type=asset_type)
                if result.get('status') == 'success':
                    results.append(result['data'])

            results.sort(key=lambda x: x.get('score', 0), reverse=True)
            return {
                'status': 'success',
                'data': {'items': results, 'total': len(results)}
            }

        except Exception as e:
            return {'status': 'error', 'message': str(e)}

    # ── Strategy Calculations ─────────────────────────────────────────

    def _atr_trailing(self, prices, period=14, multiplier=3.0):
        """ATR-based trailing stop: price minus ATR * multiplier."""
        import numpy as np

        if len(prices) < period + 1:
            return {'stop_levels': np.full(len(prices), prices[-1] * 0.95), 'current_stop': prices[-1] * 0.95}

        # Calculate True Range (using price changes as proxy for OHLC)
        tr = np.zeros(len(prices))
        for i in range(1, len(prices)):
            tr[i] = abs(prices[i] - prices[i - 1])

        # ATR as EMA of TR
        atr = np.zeros(len(prices))
        atr[period] = np.mean(tr[1:period + 1])
        alpha = 2.0 / (period + 1)
        for i in range(period + 1, len(prices)):
            atr[i] = alpha * tr[i] + (1 - alpha) * atr[i - 1]

        # Fill early ATR values
        for i in range(period):
            atr[i] = atr[period] if atr[period] > 0 else np.mean(tr[1:max(2, i + 1)])

        # Trailing stop: price - ATR * multiplier
        stop_levels = np.zeros(len(prices))
        stop_levels[0] = prices[0] - atr[0] * multiplier

        for i in range(1, len(prices)):
            new_stop = prices[i] - atr[i] * multiplier
            if prices[i] > prices[i - 1]:
                stop_levels[i] = max(stop_levels[i - 1], new_stop)
            else:
                stop_levels[i] = stop_levels[i - 1]

        return {
            'stop_levels': stop_levels,
            'current_stop': float(stop_levels[-1]),
            'atr_value': float(atr[-1]),
        }

    def _chandelier_exit(self, prices, period=22, multiplier=3.0):
        """Chandelier Exit: highest high minus ATR * multiplier."""
        import numpy as np

        if len(prices) < period + 1:
            return {'stop_levels': np.full(len(prices), prices[-1] * 0.95), 'current_stop': prices[-1] * 0.95}

        # True range
        tr = np.zeros(len(prices))
        for i in range(1, len(prices)):
            tr[i] = abs(prices[i] - prices[i - 1])

        # ATR
        atr = np.zeros(len(prices))
        atr[period] = np.mean(tr[1:period + 1])
        alpha = 2.0 / (period + 1)
        for i in range(period + 1, len(prices)):
            atr[i] = alpha * tr[i] + (1 - alpha) * atr[i - 1]
        for i in range(period):
            atr[i] = atr[period]

        # Chandelier: rolling highest high - ATR * mult
        stop_levels = np.zeros(len(prices))
        for i in range(len(prices)):
            window_start = max(0, i - period + 1)
            highest_high = np.max(prices[window_start:i + 1])
            stop_levels[i] = highest_high - atr[i] * multiplier

        return {
            'stop_levels': stop_levels,
            'current_stop': float(stop_levels[-1]),
        }

    def _parabolic_sar(self, prices, af_start=0.02, af_step=0.02, af_max=0.2):
        """Parabolic SAR calculation."""
        import numpy as np

        n = len(prices)
        if n < 5:
            return {'stop_levels': np.full(n, prices[-1] * 0.95), 'current_stop': prices[-1] * 0.95}

        sar = np.zeros(n)
        af = af_start
        is_long = True
        ep = prices[0]  # extreme point
        sar[0] = prices[0] * 0.98  # initial SAR below first price

        for i in range(1, n):
            prev_sar = sar[i - 1]

            if is_long:
                sar[i] = prev_sar + af * (ep - prev_sar)
                # Make sure SAR is below the previous two prices
                if i >= 2:
                    sar[i] = min(sar[i], prices[i - 1], prices[i - 2])
                else:
                    sar[i] = min(sar[i], prices[i - 1])

                if prices[i] < sar[i]:
                    # Flip to short
                    is_long = False
                    sar[i] = ep
                    ep = prices[i]
                    af = af_start
                else:
                    if prices[i] > ep:
                        ep = prices[i]
                        af = min(af + af_step, af_max)
            else:
                sar[i] = prev_sar + af * (ep - prev_sar)
                # Make sure SAR is above the previous two prices
                if i >= 2:
                    sar[i] = max(sar[i], prices[i - 1], prices[i - 2])
                else:
                    sar[i] = max(sar[i], prices[i - 1])

                if prices[i] > sar[i]:
                    # Flip to long
                    is_long = True
                    sar[i] = ep
                    ep = prices[i]
                    af = af_start
                else:
                    if prices[i] < ep:
                        ep = prices[i]
                        af = min(af + af_step, af_max)

        return {
            'stop_levels': sar,
            'current_stop': float(sar[-1]),
            'is_long': is_long,
        }

    def _percentage_trailing(self, prices, pct=5.0):
        """Simple percentage trailing stop from highest price."""
        import numpy as np

        n = len(prices)
        stop_levels = np.zeros(n)
        running_high = prices[0]

        for i in range(n):
            if prices[i] > running_high:
                running_high = prices[i]
            stop_levels[i] = running_high * (1 - pct / 100.0)

        return {
            'stop_levels': stop_levels,
            'current_stop': float(stop_levels[-1]),
            'running_high': float(running_high),
        }

    def _keltner_trailing(self, prices, period=20, mult=2.0):
        """Lower Keltner channel as trailing stop."""
        import numpy as np

        n = len(prices)
        if n < period + 1:
            return {'stop_levels': np.full(n, prices[-1] * 0.95), 'current_stop': prices[-1] * 0.95}

        # EMA of price (midline)
        ema = np.zeros(n)
        ema[0] = prices[0]
        alpha = 2.0 / (period + 1)
        for i in range(1, n):
            ema[i] = alpha * prices[i] + (1 - alpha) * ema[i - 1]

        # ATR
        tr = np.zeros(n)
        for i in range(1, n):
            tr[i] = abs(prices[i] - prices[i - 1])

        atr = np.zeros(n)
        atr[period] = np.mean(tr[1:period + 1])
        for i in range(period + 1, n):
            atr[i] = alpha * tr[i] + (1 - alpha) * atr[i - 1]
        for i in range(period):
            atr[i] = atr[period]

        # Lower Keltner = EMA - ATR * mult
        stop_levels = ema - atr * mult

        # Make trailing (only goes up)
        for i in range(1, n):
            if stop_levels[i] < stop_levels[i - 1] and prices[i] >= prices[i - 1]:
                stop_levels[i] = stop_levels[i - 1]

        return {
            'stop_levels': stop_levels,
            'current_stop': float(stop_levels[-1]),
        }

    def _supertrend(self, prices, period=10, multiplier=3.0):
        """Supertrend indicator as trailing stop."""
        import numpy as np

        n = len(prices)
        if n < period + 1:
            return {'stop_levels': np.full(n, prices[-1] * 0.95), 'current_stop': prices[-1] * 0.95}

        # ATR calculation
        tr = np.zeros(n)
        for i in range(1, n):
            tr[i] = abs(prices[i] - prices[i - 1])

        atr = np.zeros(n)
        atr[period] = np.mean(tr[1:period + 1])
        alpha_atr = 2.0 / (period + 1)
        for i in range(period + 1, n):
            atr[i] = alpha_atr * tr[i] + (1 - alpha_atr) * atr[i - 1]
        for i in range(period):
            atr[i] = atr[period]

        # Basic bands
        mid = np.array(prices, dtype=float)  # Using close as mid price
        upper_band = mid + multiplier * atr
        lower_band = mid - multiplier * atr

        # Supertrend
        supertrend = np.zeros(n)
        direction = np.ones(n)  # 1 = up (bullish), -1 = down (bearish)

        supertrend[0] = lower_band[0]
        direction[0] = 1

        for i in range(1, n):
            # Update lower band (only goes up)
            if lower_band[i] > lower_band[i - 1] or prices[i - 1] < lower_band[i - 1]:
                lower_band[i] = lower_band[i]
            else:
                lower_band[i] = lower_band[i - 1]

            # Update upper band (only goes down)
            if upper_band[i] < upper_band[i - 1] or prices[i - 1] > upper_band[i - 1]:
                upper_band[i] = upper_band[i]
            else:
                upper_band[i] = upper_band[i - 1]

            # Determine direction
            if direction[i - 1] == 1:
                if prices[i] < lower_band[i]:
                    direction[i] = -1
                    supertrend[i] = upper_band[i]
                else:
                    direction[i] = 1
                    supertrend[i] = lower_band[i]
            else:
                if prices[i] > upper_band[i]:
                    direction[i] = 1
                    supertrend[i] = lower_band[i]
                else:
                    direction[i] = -1
                    supertrend[i] = upper_band[i]

        return {
            'stop_levels': supertrend,
            'current_stop': float(supertrend[-1]),
            'direction': int(direction[-1]),
        }

    # ── Backtesting ───────────────────────────────────────────────────

    def _backtest_trailing(self, prices, strategy, params):
        """Backtest a trailing stop strategy over the price data.

        Simulates entries after each stop-out, measures profit retained.
        """
        import numpy as np

        # Get stop levels for this strategy
        if strategy == 'atr':
            result = self._atr_trailing(prices, **params)
        elif strategy == 'chandelier':
            result = self._chandelier_exit(prices, **params)
        elif strategy == 'parabolic_sar':
            result = self._parabolic_sar(prices, **params)
        elif strategy == 'percentage':
            result = self._percentage_trailing(prices, **params)
        elif strategy == 'keltner':
            result = self._keltner_trailing(prices, **params)
        elif strategy == 'supertrend':
            result = self._supertrend(prices, **params)
        else:
            return self._empty_backtest()

        stop_levels = result['stop_levels']

        trades = []
        in_trade = True
        entry_idx = 0
        entry_price = float(prices[0])
        peak_price = entry_price

        for i in range(1, len(prices)):
            price = float(prices[i])

            if in_trade:
                if price > peak_price:
                    peak_price = price

                # Check if stopped out
                if price <= float(stop_levels[i]):
                    exit_price = float(stop_levels[i])
                    profit_pct = ((exit_price - entry_price) / entry_price) * 100 if entry_price > 0 else 0
                    max_profit_pct = ((peak_price - entry_price) / entry_price) * 100 if entry_price > 0 else 0
                    given_back = max_profit_pct - profit_pct

                    trades.append({
                        'entry_idx': entry_idx,
                        'exit_idx': i,
                        'entry_price': entry_price,
                        'exit_price': exit_price,
                        'profit_pct': round(profit_pct, 2),
                        'max_profit_pct': round(max_profit_pct, 2),
                        'given_back_pct': round(given_back, 2),
                    })

                    in_trade = False
            else:
                # Re-enter after 2 bars
                if i >= entry_idx + 3:
                    in_trade = True
                    entry_idx = i
                    entry_price = price
                    peak_price = price

        # Handle open trade
        if in_trade and entry_idx < len(prices) - 1:
            last_price = float(prices[-1])
            profit_pct = ((last_price - entry_price) / entry_price) * 100 if entry_price > 0 else 0
            max_profit_pct = ((peak_price - entry_price) / entry_price) * 100 if entry_price > 0 else 0
            trades.append({
                'entry_idx': entry_idx,
                'exit_idx': len(prices) - 1,
                'entry_price': entry_price,
                'exit_price': last_price,
                'profit_pct': round(profit_pct, 2),
                'max_profit_pct': round(max_profit_pct, 2),
                'given_back_pct': round(max_profit_pct - profit_pct, 2),
                'open': True,
            })

        if not trades:
            return self._empty_backtest()

        # Aggregate
        winning = [t for t in trades if t['profit_pct'] > 0]
        total_profit = sum(t['profit_pct'] for t in trades)
        total_max_profit = sum(t['max_profit_pct'] for t in trades)
        profit_retained = (total_profit / total_max_profit * 100) if total_max_profit > 0 else 0
        avg_profit = total_profit / len(trades) if trades else 0
        max_given_back = max(t['given_back_pct'] for t in trades) if trades else 0
        win_rate = (len(winning) / len(trades) * 100) if trades else 0
        times_stopped = len([t for t in trades if not t.get('open')])

        return {
            'profit_retained_pct': round(profit_retained, 1),
            'times_stopped': times_stopped,
            'avg_profit': round(avg_profit, 2),
            'total_trades': len(trades),
            'win_rate': round(win_rate, 1),
            'max_given_back': round(max_given_back, 2),
            'total_profit_pct': round(total_profit, 2),
        }

    def _empty_backtest(self):
        """Return empty backtest result."""
        return {
            'profit_retained_pct': 0,
            'times_stopped': 0,
            'avg_profit': 0,
            'total_trades': 0,
            'win_rate': 0,
            'max_given_back': 0,
            'total_profit_pct': 0,
        }

    # ── Recommendation & Scoring ──────────────────────────────────────

    def _recommend_best(self, backtests):
        """Recommend best strategy based on net profit retention and reliability."""
        if not backtests:
            return 'percentage', 'Default recommendation — insufficient data'

        scored = {}
        for name, bt in backtests.items():
            # Score: retention (40%) + avg_profit (30%) + low given-back (30%)
            retention_score = min(bt.get('profit_retained_pct', 0), 100) * 0.4
            avg_prof_score = min(max(bt.get('avg_profit', 0), 0), 10) * 10 * 0.3
            given_back_score = max(0, 100 - bt.get('max_given_back', 100)) * 0.3
            scored[name] = retention_score + avg_prof_score + given_back_score

        best = max(scored, key=scored.get)
        bt = backtests[best]

        retention = bt.get('profit_retained_pct', 0)
        stops = bt.get('times_stopped', 0)
        reason = (
            f'Best profit retention ({retention}%) with '
            f'{"lowest" if stops <= 2 else "acceptable"} false stops ({stops})'
        )

        return best, reason

    def _format_strategy(self, result, params, backtest, current_price):
        """Format strategy result for output."""
        import numpy as np

        stop_level = result.get('current_stop', 0)
        distance_pct = ((current_price - stop_level) / current_price * 100) if current_price > 0 else 0

        # Get last 24 data points of trailing history
        stop_levels = result.get('stop_levels', [])
        if hasattr(stop_levels, '__len__') and len(stop_levels) > 0:
            trail_24 = [round(float(s), 6) for s in stop_levels[-24:]]
        else:
            trail_24 = []

        status = 'active' if stop_level < current_price else 'stopped'

        return {
            'stop_level': round(float(stop_level), 8),
            'distance_pct': round(float(distance_pct), 2),
            'status': status,
            'params': params,
            'trailing_history_24': trail_24,
            'backtest': {
                'profit_retained_pct': backtest.get('profit_retained_pct', 0),
                'times_stopped': backtest.get('times_stopped', 0),
                'avg_profit': backtest.get('avg_profit', 0),
                'win_rate': backtest.get('win_rate', 0),
                'max_given_back': backtest.get('max_given_back', 0),
            },
        }

    def _calculate_score(self, backtests, prices_arr):
        """Calculate overall score (0-100) for how well trailing stops work for this asset."""
        import numpy as np

        if not backtests:
            return 30

        # Best retention across all strategies
        best_retention = max(bt.get('profit_retained_pct', 0) for bt in backtests.values())
        retention_score = min(best_retention, 100) * 0.40  # 40 pts max

        # Fewest false stops (normalized)
        min_stops = min(bt.get('times_stopped', 99) for bt in backtests.values())
        stop_score = max(0, 25 - min_stops * 5)  # 25 pts max

        # Regime fit: trailing stops work better in trends
        regime = self._detect_regime(prices_arr)
        regime_scores = {'strong_trend': 20, 'trending': 15, 'mild_trend': 10, 'ranging': 5}
        regime_score = regime_scores.get(regime, 8)  # 20 pts max

        # Distance reasonableness: best if 3-8%
        best_avg = max(bt.get('avg_profit', 0) for bt in backtests.values())
        dist_score = 15 if 2.0 <= best_avg <= 10.0 else (10 if best_avg > 0 else 0)  # 15 pts max

        total = int(retention_score + stop_score + regime_score + dist_score)
        return max(0, min(100, total))

    def _detect_regime(self, prices_arr):
        """Detect whether asset is trending or ranging."""
        import numpy as np

        if len(prices_arr) < 20:
            return 'unknown'

        # Use ADX-like metric: ratio of net move to total path length
        net_move = abs(prices_arr[-1] - prices_arr[0])
        total_path = sum(abs(prices_arr[i] - prices_arr[i - 1]) for i in range(1, len(prices_arr)))

        if total_path == 0:
            return 'ranging'

        efficiency = net_move / total_path

        # Also check trend direction consistency
        returns = np.diff(prices_arr) / prices_arr[:-1]
        pos_count = np.sum(returns > 0)
        direction_bias = abs(pos_count / len(returns) - 0.5) * 2  # 0=balanced, 1=all one way

        combined = efficiency * 0.6 + direction_bias * 0.4

        if combined > 0.4:
            return 'strong_trend'
        elif combined > 0.25:
            return 'trending'
        elif combined > 0.15:
            return 'mild_trend'
        else:
            return 'ranging'
