#!/usr/bin/env python3
"""Migration script: Split kamucoid_investasi_market → 3 databases.

Source: kamucoid_investasi_market (legacy single market DB)
Target:
  - kamucoid_investasi_market_crypto  (COIN.* assets)
  - kamucoid_investasi_market_id      (IDX.* assets)
  - kamucoid_investasi_market_us      (NYSE.*, NASDAQ.* assets)

This script:
1. Creates the 3 target databases if they don't exist
2. Copies table structures from the source DB
3. Splits data by asset_id prefix (COIN.*, IDX.*, NYSE.*/NASDAQ.*)
4. Handles tables without asset_id (market_cache, alert_configs) → copy to all 3

Usage:
  python scripts/migrate_split_market_db.py [--dry-run] [--skip-ohlcv]
"""
import os
import sys
import argparse
import time

# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from dotenv import load_dotenv
load_dotenv()

DB_HOST = os.getenv('DB_HOST', 'localhost')
DB_PORT = int(os.getenv('DB_PORT', 3306))
DB_USER = os.getenv('DB_USER', 'kamucoid_investasi')
DB_PASSWORD = os.getenv('DB_PASSWORD', 'qwe123,./')

SOURCE_DB = os.getenv('DB_NAME_MARKET', 'kamucoid_investasi_market')
TARGET_CRYPTO = os.getenv('DB_NAME_MARKET_CRYPTO', 'kamucoid_investasi_market_crypto')
TARGET_ID = os.getenv('DB_NAME_MARKET_ID', 'kamucoid_investasi_market_id')
TARGET_US = os.getenv('DB_NAME_MARKET_US', 'kamucoid_investasi_market_us')

# Tables that have asset_id column → split by prefix
ASSET_TABLES = [
    'assets',             # PK = asset_id itself (id column)
    'asset_profiles',     # FK asset_id
    'ohlcv_data',         # FK asset_id (BIGGEST table — 75M rows)
    'market_tickers',     # FK asset_id
    'trading_signals',    # FK asset_id
    'prediction_queues',  # FK asset_id
    'asset_source_mappings',  # FK asset_id
    'range_trading_scores',   # FK asset_id
    'bullish_momentum_scores',  # FK asset_id
    'asset_extended_data',  # FK asset_id
]

# Tables without asset_id → copy to ALL 3 databases
# NOTE: alert_configs is NOT in legacy market DB (it's created by create_all)
SHARED_TABLES = [
    'alert_logs',
    'market_cache',
]

# Asset prefix → target DB mapping
PREFIX_MAP = {
    'COIN': TARGET_CRYPTO,
    'IDX': TARGET_ID,
    'NYSE': TARGET_US,
    'NASDAQ': TARGET_US,
}


def get_connection():
    import pymysql
    return pymysql.connect(
        host=DB_HOST, port=DB_PORT,
        user=DB_USER, password=DB_PASSWORD,
        charset='utf8mb4',
        autocommit=True,
    )


def create_databases(conn, dry_run=False):
    """Create target databases if they don't exist."""
    cur = conn.cursor()
    for db_name in [TARGET_CRYPTO, TARGET_ID, TARGET_US]:
        sql = f"CREATE DATABASE IF NOT EXISTS `{db_name}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
        if dry_run:
            print(f"  [DRY-RUN] {sql}")
        else:
            cur.execute(sql)
            print(f"  ✓ Database `{db_name}` ready")
    cur.close()


def table_exists(conn, db_name, table):
    """Check if a table exists in a database."""
    cur = conn.cursor()
    cur.execute("SELECT COUNT(*) FROM information_schema.TABLES "
                "WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s", (db_name, table))
    exists = cur.fetchone()[0] > 0
    cur.close()
    return exists


def copy_table_structure(conn, table, dry_run=False):
    """Copy table structure from source to all 3 target DBs."""
    # Check source table exists
    if not table_exists(conn, SOURCE_DB, table):
        print(f"  ⚠ {SOURCE_DB}.{table} does not exist in source, skipping")
        return False

    cur = conn.cursor()
    for db_name in [TARGET_CRYPTO, TARGET_ID, TARGET_US]:
        if table_exists(conn, db_name, table):
            print(f"  ⏭ {db_name}.{table} already exists, skipping structure")
            continue

        sql = f"CREATE TABLE `{db_name}`.`{table}` LIKE `{SOURCE_DB}`.`{table}`"
        if dry_run:
            print(f"  [DRY-RUN] {sql}")
        else:
            cur.execute(sql)
            print(f"  ✓ Created {db_name}.{table}")
    cur.close()
    return True


