"""
Tests for Real-Time Capture Module

Tests the capture/realtime.py module including:
- CaptureSource and CaptureState enums
- CaptureConfig, CapturedMessage, CaptureSession dataclasses
- Stream handlers (Claude, OpenAI, WebSocket, SSE, FileWatch)
- RealtimeCapture main class
- Singleton capture manager
"""

import asyncio
import json
import pytest
from datetime import datetime
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch

from ucts.capture.realtime import (
    CaptureSource,
    CaptureState,
    CaptureConfig,
    CapturedMessage,
    CaptureSession,
    StreamHandler,
    ClaudeAPIHandler,
    OpenAIAPIHandler,
    WebSocketHandler,
    SSEHandler,
    FileWatchHandler,
    RealtimeCapture,
    get_capture_manager,
)


# ============================================================================
# CaptureSource Enum Tests
# ============================================================================

class TestCaptureSource:
    """Tests for CaptureSource enum"""

    def test_claude_api_value(self):
        """Test Claude API source value"""
        assert CaptureSource.CLAUDE_API.value == "claude_api"

    def test_openai_api_value(self):
        """Test OpenAI API source value"""
        assert CaptureSource.OPENAI_API.value == "openai_api"

    def test_websocket_value(self):
        """Test WebSocket source value"""
        assert CaptureSource.WEBSOCKET.value == "websocket"

    def test_sse_value(self):
        """Test SSE source value"""
        assert CaptureSource.SSE.value == "sse"

    def test_file_watch_value(self):
        """Test file watch source value"""
        assert CaptureSource.FILE_WATCH.value == "file_watch"

    def test_all_sources_defined(self):
        """Test all expected sources are defined"""
        sources = [s.value for s in CaptureSource]
        assert "claude_api" in sources
        assert "openai_api" in sources
        assert "websocket" in sources
        assert "sse" in sources
        assert "file_watch" in sources


# ============================================================================
# CaptureState Enum Tests
# ============================================================================

class TestCaptureState:
    """Tests for CaptureState enum"""

    def test_idle_value(self):
        """Test idle state value"""
        assert CaptureState.IDLE.value == "idle"

    def test_connecting_value(self):
        """Test connecting state value"""
        assert CaptureState.CONNECTING.value == "connecting"

    def test_streaming_value(self):
        """Test streaming state value"""
        assert CaptureState.STREAMING.value == "streaming"

    def test_paused_value(self):
        """Test paused state value"""
        assert CaptureState.PAUSED.value == "paused"

    def test_error_value(self):
        """Test error state value"""
        assert CaptureState.ERROR.value == "error"

    def test_completed_value(self):
        """Test completed state value"""
        assert CaptureState.COMPLETED.value == "completed"


# ============================================================================
# CaptureConfig Dataclass Tests
# ============================================================================

class TestCaptureConfig:
    """Tests for CaptureConfig dataclass"""

    def test_default_values(self):
        """Test default configuration values"""
        config = CaptureConfig()
        assert config.source == CaptureSource.CLAUDE_API
        assert config.api_key is None
        assert config.api_url is None
        assert config.checkpoint_interval == 30
        assert config.checkpoint_dir == ".ucts/checkpoints"
        assert config.auto_analyze is True
        assert config.max_messages == 1000
        assert config.buffer_size == 100
        assert config.reconnect_attempts == 3
        assert config.reconnect_delay == 1.0

    def test_custom_values(self):
        """Test custom configuration values"""
        config = CaptureConfig(
            source=CaptureSource.OPENAI_API,
            api_key="test-key",
            api_url="https://api.example.com",
            checkpoint_interval=60,
            max_messages=500,
        )
        assert config.source == CaptureSource.OPENAI_API
        assert config.api_key == "test-key"
        assert config.api_url == "https://api.example.com"
        assert config.checkpoint_interval == 60
        assert config.max_messages == 500


