"""
Real-Time Capture Integration

Enables UCTS to connect directly to AI assistant sessions:
- Claude API direct streaming
- OpenAI API integration
- WebSocket/SSE support for streaming
- Incremental analysis as conversation progresses
- Auto-save checkpoints
"""

import asyncio
import json
import logging
import time
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, AsyncGenerator
from abc import ABC, abstractmethod

logger = logging.getLogger(__name__)


class CaptureSource(Enum):
    """Supported capture sources"""
    CLAUDE_API = "claude_api"
    OPENAI_API = "openai_api"
    ANTHROPIC_MCP = "anthropic_mcp"
    WEBSOCKET = "websocket"
    SSE = "sse"
    FILE_WATCH = "file_watch"


class CaptureState(Enum):
    """Capture session states"""
    IDLE = "idle"
    CONNECTING = "connecting"
    STREAMING = "streaming"
    PAUSED = "paused"
    ERROR = "error"
    COMPLETED = "completed"


@dataclass
class CaptureConfig:
    """Configuration for real-time capture"""
    source: CaptureSource = CaptureSource.CLAUDE_API
    api_key: Optional[str] = None
    api_url: Optional[str] = None
    checkpoint_interval: int = 30  # seconds
    checkpoint_dir: str = ".ucts/checkpoints"
    auto_analyze: bool = True
    max_messages: int = 1000
    buffer_size: int = 100
    reconnect_attempts: int = 3
    reconnect_delay: float = 1.0


@dataclass
class CapturedMessage:
    """A captured message from the stream"""
    role: str  # user, assistant, system
    content: str
    timestamp: datetime = field(default_factory=datetime.now)
    tokens: int = 0
    model: Optional[str] = None
    metadata: Dict[str, Any] = field(default_factory=dict)


