"""AlgorithmContext - shared state passed through the algorithm pipeline."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Optional


@dataclass
class AlgorithmContext:
    """Shared state object passed through the algorithm pipeline.

    Every algorithm reads from and writes to this context.
    This ensures all algorithms share the same data without coupling.
    """
    # Input data
    asset_id: str = ''
    timeframe: str = '1h'
    source: str = 'coingecko'
    ohlcv: Optional[pd.DataFrame] = None  # columns: timestamp, datetime, open, high, low, close, volume

    # Current market state
    current_price: float = 0.0
    current_volume: float = 0.0

    # Settings
    capital_idr: float = 1_000_000
    active_pct: float = 75.0
    reserve_pct: float = 25.0
    buy_fee_pct: float = 0.31
    sell_fee_pct: float = 0.31

    # Stock-aware parameters
    asset_type: str = 'crypto'          # 'crypto' or 'stock'
    lot_size: int = 0                   # 0 = fractional (crypto), 100 = IDX stock lot
    sell_tax_pct: float = 0.0           # PPh Final 0.10% for stocks, 0 for crypto
    market_hours: tuple = ()            # () = 24/7, (9, 16) = IDX market hours

    # Computed indicators (populated by technical algorithms)
    indicators: dict = field(default_factory=dict)

    # Pattern detections (populated by pattern algorithms)
    patterns: dict = field(default_factory=dict)

    # Quantitative results (ETA, multiplier, laddering)
    quantitative: dict = field(default_factory=dict)

    # Money management results
    money_mgmt: dict = field(default_factory=dict)

    # Velocity / time projections
    velocity: dict = field(default_factory=dict)

    # ML predictions
    ml_predictions: dict = field(default_factory=dict)

    # Strategy recommendations
    strategy: dict = field(default_factory=dict)

    # Signal contribution from each algorithm
    signal_contributions: dict = field(default_factory=dict)

    # Algorithm versions used (for traceability)
    algorithm_versions: dict = field(default_factory=dict)

    # Errors encountered during pipeline
    errors: list = field(default_factory=list)

    # Batch scan mode (skip Optuna & cross-asset for speed, keep full accuracy)
    batch_scan: bool = False

    def get(self, key: str, default: Any = None) -> Any:
        """Get a value from indicators, patterns, or any dict attribute."""
        if key in self.indicators:
            return self.indicators[key]
        if key in self.patterns:
            return self.patterns[key]
        if key in self.quantitative:
            return self.quantitative[key]
        return default

    def to_dict(self) -> dict:
        """Serialize context to dict (for API responses).

        Excludes raw OHLCV DataFrame and large series arrays to keep
        the response size reasonable.
        """
        # Filter out large arrays from indicators (keep only *_latest keys and scalars)
        filtered_indicators = {}
        for k, v in self.indicators.items():
            if isinstance(v, (list, type(None))) and not k.endswith('_latest'):
                continue  # skip raw series
            filtered_indicators[k] = v

        # Filter ML predictions (exclude raw predicted_prices array)
        filtered_ml = {}
        for k, v in self.ml_predictions.items():
            if k == 'ensemble':
                filtered_ml[k] = {
                    'n_steps': v.get('n_steps'),
                    'volatility': v.get('volatility'),
                }
            else:
                filtered_ml[k] = v

        return {
            'asset_id': self.asset_id,
            'timeframe': self.timeframe,
            'asset_type': self.asset_type,
            'current_price': self.current_price,
            'indicators': filtered_indicators,
            'patterns': self.patterns,
            'quantitative': self.quantitative,
            'money_mgmt': self.money_mgmt,
            'velocity': self.velocity,
            'ml_predictions': filtered_ml,
            'strategy': self.strategy,
            'signal_contributions': self.signal_contributions,
            'algorithm_versions': self.algorithm_versions,
            'errors': self.errors,
        }

    def round_to_lot(self, units: float) -> float:
        """Round down to the nearest lot size for stocks.

        For crypto (lot_size=0): returns units as-is (fractional trading).
        For stocks (lot_size=100): rounds down to nearest 100 shares.
        """
        if self.lot_size <= 0:
            return units  # fractional (crypto)
        return float(int(units // self.lot_size) * self.lot_size)

    def round_price(self, price: float, direction: str = 'nearest') -> float:
        """Round price to valid tick size for stocks, or 8 decimals for crypto.

        For stocks: snaps to nearest BEI tick size based on price level.
        For crypto: returns round(price, 8).

        Args:
            price: Raw price value.
            direction: 'nearest', 'down' (buy/SL), or 'up' (sell/TP).
        """
        from app.helpers.tick_size import snap_to_tick
        return snap_to_tick(price, self.asset_type, direction)

    @property
    def total_sell_fee_pct(self) -> float:
        """Total sell cost = broker sell fee + sell tax (PPh Final for stocks)."""
        return self.sell_fee_pct + self.sell_tax_pct

    def is_market_hours(self, hour: int) -> bool:
        """Check if a given hour falls within market trading hours.

        Returns True always for crypto (24/7).
        For stocks: True if hour is between open and close.
        """
        if not self.market_hours:
            return True  # 24/7
        open_h, close_h = self.market_hours
        return open_h <= hour < close_h