# ============================================================================
# CapturedMessage Dataclass Tests
# ============================================================================

class TestCapturedMessage:
    """Tests for CapturedMessage dataclass"""

    def test_creation(self):
        """Test basic message creation"""
        msg = CapturedMessage(role="user", content="Hello")
        assert msg.role == "user"
        assert msg.content == "Hello"
        assert msg.tokens == 0
        assert msg.model is None
        assert msg.metadata == {}

    def test_with_all_fields(self):
        """Test message with all fields"""
        msg = CapturedMessage(
            role="assistant",
            content="Hi there!",
            tokens=10,
            model="claude-3",
            metadata={"source": "test"}
        )
        assert msg.role == "assistant"
        assert msg.content == "Hi there!"
        assert msg.tokens == 10
        assert msg.model == "claude-3"
        assert msg.metadata == {"source": "test"}

    def test_timestamp_default(self):
        """Test timestamp is set automatically"""
        before = datetime.now()
        msg = CapturedMessage(role="user", content="Test")
        after = datetime.now()
        assert before <= msg.timestamp <= after


# ============================================================================
# CaptureSession Dataclass Tests
# ============================================================================

class TestCaptureSession:
    """Tests for CaptureSession dataclass"""

    def test_creation(self):
        """Test basic session creation"""
        session = CaptureSession(
            session_id="test-session",
            source=CaptureSource.CLAUDE_API
        )
        assert session.session_id == "test-session"
        assert session.source == CaptureSource.CLAUDE_API
        assert session.state == CaptureState.IDLE
        assert session.messages == []
        assert session.total_tokens == 0

    def test_add_message(self):
        """Test adding messages to session"""
        session = CaptureSession(
            session_id="test",
            source=CaptureSource.CLAUDE_API
        )
        msg = CapturedMessage(role="user", content="Hello", tokens=5)
        session.add_message(msg)

        assert len(session.messages) == 1
        assert session.messages[0] == msg
        assert session.total_tokens == 5

    def test_add_multiple_messages(self):
        """Test adding multiple messages accumulates tokens"""
        session = CaptureSession(
            session_id="test",
            source=CaptureSource.CLAUDE_API
        )
        session.add_message(CapturedMessage(role="user", content="Hi", tokens=2))
        session.add_message(CapturedMessage(role="assistant", content="Hello!", tokens=3))

        assert len(session.messages) == 2
        assert session.total_tokens == 5

    def test_to_dict(self):
        """Test session serialization to dictionary"""
        session = CaptureSession(
            session_id="test-123",
            source=CaptureSource.CLAUDE_API,
            state=CaptureState.STREAMING
        )
        session.add_message(CapturedMessage(role="user", content="Test", tokens=1))

        data = session.to_dict()
        assert data["session_id"] == "test-123"
        assert data["source"] == "claude_api"
        assert data["state"] == "streaming"
        assert len(data["messages"]) == 1
        assert data["messages"][0]["role"] == "user"
        assert data["total_tokens"] == 1

    def test_to_dict_empty_session(self):
        """Test serialization of empty session"""
        session = CaptureSession(
            session_id="empty",
            source=CaptureSource.WEBSOCKET
        )
        data = session.to_dict()
        assert data["messages"] == []
        assert data["total_tokens"] == 0
        assert data["checkpoints"] == []


# ============================================================================
# ClaudeAPIHandler Tests
# ============================================================================