def get_row_count(conn, db_name, table, where=''):
    """Get row count for a table with optional WHERE clause."""
    cur = conn.cursor()
    sql = f"SELECT COUNT(*) FROM `{db_name}`.`{table}`"
    if where:
        sql += f" WHERE {where}"
    cur.execute(sql)
    count = cur.fetchone()[0]
    cur.close()
    return count


def split_asset_table(conn, table, dry_run=False, batch_size=10000):
    """Split a table with asset_id into the 3 target DBs."""
    if not table_exists(conn, SOURCE_DB, table):
        print(f"  ⚠ {SOURCE_DB}.{table} does not exist in source, skipping")
        return

    cur = conn.cursor()

    # Determine the asset_id column name
    # 'assets' table uses 'id', others use 'asset_id'
    if table == 'assets':
        id_col = 'id'
    else:
        id_col = 'asset_id'

    # Define WHERE clauses for each target DB
    splits = [
        (TARGET_CRYPTO, f"`{id_col}` LIKE 'COIN.%'"),
        (TARGET_ID, f"`{id_col}` LIKE 'IDX.%'"),
        (TARGET_US, f"(`{id_col}` LIKE 'NYSE.%' OR `{id_col}` LIKE 'NASDAQ.%')"),
    ]

    for target_db, where_clause in splits:
        source_count = get_row_count(conn, SOURCE_DB, table, where_clause)

        if dry_run:
            print(f"  [DRY-RUN] Would copy {source_count:,} rows → {target_db}.{table}")
            continue

        # Check if target already has data
        existing = get_row_count(conn, target_db, table)
        if existing > 0:
            source_count = get_row_count(conn, SOURCE_DB, table, where_clause)
            if existing >= source_count:
                print(f"  ⏭ {target_db}.{table}: already has {existing} rows (source has {source_count}), skipping")
                continue
            else:
                print(f"  ⚠ {target_db}.{table}: has {existing} rows but source has {source_count}. Continuing from where left off...")

        if source_count == 0:
            print(f"  ⏭ {target_db}.{table}: 0 matching rows in source, skipping")
            continue

        # For small tables, use INSERT ... SELECT directly
        if source_count < 100000:
            sql = (f"INSERT IGNORE INTO `{target_db}`.`{table}` "
                   f"SELECT * FROM `{SOURCE_DB}`.`{table}` WHERE {where_clause}")
            t0 = time.time()
            cur.execute(sql)
            elapsed = time.time() - t0
            print(f"  ✓ {target_db}.{table}: {source_count:,} rows copied ({elapsed:.1f}s)")
        else:
            # For large tables (ohlcv_data), batch by ID range
            print(f"  → {target_db}.{table}: {source_count:,} rows to copy (batching by {batch_size:,})...")

            # Get min/max id
            cur.execute(f"SELECT MIN(id), MAX(id) FROM `{SOURCE_DB}`.`{table}` WHERE {where_clause}")
            min_id, max_id = cur.fetchone()
            if min_id is None:
                continue

            copied = 0
            t0 = time.time()
            current_id = min_id

            while current_id <= max_id:
                batch_where = f"{where_clause} AND id >= {current_id} AND id < {current_id + batch_size}"
                sql = (f"INSERT IGNORE INTO `{target_db}`.`{table}` "
                       f"SELECT * FROM `{SOURCE_DB}`.`{table}` WHERE {batch_where}")
                cur.execute(sql)
                rows_affected = cur.rowcount
                copied += rows_affected
                current_id += batch_size

                # Progress every 1M rows
                if copied > 0 and copied % 1000000 < batch_size:
                    elapsed = time.time() - t0
                    rate = copied / elapsed if elapsed > 0 else 0
                    pct = (copied / source_count * 100) if source_count > 0 else 0
                    print(f"    ... {copied:,}/{source_count:,} ({pct:.1f}%) — {rate:,.0f} rows/s")

            elapsed = time.time() - t0
            rate = copied / elapsed if elapsed > 0 else 0
            print(f"  ✓ {target_db}.{table}: {copied:,} rows copied ({elapsed:.1f}s, {rate:,.0f} rows/s)")

    cur.close()


def copy_shared_table(conn, table, dry_run=False):
    """Copy a non-asset table to all 3 target DBs."""
    if not table_exists(conn, SOURCE_DB, table):
        print(f"  ⚠ {SOURCE_DB}.{table} does not exist in source, skipping")
        return

    cur = conn.cursor()
    source_count = get_row_count(conn, SOURCE_DB, table)

    for target_db in [TARGET_CRYPTO, TARGET_ID, TARGET_US]:
        existing = get_row_count(conn, target_db, table)
        if existing >= source_count and source_count > 0:
            print(f"  ⏭ {target_db}.{table}: already has {existing} rows, skipping")
            continue

        if dry_run:
            print(f"  [DRY-RUN] Would copy {source_count:,} rows → {target_db}.{table}")
            continue

        # Truncate and re-insert (shared tables are small)
        cur.execute(f"TRUNCATE TABLE `{target_db}`.`{table}`")
        if source_count > 0:
            sql = (f"INSERT INTO `{target_db}`.`{table}` "
                   f"SELECT * FROM `{SOURCE_DB}`.`{table}`")
            cur.execute(sql)
        print(f"  ✓ {target_db}.{table}: {source_count:,} rows copied")

    cur.close()