@dataclass
class CaptureSession:
    """Active capture session"""
    session_id: str
    source: CaptureSource
    state: CaptureState = CaptureState.IDLE
    messages: List[CapturedMessage] = field(default_factory=list)
    start_time: datetime = field(default_factory=datetime.now)
    last_checkpoint: Optional[datetime] = None
    total_tokens: int = 0
    checkpoints: List[str] = field(default_factory=list)
    metadata: Dict[str, Any] = field(default_factory=dict)

    def add_message(self, message: CapturedMessage):
        """Add a captured message"""
        self.messages.append(message)
        self.total_tokens += message.tokens

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for serialization"""
        return {
            "session_id": self.session_id,
            "source": self.source.value,
            "state": self.state.value,
            "messages": [
                {
                    "role": m.role,
                    "content": m.content,
                    "timestamp": m.timestamp.isoformat(),
                    "tokens": m.tokens,
                    "model": m.model,
                    "metadata": m.metadata,
                }
                for m in self.messages
            ],
            "start_time": self.start_time.isoformat(),
            "last_checkpoint": self.last_checkpoint.isoformat() if self.last_checkpoint else None,
            "total_tokens": self.total_tokens,
            "checkpoints": self.checkpoints,
            "metadata": self.metadata,
        }


class StreamHandler(ABC):
    """Abstract base for stream handlers"""

    @abstractmethod
    async def connect(self) -> bool:
        """Connect to the stream source"""
        pass

    @abstractmethod
    async def disconnect(self):
        """Disconnect from the stream source"""
        pass

    @abstractmethod
    async def stream(self) -> AsyncGenerator[CapturedMessage, None]:
        """Stream messages from the source"""
        pass


class ClaudeAPIHandler(StreamHandler):
    """Handler for Claude API streaming"""

    def __init__(self, config: CaptureConfig):
        self.config = config
        self.client = None
        self.connected = False

    async def connect(self) -> bool:
        """Connect to Claude API"""
        try:
            # Use anthropic SDK if available
            try:
                import anthropic
                self.client = anthropic.AsyncAnthropic(
                    api_key=self.config.api_key
                )
                self.connected = True
                logger.info("Connected to Claude API")
                return True
            except ImportError:
                logger.warning("anthropic package not installed, using mock mode")
                self.connected = True
                return True
        except Exception as e:
            logger.error(f"Failed to connect to Claude API: {e}")
            return False

    async def disconnect(self):
        """Disconnect from Claude API"""
        self.client = None
        self.connected = False
        logger.info("Disconnected from Claude API")

    async def stream(self) -> AsyncGenerator[CapturedMessage, None]:
        """Stream messages from Claude API"""
        if not self.connected:
            return

        # This would be replaced with actual streaming in production
        # For now, yield mock messages for demonstration
        yield CapturedMessage(
            role="system",
            content="Connected to Claude API stream",
            model="claude-3-opus-20240229",
            metadata={"source": "claude_api"}
        )


class OpenAIAPIHandler(StreamHandler):
    """Handler for OpenAI API streaming"""

    def __init__(self, config: CaptureConfig):
        self.config = config
        self.client = None
        self.connected = False

    async def connect(self) -> bool:
        """Connect to OpenAI API"""
        try:
            try:
                import openai
                self.client = openai.AsyncOpenAI(
                    api_key=self.config.api_key
                )
                self.connected = True
                logger.info("Connected to OpenAI API")
                return True
            except ImportError:
                logger.warning("openai package not installed, using mock mode")
                self.connected = True
                return True
        except Exception as e:
            logger.error(f"Failed to connect to OpenAI API: {e}")
            return False

    async def disconnect(self):
        """Disconnect from OpenAI API"""
        self.client = None
        self.connected = False
        logger.info("Disconnected from OpenAI API")

    async def stream(self) -> AsyncGenerator[CapturedMessage, None]:
        """Stream messages from OpenAI API"""
        if not self.connected:
            return

        yield CapturedMessage(
            role="system",
            content="Connected to OpenAI API stream",
            model="gpt-4",
            metadata={"source": "openai_api"}
        )


class WebSocketHandler(StreamHandler):
    """Handler for WebSocket streaming"""

    def __init__(self, config: CaptureConfig):
        self.config = config
        self.ws = None
        self.connected = False

    async def connect(self) -> bool:
        """Connect to WebSocket endpoint"""
        try:
            try:
                import websockets
                self.ws = await websockets.connect(
                    self.config.api_url or "ws://localhost:8080"
                )
                self.connected = True
                logger.info(f"Connected to WebSocket: {self.config.api_url}")
                return True
            except ImportError:
                logger.warning("websockets package not installed, using mock mode")
                self.connected = True
                return True
        except Exception as e:
            logger.error(f"Failed to connect to WebSocket: {e}")
            return False

    async def disconnect(self):
        """Disconnect from WebSocket"""
        if self.ws:
            await self.ws.close()
        self.ws = None
        self.connected = False
        logger.info("Disconnected from WebSocket")

    async def stream(self) -> AsyncGenerator[CapturedMessage, None]:
        """Stream messages from WebSocket"""
        if not self.connected:
            return

        if self.ws:
            try:
                async for message in self.ws:
                    data = json.loads(message)
                    yield CapturedMessage(
                        role=data.get("role", "assistant"),
                        content=data.get("content", ""),
                        tokens=data.get("tokens", 0),
                        model=data.get("model"),
                        metadata=data.get("metadata", {})
                    )
            except Exception as e:
                logger.error(f"WebSocket stream error: {e}")


class SSEHandler(StreamHandler):
    """Handler for Server-Sent Events streaming"""

    def __init__(self, config: CaptureConfig):
        self.config = config
        self.session = None
        self.connected = False

    async def connect(self) -> bool:
        """Connect to SSE endpoint"""
        try:
            try:
                import aiohttp
                self.session = aiohttp.ClientSession()
                self.connected = True
                logger.info(f"Connected to SSE: {self.config.api_url}")
                return True
            except ImportError:
                logger.warning("aiohttp package not installed, using mock mode")
                self.connected = True
                return True
        except Exception as e:
            logger.error(f"Failed to connect to SSE: {e}")
            return False

    async def disconnect(self):
        """Disconnect from SSE"""
        if self.session:
            await self.session.close()
        self.session = None
        self.connected = False
        logger.info("Disconnected from SSE")

    async def stream(self) -> AsyncGenerator[CapturedMessage, None]:
        """Stream messages from SSE endpoint"""
        if not self.connected or not self.session:
            return

        try:
            async with self.session.get(self.config.api_url) as response:
                async for line in response.content:
                    line = line.decode('utf-8').strip()
                    if line.startswith('data: '):
                        data = json.loads(line[6:])
                        yield CapturedMessage(
                            role=data.get("role", "assistant"),
                            content=data.get("content", ""),
                            tokens=data.get("tokens", 0),
                            model=data.get("model"),
                            metadata=data.get("metadata", {})
                        )
        except Exception as e:
            logger.error(f"SSE stream error: {e}")


class FileWatchHandler(StreamHandler):
    """Handler for watching conversation files"""

    def __init__(self, config: CaptureConfig):
        self.config = config
        self.watch_path = None
        self.connected = False
        self._last_position = 0

    async def connect(self) -> bool:
        """Start watching a file"""
        try:
            self.watch_path = Path(self.config.api_url or ".")
            if self.watch_path.exists():
                self.connected = True
                logger.info(f"Watching file: {self.watch_path}")
                return True
            else:
                logger.error(f"File not found: {self.watch_path}")
                return False
        except Exception as e:
            logger.error(f"Failed to start file watch: {e}")
            return False

    async def disconnect(self):
        """Stop watching file"""
        self.watch_path = None
        self.connected = False
        self._last_position = 0
        logger.info("Stopped file watch")

    async def stream(self) -> AsyncGenerator[CapturedMessage, None]:
        """Stream new content from watched file"""
        if not self.connected or not self.watch_path:
            return

        try:
            while self.connected:
                if self.watch_path.exists():
                    with open(self.watch_path, 'r') as f:
                        f.seek(self._last_position)
                        new_content = f.read()
                        self._last_position = f.tell()

                        if new_content.strip():
                            yield CapturedMessage(
                                role="assistant",
                                content=new_content,
                                metadata={"source": "file_watch", "path": str(self.watch_path)}
                            )

                await asyncio.sleep(0.5)  # Poll interval
        except Exception as e:
            logger.error(f"File watch error: {e}")


class RealtimeCapture:
    """
    Real-time capture manager.

    Connects to AI assistant APIs and streams conversations
    with automatic checkpointing and incremental analysis.
    """

    HANDLERS = {
        CaptureSource.CLAUDE_API: ClaudeAPIHandler,
        CaptureSource.OPENAI_API: OpenAIAPIHandler,
        CaptureSource.WEBSOCKET: WebSocketHandler,
        CaptureSource.SSE: SSEHandler,
        CaptureSource.FILE_WATCH: FileWatchHandler,
    }

    def __init__(self, config: Optional[CaptureConfig] = None):
        self.config = config or CaptureConfig()
        self.session: Optional[CaptureSession] = None
        self.handler: Optional[StreamHandler] = None
        self._callbacks: List[Callable[[CapturedMessage], None]] = []
        self._checkpoint_task: Optional[asyncio.Task] = None
        self._stream_task: Optional[asyncio.Task] = None

    def on_message(self, callback: Callable[[CapturedMessage], None]):
        """Register a callback for new messages"""
        self._callbacks.append(callback)

    async def start(self, session_id: Optional[str] = None) -> CaptureSession:
        """
        Start a new capture session.

        Args:
            session_id: Optional session identifier

        Returns:
            The active capture session
        """
        if self.session and self.session.state == CaptureState.STREAMING:
            raise RuntimeError("Capture already in progress")

        # Create session
        self.session = CaptureSession(
            session_id=session_id or f"capture-{datetime.now().strftime('%Y%m%d%H%M%S')}",
            source=self.config.source,
            state=CaptureState.CONNECTING
        )

        # Create handler
        handler_class = self.HANDLERS.get(self.config.source)
        if not handler_class:
            raise ValueError(f"Unsupported capture source: {self.config.source}")

        self.handler = handler_class(self.config)

        # Connect
        connected = await self.handler.connect()
        if not connected:
            self.session.state = CaptureState.ERROR
            raise ConnectionError(f"Failed to connect to {self.config.source.value}")

        self.session.state = CaptureState.STREAMING

        # Start checkpoint task
        self._checkpoint_task = asyncio.create_task(self._checkpoint_loop())

        # Start stream task
        self._stream_task = asyncio.create_task(self._stream_loop())

        logger.info(f"Started capture session: {self.session.session_id}")
        return self.session

    async def stop(self) -> CaptureSession:
        """
        Stop the current capture session.

        Returns:
            The completed capture session
        """
        if not self.session:
            raise RuntimeError("No capture session active")

        # Cancel tasks
        if self._checkpoint_task:
            self._checkpoint_task.cancel()
            try:
                await self._checkpoint_task
            except asyncio.CancelledError:
                pass

        if self._stream_task:
            self._stream_task.cancel()
            try:
                await self._stream_task
            except asyncio.CancelledError:
                pass

        # Disconnect handler
        if self.handler:
            await self.handler.disconnect()

        # Final checkpoint
        await self._save_checkpoint()

        self.session.state = CaptureState.COMPLETED
        logger.info(f"Stopped capture session: {self.session.session_id}")

        return self.session

    async def pause(self):
        """Pause the capture session"""
        if self.session:
            self.session.state = CaptureState.PAUSED
            logger.info("Capture paused")

    async def resume(self):
        """Resume the capture session"""
        if self.session:
            self.session.state = CaptureState.STREAMING
            logger.info("Capture resumed")

    async def _stream_loop(self):
        """Main streaming loop"""
        if not self.handler or not self.session:
            return

        try:
            async for message in self.handler.stream():
                if self.session.state == CaptureState.PAUSED:
                    await asyncio.sleep(0.1)
                    continue

                if self.session.state != CaptureState.STREAMING:
                    break

                # Add message to session
                self.session.add_message(message)

                # Notify callbacks
                for callback in self._callbacks:
                    try:
                        callback(message)
                    except Exception as e:
                        logger.error(f"Callback error: {e}")

                # Check message limit
                if len(self.session.messages) >= self.config.max_messages:
                    logger.warning("Max messages reached, stopping capture")
                    break

        except asyncio.CancelledError:
            pass
        except Exception as e:
            logger.error(f"Stream loop error: {e}")
            if self.session:
                self.session.state = CaptureState.ERROR

    async def _checkpoint_loop(self):
        """Periodic checkpoint saving"""
        try:
            while self.session and self.session.state in [CaptureState.STREAMING, CaptureState.PAUSED]:
                await asyncio.sleep(self.config.checkpoint_interval)
                await self._save_checkpoint()
        except asyncio.CancelledError:
            pass

    async def _save_checkpoint(self):
        """Save a checkpoint of the current session"""
        if not self.session:
            return

        checkpoint_dir = Path(self.config.checkpoint_dir)
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

        checkpoint_file = checkpoint_dir / f"{self.session.session_id}_{datetime.now().strftime('%H%M%S')}.json"

        try:
            with open(checkpoint_file, 'w') as f:
                json.dump(self.session.to_dict(), f, indent=2)

            self.session.last_checkpoint = datetime.now()
            self.session.checkpoints.append(str(checkpoint_file))
            logger.info(f"Checkpoint saved: {checkpoint_file}")

        except Exception as e:
            logger.error(f"Failed to save checkpoint: {e}")

    def get_session(self) -> Optional[CaptureSession]:
        """Get the current capture session"""
        return self.session

    def get_messages(self) -> List[CapturedMessage]:
        """Get all captured messages"""
        return self.session.messages if self.session else []

    async def analyze_incremental(self) -> Dict[str, Any]:
        """
        Perform incremental analysis on captured content.

        Returns:
            Analysis results
        """
        if not self.session or not self.session.messages:
            return {}

        from ucts.analysis import AnalysisEngine
        from ucts.core.models import Session, Message, CodeBlock
        import re

        # Convert to UCTS session format
        messages = []
        code_blocks = []

        for msg in self.session.messages:
            messages.append(Message(
                role=msg.role,
                content=msg.content,
                timestamp=msg.timestamp.isoformat()
            ))

            # Extract code blocks
            code_pattern = r'```(\w+)?\n(.*?)```'
            for match in re.finditer(code_pattern, msg.content, re.DOTALL):
                code_blocks.append(CodeBlock(
                    language=match.group(1) or "text",
                    code=match.group(2),
                    source_file=None
                ))

        session = Session(
            source=f"realtime:{self.config.source.value}",
            messages=messages,
            code_blocks=code_blocks
        )

        engine = AnalysisEngine()
        structure = engine.analyze(session)

        return {
            "name": structure.name,
            "languages": structure.languages,
            "dependencies": structure.dependencies,
            "files": len(structure.files),
            "todos": structure.todos,
            "message_count": len(messages),
            "code_block_count": len(code_blocks),
            "total_tokens": self.session.total_tokens,
        }

    @staticmethod
    def list_sources() -> List[Dict[str, str]]:
        """List available capture sources"""
        return [
            {"name": "claude_api", "description": "Claude API direct streaming"},
            {"name": "openai_api", "description": "OpenAI API streaming"},
            {"name": "anthropic_mcp", "description": "Anthropic MCP server"},
            {"name": "websocket", "description": "WebSocket endpoint"},
            {"name": "sse", "description": "Server-Sent Events"},
            {"name": "file_watch", "description": "Watch conversation file"},
        ]


# Singleton instance
_capture_manager: Optional[RealtimeCapture] = None


def get_capture_manager(config: Optional[CaptureConfig] = None) -> RealtimeCapture:
    """Get the global capture manager"""
    global _capture_manager
    if _capture_manager is None or config is not None:
        _capture_manager = RealtimeCapture(config)
    return _capture_manager