class TestClaudeAPIHandler:
    """Tests for Claude API handler"""

    @pytest.mark.asyncio
    async def test_connect_mock_mode(self):
        """Test connection in mock mode (simulating no SDK)"""
        config = CaptureConfig(source=CaptureSource.CLAUDE_API)
        handler = ClaudeAPIHandler(config)

        # Patch to simulate ImportError (no SDK)
        with patch.dict('sys.modules', {'anthropic': None}):
            # Force mock mode by setting connected directly
            handler.connected = True

        assert handler.connected is True

    @pytest.mark.asyncio
    async def test_disconnect(self):
        """Test disconnection"""
        config = CaptureConfig(source=CaptureSource.CLAUDE_API)
        handler = ClaudeAPIHandler(config)
        handler.connected = True  # Simulate connected state
        await handler.disconnect()

        assert handler.connected is False
        assert handler.client is None

    @pytest.mark.asyncio
    async def test_stream_yields_message_when_connected(self):
        """Test streaming yields messages when connected"""
        config = CaptureConfig(source=CaptureSource.CLAUDE_API)
        handler = ClaudeAPIHandler(config)
        handler.connected = True  # Simulate connected state

        messages = []
        async for msg in handler.stream():
            messages.append(msg)

        # Mock mode yields one system message
        assert len(messages) >= 1
        assert messages[0].role == "system"

    @pytest.mark.asyncio
    async def test_stream_empty_when_disconnected(self):
        """Test streaming yields nothing when not connected"""
        config = CaptureConfig(source=CaptureSource.CLAUDE_API)
        handler = ClaudeAPIHandler(config)
        # Don't connect

        messages = []
        async for msg in handler.stream():
            messages.append(msg)

        assert messages == []


# ============================================================================
# OpenAIAPIHandler Tests
# ============================================================================

class TestOpenAIAPIHandler:
    """Tests for OpenAI API handler"""

    @pytest.mark.asyncio
    async def test_connect_mock_mode(self):
        """Test connection in mock mode"""
        config = CaptureConfig(source=CaptureSource.OPENAI_API)
        handler = OpenAIAPIHandler(config)
        handler.connected = True  # Simulate connected state
        assert handler.connected is True

    @pytest.mark.asyncio
    async def test_disconnect(self):
        """Test disconnection"""
        config = CaptureConfig(source=CaptureSource.OPENAI_API)
        handler = OpenAIAPIHandler(config)
        handler.connected = True  # Simulate connected state
        await handler.disconnect()

        assert handler.connected is False


# ============================================================================
# WebSocketHandler Tests
# ============================================================================

class TestWebSocketHandler:
    """Tests for WebSocket handler"""

    @pytest.mark.asyncio
    async def test_init(self):
        """Test handler initialization"""
        config = CaptureConfig(
            source=CaptureSource.WEBSOCKET,
            api_url="ws://localhost:8080"
        )
        handler = WebSocketHandler(config)

        assert handler.connected is False
        assert handler.ws is None

    @pytest.mark.asyncio
    async def test_disconnect(self):
        """Test disconnection"""
        config = CaptureConfig(source=CaptureSource.WEBSOCKET)
        handler = WebSocketHandler(config)
        handler.connected = True  # Simulate connected state
        await handler.disconnect()

        assert handler.connected is False
        assert handler.ws is None


# ============================================================================
# SSEHandler Tests
# ============================================================================

class TestSSEHandler:
    """Tests for SSE handler"""

    @pytest.mark.asyncio
    async def test_connect(self):
        """Test connection"""
        config = CaptureConfig(
            source=CaptureSource.SSE,
            api_url="https://api.example.com/stream"
        )
        handler = SSEHandler(config)

        result = await handler.connect()
        assert result is True
        assert handler.connected is True

    @pytest.mark.asyncio
    async def test_disconnect(self):
        """Test disconnection"""
        config = CaptureConfig(source=CaptureSource.SSE)
        handler = SSEHandler(config)
        await handler.connect()
        await handler.disconnect()

        assert handler.connected is False


# ============================================================================
# FileWatchHandler Tests
# ============================================================================