def verify_counts(conn):
    """Verify row counts match between source and targets."""
    cur = conn.cursor()
    print("\n" + "=" * 70)
    print("VERIFICATION")
    print("=" * 70)

    all_ok = True
    for table in ASSET_TABLES:
        if not table_exists(conn, SOURCE_DB, table):
            print(f"  ⚠ {table}: not in source DB, skipping verification")
            continue

        source_total = get_row_count(conn, SOURCE_DB, table)

        crypto_count = get_row_count(conn, TARGET_CRYPTO, table) if table_exists(conn, TARGET_CRYPTO, table) else 0
        id_count = get_row_count(conn, TARGET_ID, table) if table_exists(conn, TARGET_ID, table) else 0
        us_count = get_row_count(conn, TARGET_US, table) if table_exists(conn, TARGET_US, table) else 0
        target_total = crypto_count + id_count + us_count

        status = "✓" if target_total == source_total else "✗"
        if target_total != source_total:
            all_ok = False

        print(f"  {status} {table}: source={source_total:,} | "
              f"crypto={crypto_count:,} + id={id_count:,} + us={us_count:,} = {target_total:,}")

    for table in SHARED_TABLES:
        if not table_exists(conn, SOURCE_DB, table):
            print(f"  ⚠ {table} (shared): not in source DB, skipping")
            continue

        source_count = get_row_count(conn, SOURCE_DB, table)
        counts = []
        for db_name in [TARGET_CRYPTO, TARGET_ID, TARGET_US]:
            c = get_row_count(conn, db_name, table) if table_exists(conn, db_name, table) else 0
            counts.append(c)

        status = "✓" if all(c == source_count for c in counts) else "✗"
        if not all(c == source_count for c in counts):
            all_ok = False

        print(f"  {status} {table} (shared): source={source_count:,} | "
              f"all targets={counts[0]:,}")

    cur.close()
    return all_ok


def main():
    parser = argparse.ArgumentParser(description='Split market DB into 3 databases')
    parser.add_argument('--dry-run', action='store_true', help='Show what would be done without executing')
    parser.add_argument('--skip-ohlcv', action='store_true', help='Skip ohlcv_data (biggest table)')
    parser.add_argument('--verify-only', action='store_true', help='Only verify row counts')
    parser.add_argument('--batch-size', type=int, default=50000, help='Batch size for large tables')
    args = parser.parse_args()

    print("=" * 70)
    print("MARKET DB SPLIT MIGRATION")
    print("=" * 70)
    print(f"  Source:  {SOURCE_DB}")
    print(f"  Target:  {TARGET_CRYPTO}")
    print(f"           {TARGET_ID}")
    print(f"           {TARGET_US}")
    print(f"  Options: dry_run={args.dry_run}, skip_ohlcv={args.skip_ohlcv}")
    print()

    conn = get_connection()

    if args.verify_only:
        verify_counts(conn)
        conn.close()
        return

    # Step 1: Create databases
    print("Step 1: Create databases")
    create_databases(conn, args.dry_run)
    print()

    # Step 2: Copy table structures
    print("Step 2: Copy table structures")
    all_tables = ASSET_TABLES + SHARED_TABLES
    if args.skip_ohlcv:
        all_tables = [t for t in all_tables if t != 'ohlcv_data']
    for table in all_tables:
        copy_table_structure(conn, table, args.dry_run)
    print()

    # Step 3: Split asset tables
    print("Step 3: Split asset tables (by asset_id prefix)")
    for table in ASSET_TABLES:
        if args.skip_ohlcv and table == 'ohlcv_data':
            print(f"  ⏭ Skipping ohlcv_data (--skip-ohlcv)")
            continue
        print(f"\n  Processing: {table}")
        split_asset_table(conn, table, args.dry_run, args.batch_size)
    print()

    # Step 4: Copy shared tables
    print("Step 4: Copy shared tables (to all 3 DBs)")
    for table in SHARED_TABLES:
        print(f"\n  Processing: {table}")
        copy_shared_table(conn, table, args.dry_run)
    print()

    # Step 5: Verify
    if not args.dry_run:
        all_ok = verify_counts(conn)
        if all_ok:
            print("\n✅ Migration complete! All row counts match.")
        else:
            print("\n⚠️ Some counts don't match. Check above for details.")

    conn.close()


if __name__ == '__main__':
    main()