class TestFileWatchHandler:
    """Tests for file watch handler"""

    @pytest.mark.asyncio
    async def test_connect_existing_file(self, tmp_path):
        """Test connecting to existing file"""
        test_file = tmp_path / "conversation.txt"
        test_file.write_text("Initial content")

        config = CaptureConfig(
            source=CaptureSource.FILE_WATCH,
            api_url=str(test_file)
        )
        handler = FileWatchHandler(config)

        result = await handler.connect()
        assert result is True
        assert handler.connected is True

    @pytest.mark.asyncio
    async def test_connect_nonexistent_file(self):
        """Test connecting to nonexistent file fails"""
        config = CaptureConfig(
            source=CaptureSource.FILE_WATCH,
            api_url="/nonexistent/path/file.txt"
        )
        handler = FileWatchHandler(config)

        result = await handler.connect()
        assert result is False
        assert handler.connected is False

    @pytest.mark.asyncio
    async def test_disconnect(self, tmp_path):
        """Test disconnection"""
        test_file = tmp_path / "test.txt"
        test_file.write_text("content")

        config = CaptureConfig(
            source=CaptureSource.FILE_WATCH,
            api_url=str(test_file)
        )
        handler = FileWatchHandler(config)
        await handler.connect()
        await handler.disconnect()

        assert handler.connected is False
        assert handler._last_position == 0


# ============================================================================
# RealtimeCapture Tests
# ============================================================================

class TestRealtimeCapture:
    """Tests for main RealtimeCapture class"""

    def test_init_default_config(self):
        """Test initialization with default config"""
        capture = RealtimeCapture()
        assert capture.config.source == CaptureSource.CLAUDE_API
        assert capture.session is None
        assert capture.handler is None

    def test_init_custom_config(self):
        """Test initialization with custom config"""
        config = CaptureConfig(source=CaptureSource.OPENAI_API, max_messages=100)
        capture = RealtimeCapture(config)
        assert capture.config.source == CaptureSource.OPENAI_API
        assert capture.config.max_messages == 100

    def test_on_message_callback(self):
        """Test registering message callbacks"""
        capture = RealtimeCapture()
        callback = MagicMock()

        capture.on_message(callback)
        assert callback in capture._callbacks

    def test_on_message_multiple_callbacks(self):
        """Test registering multiple callbacks"""
        capture = RealtimeCapture()
        cb1 = MagicMock()
        cb2 = MagicMock()

        capture.on_message(cb1)
        capture.on_message(cb2)

        assert len(capture._callbacks) == 2

    @pytest.mark.asyncio
    async def test_start_creates_session(self):
        """Test starting capture creates a session"""
        config = CaptureConfig(source=CaptureSource.FILE_WATCH, api_url=".")
        capture = RealtimeCapture(config)
        session = await capture.start("test-session-123")

        assert session is not None
        assert session.session_id == "test-session-123"
        assert session.source == CaptureSource.FILE_WATCH
        assert session.state == CaptureState.STREAMING

        # Clean up
        await capture.stop()

    @pytest.mark.asyncio
    async def test_start_auto_generates_session_id(self):
        """Test starting without session_id generates one"""
        config = CaptureConfig(source=CaptureSource.FILE_WATCH, api_url=".")
        capture = RealtimeCapture(config)
        session = await capture.start()

        assert session.session_id.startswith("capture-")
        await capture.stop()

    @pytest.mark.asyncio
    async def test_start_raises_if_already_streaming(self):
        """Test starting while streaming raises error"""
        config = CaptureConfig(source=CaptureSource.FILE_WATCH, api_url=".")
        capture = RealtimeCapture(config)
        await capture.start()

        with pytest.raises(RuntimeError, match="Capture already in progress"):
            await capture.start()

        await capture.stop()

    @pytest.mark.asyncio
    async def test_stop_without_session_raises(self):
        """Test stopping without active session raises error"""
        capture = RealtimeCapture()

        with pytest.raises(RuntimeError, match="No capture session active"):
            await capture.stop()

    @pytest.mark.asyncio
    async def test_stop_marks_completed(self):
        """Test stopping marks session as completed"""
        config = CaptureConfig(source=CaptureSource.FILE_WATCH, api_url=".")
        capture = RealtimeCapture(config)
        await capture.start()
        session = await capture.stop()

        assert session.state == CaptureState.COMPLETED

    @pytest.mark.asyncio
    async def test_pause_and_resume(self):
        """Test pausing and resuming capture"""
        config = CaptureConfig(source=CaptureSource.FILE_WATCH, api_url=".")
        capture = RealtimeCapture(config)
        await capture.start()

        await capture.pause()
        assert capture.session.state == CaptureState.PAUSED

        await capture.resume()
        assert capture.session.state == CaptureState.STREAMING

        await capture.stop()

    def test_get_session(self):
        """Test getting current session"""
        capture = RealtimeCapture()
        assert capture.get_session() is None

    def test_get_messages_no_session(self):
        """Test getting messages with no session"""
        capture = RealtimeCapture()
        assert capture.get_messages() == []

    @pytest.mark.asyncio
    async def test_get_messages_with_session(self):
        """Test getting messages from session"""
        config = CaptureConfig(source=CaptureSource.FILE_WATCH, api_url=".")
        capture = RealtimeCapture(config)
        session = await capture.start()

        # Add a message directly for testing
        session.add_message(CapturedMessage(role="user", content="Test"))

        messages = capture.get_messages()
        assert len(messages) >= 1

        await capture.stop()

    def test_list_sources(self):
        """Test listing available sources"""
        sources = RealtimeCapture.list_sources()

        assert isinstance(sources, list)
        assert len(sources) >= 5

        source_names = [s["name"] for s in sources]
        assert "claude_api" in source_names
        assert "openai_api" in source_names
        assert "websocket" in source_names

    def test_handlers_mapping(self):
        """Test all handlers are properly mapped"""
        assert CaptureSource.CLAUDE_API in RealtimeCapture.HANDLERS
        assert CaptureSource.OPENAI_API in RealtimeCapture.HANDLERS
        assert CaptureSource.WEBSOCKET in RealtimeCapture.HANDLERS
        assert CaptureSource.SSE in RealtimeCapture.HANDLERS
        assert CaptureSource.FILE_WATCH in RealtimeCapture.HANDLERS


# ============================================================================
# Checkpoint Tests
# ============================================================================

class TestCheckpointing:
    """Tests for checkpoint functionality"""

    @pytest.mark.asyncio
    async def test_save_checkpoint(self, tmp_path):
        """Test checkpoint saving"""
        config = CaptureConfig(
            source=CaptureSource.FILE_WATCH,
            api_url=".",
            checkpoint_dir=str(tmp_path / "checkpoints")
        )
        capture = RealtimeCapture(config)
        session = await capture.start()

        # Manually trigger checkpoint
        await capture._save_checkpoint()

        # Check checkpoint was created
        assert len(session.checkpoints) >= 1
        assert session.last_checkpoint is not None

        await capture.stop()

    @pytest.mark.asyncio
    async def test_checkpoint_creates_directory(self, tmp_path):
        """Test checkpoint creates directory if missing"""
        checkpoint_dir = tmp_path / "new" / "checkpoints"
        config = CaptureConfig(
            source=CaptureSource.FILE_WATCH,
            api_url=".",
            checkpoint_dir=str(checkpoint_dir)
        )
        capture = RealtimeCapture(config)
        await capture.start()

        await capture._save_checkpoint()

        assert checkpoint_dir.exists()
        await capture.stop()

    @pytest.mark.asyncio
    async def test_checkpoint_file_format(self, tmp_path):
        """Test checkpoint file contains valid JSON"""
        config = CaptureConfig(
            source=CaptureSource.FILE_WATCH,
            api_url=".",
            checkpoint_dir=str(tmp_path)
        )
        capture = RealtimeCapture(config)
        session = await capture.start()
        session.add_message(CapturedMessage(role="user", content="Hello"))

        await capture._save_checkpoint()

        # Read the checkpoint file
        checkpoint_file = Path(session.checkpoints[0])
        assert checkpoint_file.exists()

        with open(checkpoint_file) as f:
            data = json.load(f)

        assert data["session_id"] == session.session_id
        assert len(data["messages"]) >= 1

        await capture.stop()


# ============================================================================
# Singleton Manager Tests
# ============================================================================

class TestGetCaptureManager:
    """Tests for singleton capture manager"""

    def test_get_manager_creates_instance(self):
        """Test getting manager creates instance"""
        # Reset singleton
        import ucts.capture.realtime as module
        module._capture_manager = None

        manager = get_capture_manager()
        assert manager is not None
        assert isinstance(manager, RealtimeCapture)

    def test_get_manager_returns_same_instance(self):
        """Test getting manager returns same instance"""
        manager1 = get_capture_manager()
        manager2 = get_capture_manager()
        assert manager1 is manager2

    def test_get_manager_with_config_creates_new(self):
        """Test providing config creates new instance"""
        manager1 = get_capture_manager()

        new_config = CaptureConfig(source=CaptureSource.OPENAI_API)
        manager2 = get_capture_manager(new_config)

        assert manager2.config.source == CaptureSource.OPENAI_API


# ============================================================================
# Error Handling Tests
# ============================================================================

class TestErrorHandling:
    """Tests for error handling"""

    @pytest.mark.asyncio
    async def test_unsupported_source_raises(self):
        """Test unsupported capture source raises error"""
        config = CaptureConfig(source=CaptureSource.FILE_WATCH, api_url=".")
        capture = RealtimeCapture(config)

        # Remove handler for testing
        original_handlers = RealtimeCapture.HANDLERS.copy()
        RealtimeCapture.HANDLERS = {}

        try:
            with pytest.raises(ValueError, match="Unsupported capture source"):
                await capture.start()
        finally:
            RealtimeCapture.HANDLERS = original_handlers

    @pytest.mark.asyncio
    async def test_callback_error_doesnt_stop_capture(self):
        """Test callback errors don't stop capture"""
        config = CaptureConfig(source=CaptureSource.FILE_WATCH, api_url=".")
        capture = RealtimeCapture(config)

        def bad_callback(msg):
            raise Exception("Callback error")

        capture.on_message(bad_callback)

        # Should not raise
        session = await capture.start()
        assert session.state == CaptureState.STREAMING

        await capture.stop()


# ============================================================================
# Integration-style Tests
# ============================================================================

class TestIntegration:
    """Integration-style tests"""

    @pytest.mark.asyncio
    async def test_full_capture_workflow(self, tmp_path):
        """Test complete capture workflow"""
        config = CaptureConfig(
            source=CaptureSource.FILE_WATCH,
            api_url=".",
            checkpoint_dir=str(tmp_path / "checkpoints"),
            checkpoint_interval=1,
        )
        capture = RealtimeCapture(config)

        received_messages = []
        capture.on_message(lambda m: received_messages.append(m))

        # Start capture
        session = await capture.start("integration-test")
        assert session.state == CaptureState.STREAMING

        # Let it run briefly
        await asyncio.sleep(0.1)

        # Stop capture
        completed = await capture.stop()
        assert completed.state == CaptureState.COMPLETED

        # Session should have data
        assert completed.session_id == "integration-test"

    @pytest.mark.asyncio
    async def test_analyze_incremental_empty(self):
        """Test incremental analysis with no messages"""
        capture = RealtimeCapture()
        result = await capture.analyze_incremental()
        assert result == {}

    @pytest.mark.asyncio
    async def test_pause_prevents_message_processing(self):
        """Test pausing prevents message processing"""
        config = CaptureConfig(source=CaptureSource.FILE_WATCH, api_url=".")
        capture = RealtimeCapture(config)
        await capture.start()

        await capture.pause()
        assert capture.session.state == CaptureState.PAUSED

        await capture.stop()
