"""
Conversation Analytics

Provides analytics for AI conversations:
- Token usage tracking over time
- Cost estimation
- Conversation quality metrics
- Pattern detection (common problems)
"""

import json
import logging
import os
import re
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

logger = logging.getLogger(__name__)


# Token pricing per 1M tokens (approximate as of 2024)
TOKEN_PRICING = {
    "claude-3-opus": {"input": 15.0, "output": 75.0},
    "claude-3-sonnet": {"input": 3.0, "output": 15.0},
    "claude-3-haiku": {"input": 0.25, "output": 1.25},
    "gpt-4-turbo": {"input": 10.0, "output": 30.0},
    "gpt-4": {"input": 30.0, "output": 60.0},
    "gpt-3.5-turbo": {"input": 0.5, "output": 1.5},
    "default": {"input": 3.0, "output": 15.0},
}


@dataclass
class AnalyticsConfig:
    """Configuration for conversation analytics"""
    storage_dir: str = ""
    track_costs: bool = True
    model: str = "default"
    detect_patterns: bool = True
    quality_threshold: float = 0.7

    def __post_init__(self):
        if not self.storage_dir:
            home = os.path.expanduser("~")
            self.storage_dir = os.path.join(home, ".ucts", "analytics")


@dataclass
class UsageMetrics:
    """Token usage metrics"""
    total_input_tokens: int = 0
    total_output_tokens: int = 0
    total_tokens: int = 0
    message_count: int = 0
    avg_input_per_message: float = 0.0
    avg_output_per_message: float = 0.0
    tokens_over_time: List[Dict[str, Any]] = field(default_factory=list)

    def to_dict(self) -> Dict[str, Any]:
        return {
            "total_input_tokens": self.total_input_tokens,
            "total_output_tokens": self.total_output_tokens,
            "total_tokens": self.total_tokens,
            "message_count": self.message_count,
            "avg_input_per_message": self.avg_input_per_message,
            "avg_output_per_message": self.avg_output_per_message,
        }


@dataclass
class CostEstimate:
    """Cost estimation for token usage"""
    model: str
    input_cost: float = 0.0
    output_cost: float = 0.0
    total_cost: float = 0.0
    currency: str = "USD"
    costs_over_time: List[Dict[str, Any]] = field(default_factory=list)

    def to_dict(self) -> Dict[str, Any]:
        return {
            "model": self.model,
            "input_cost": round(self.input_cost, 4),
            "output_cost": round(self.output_cost, 4),
            "total_cost": round(self.total_cost, 4),
            "currency": self.currency,
        }


@dataclass
class QualityMetrics:
    """Conversation quality metrics"""
    clarity_score: float = 0.0      # How clear/specific are prompts
    efficiency_score: float = 0.0   # How efficient is the conversation
    success_rate: float = 0.0       # Task completion rate
    iteration_count: float = 0.0    # Average iterations per task
    code_quality: float = 0.0       # Code quality indicators

    # Quality indicators
    has_clear_goals: bool = False
    has_examples: bool = False
    has_constraints: bool = False
    has_context: bool = False

    def overall_score(self) -> float:
        """Calculate overall quality score"""
        scores = [
            self.clarity_score,
            self.efficiency_score,
            self.success_rate,
            self.code_quality,
        ]
        return sum(scores) / len(scores) if scores else 0.0

    def to_dict(self) -> Dict[str, Any]:
        return {
            "clarity_score": round(self.clarity_score, 2),
            "efficiency_score": round(self.efficiency_score, 2),
            "success_rate": round(self.success_rate, 2),
            "iteration_count": round(self.iteration_count, 2),
            "code_quality": round(self.code_quality, 2),
            "overall_score": round(self.overall_score(), 2),
            "has_clear_goals": self.has_clear_goals,
            "has_examples": self.has_examples,
            "has_constraints": self.has_constraints,
            "has_context": self.has_context,
        }


@dataclass
class PatternMetrics:
    """Detected patterns in conversations"""
    common_issues: List[str] = field(default_factory=list)
    successful_patterns: List[str] = field(default_factory=list)
    topic_distribution: Dict[str, int] = field(default_factory=dict)
    language_distribution: Dict[str, int] = field(default_factory=dict)
    error_patterns: List[str] = field(default_factory=list)

    def to_dict(self) -> Dict[str, Any]:
        return {
            "common_issues": self.common_issues,
            "successful_patterns": self.successful_patterns,
            "topic_distribution": self.topic_distribution,
            "language_distribution": self.language_distribution,
            "error_patterns": self.error_patterns,
        }


@dataclass
class AnalyticsReport:
    """Complete analytics report"""
    generated_at: datetime = field(default_factory=datetime.now)
    period_start: Optional[datetime] = None
    period_end: Optional[datetime] = None
    session_count: int = 0
    usage: UsageMetrics = field(default_factory=UsageMetrics)
    costs: CostEstimate = field(default_factory=lambda: CostEstimate(model="default"))
    quality: QualityMetrics = field(default_factory=QualityMetrics)
    patterns: PatternMetrics = field(default_factory=PatternMetrics)

    def to_dict(self) -> Dict[str, Any]:
        return {
            "generated_at": self.generated_at.isoformat(),
            "period_start": self.period_start.isoformat() if self.period_start else None,
            "period_end": self.period_end.isoformat() if self.period_end else None,
            "session_count": self.session_count,
            "usage": self.usage.to_dict(),
            "costs": self.costs.to_dict(),
            "quality": self.quality.to_dict(),
            "patterns": self.patterns.to_dict(),
        }

    def to_markdown(self) -> str:
        """Generate markdown report"""
        lines = [
            "# UCTS Conversation Analytics Report",
            "",
            f"Generated: {self.generated_at.strftime('%Y-%m-%d %H:%M:%S')}",
            "",
        ]

        if self.period_start and self.period_end:
            lines.append(f"Period: {self.period_start.strftime('%Y-%m-%d')} to {self.period_end.strftime('%Y-%m-%d')}")
            lines.append("")

        lines.extend([
            "## Summary",
            "",
            f"- Sessions analyzed: {self.session_count}",
            f"- Total tokens: {self.usage.total_tokens:,}",
            f"- Estimated cost: ${self.costs.total_cost:.2f}",
            f"- Overall quality: {self.quality.overall_score():.0%}",
            "",
            "## Token Usage",
            "",
            f"| Metric | Value |",
            f"|--------|-------|",
            f"| Input tokens | {self.usage.total_input_tokens:,} |",
            f"| Output tokens | {self.usage.total_output_tokens:,} |",
            f"| Messages | {self.usage.message_count:,} |",
            f"| Avg input/message | {self.usage.avg_input_per_message:.0f} |",
            f"| Avg output/message | {self.usage.avg_output_per_message:.0f} |",
            "",
            "## Cost Breakdown",
            "",
            f"| Component | Cost |",
            f"|-----------|------|",
            f"| Input | ${self.costs.input_cost:.4f} |",
            f"| Output | ${self.costs.output_cost:.4f} |",
            f"| **Total** | **${self.costs.total_cost:.4f}** |",
            "",
            "## Quality Metrics",
            "",
            f"| Metric | Score |",
            f"|--------|-------|",
            f"| Clarity | {self.quality.clarity_score:.0%} |",
            f"| Efficiency | {self.quality.efficiency_score:.0%} |",
            f"| Success rate | {self.quality.success_rate:.0%} |",
            f"| Code quality | {self.quality.code_quality:.0%} |",
            "",
        ])

        if self.patterns.language_distribution:
            lines.append("## Language Distribution")
            lines.append("")
            for lang, count in sorted(self.patterns.language_distribution.items(), key=lambda x: -x[1]):
                lines.append(f"- {lang}: {count}")
            lines.append("")

        if self.patterns.common_issues:
            lines.append("## Common Issues Detected")
            lines.append("")
            for issue in self.patterns.common_issues[:10]:
                lines.append(f"- {issue}")
            lines.append("")

        if self.patterns.successful_patterns:
            lines.append("## Successful Patterns")
            lines.append("")
            for pattern in self.patterns.successful_patterns[:10]:
                lines.append(f"- {pattern}")
            lines.append("")

        return "\n".join(lines)


class ConversationAnalytics:
    """
    Analyze AI conversations for usage, cost, and quality metrics.

    Features:
    - Token usage tracking
    - Cost estimation by model
    - Quality scoring
    - Pattern detection
    - Historical trends
    """

    def __init__(self, config: Optional[AnalyticsConfig] = None):
        self.config = config or AnalyticsConfig()
        self._storage_path = Path(self.config.storage_dir)
        self._storage_path.mkdir(parents=True, exist_ok=True)
        self._history: List[Dict[str, Any]] = []
        self._load_history()

    def _load_history(self):
        """Load analytics history"""
        history_file = self._storage_path / "history.json"
        if history_file.exists():
            try:
                with open(history_file) as f:
                    self._history = json.load(f)
            except Exception as e:
                logger.warning(f"Failed to load history: {e}")

    def _save_history(self):
        """Save analytics history"""
        history_file = self._storage_path / "history.json"
        with open(history_file, 'w') as f:
            json.dump(self._history[-1000:], f, indent=2, default=str)

    def estimate_tokens(self, text: str) -> int:
        """Estimate token count for text"""
        # Rough estimation: ~4 characters per token for English
        return len(text) // 4

    def analyze_session(self, session: "Session") -> AnalyticsReport:
        """Analyze a single conversation session"""
        report = AnalyticsReport()
        report.session_count = 1

        # Calculate usage
        report.usage = self._calculate_usage(session)

        # Calculate costs
        report.costs = self._calculate_costs(report.usage)

        # Calculate quality
        report.quality = self._calculate_quality(session)

        # Detect patterns
        if self.config.detect_patterns:
            report.patterns = self._detect_patterns(session)

        # Record in history
        self._record_session(session, report)

        return report

    def analyze_sessions(self, sessions: List["Session"]) -> AnalyticsReport:
        """Analyze multiple sessions"""
        if not sessions:
            return AnalyticsReport()

        # Analyze each session
        reports = [self.analyze_session(s) for s in sessions]

        # Aggregate
        return self._aggregate_reports(reports)

    def _calculate_usage(self, session: "Session") -> UsageMetrics:
        """Calculate token usage metrics"""
        usage = UsageMetrics()

        for msg in session.messages:
            tokens = self.estimate_tokens(msg.content)
            if msg.role == "user":
                usage.total_input_tokens += tokens
            else:
                usage.total_output_tokens += tokens
            usage.message_count += 1

        # Add code blocks
        for block in session.code_blocks:
            usage.total_output_tokens += self.estimate_tokens(block.content)

        usage.total_tokens = usage.total_input_tokens + usage.total_output_tokens

        if usage.message_count > 0:
            user_messages = sum(1 for m in session.messages if m.role == "user")
            assistant_messages = sum(1 for m in session.messages if m.role != "user")

            if user_messages > 0:
                usage.avg_input_per_message = usage.total_input_tokens / user_messages
            if assistant_messages > 0:
                usage.avg_output_per_message = usage.total_output_tokens / assistant_messages

        return usage

    def _calculate_costs(self, usage: UsageMetrics) -> CostEstimate:
        """Calculate cost estimates"""
        pricing = TOKEN_PRICING.get(self.config.model, TOKEN_PRICING["default"])

        costs = CostEstimate(model=self.config.model)
        costs.input_cost = (usage.total_input_tokens / 1_000_000) * pricing["input"]
        costs.output_cost = (usage.total_output_tokens / 1_000_000) * pricing["output"]
        costs.total_cost = costs.input_cost + costs.output_cost

        return costs

    def _calculate_quality(self, session: "Session") -> QualityMetrics:
        """Calculate conversation quality metrics"""
        quality = QualityMetrics()

        if not session.messages:
            return quality

        user_messages = [m for m in session.messages if m.role == "user"]
        assistant_messages = [m for m in session.messages if m.role != "user"]

        # Clarity score based on prompt characteristics
        clarity_indicators = 0
        for msg in user_messages:
            content = msg.content.lower()

            # Check for clear goal statements
            if any(phrase in content for phrase in
                   ["please", "create", "build", "make", "help me", "i want to", "i need"]):
                clarity_indicators += 1
                quality.has_clear_goals = True

            # Check for examples
            if any(phrase in content for phrase in
                   ["example", "like this", "such as", "for instance"]):
                clarity_indicators += 1
                quality.has_examples = True

            # Check for constraints
            if any(phrase in content for phrase in
                   ["should", "must", "require", "constraint", "limit"]):
                clarity_indicators += 1
                quality.has_constraints = True

            # Check for context
            if len(content) > 100:  # Longer prompts often have more context
                clarity_indicators += 1
                quality.has_context = True

        quality.clarity_score = min(1.0, clarity_indicators / (len(user_messages) * 2)) if user_messages else 0.0

        # Efficiency score (fewer iterations = more efficient)
        if len(user_messages) > 0:
            # Lower message count relative to content = more efficient
            avg_content_per_message = sum(len(m.content) for m in session.messages) / len(session.messages)
            quality.efficiency_score = min(1.0, avg_content_per_message / 500)

        # Estimate success rate from conversation patterns
        success_indicators = sum(1 for m in assistant_messages
                                if any(phrase in m.content.lower()
                                      for phrase in ["complete", "done", "finished", "here's", "created"]))
        quality.success_rate = min(1.0, success_indicators / len(assistant_messages)) if assistant_messages else 0.0

        # Code quality (based on code blocks present)
        if session.code_blocks:
            # Check for various quality indicators in code
            code_quality_indicators = 0
            for block in session.code_blocks:
                code = block.content

                # Documentation
                if '"""' in code or "'''" in code or "/**" in code:
                    code_quality_indicators += 1

                # Error handling
                if "try" in code or "catch" in code or "except" in code:
                    code_quality_indicators += 1

                # Type hints
                if "->" in code or ": str" in code or ": int" in code or ": number" in code:
                    code_quality_indicators += 1

            quality.code_quality = min(1.0, code_quality_indicators / (len(session.code_blocks) * 3))

        # Iteration count
        quality.iteration_count = len(user_messages)

        return quality

    def _detect_patterns(self, session: "Session") -> PatternMetrics:
        """Detect patterns in conversation"""
        patterns = PatternMetrics()

        # Language distribution from code blocks
        for block in session.code_blocks:
            lang = block.language or "unknown"
            patterns.language_distribution[lang] = patterns.language_distribution.get(lang, 0) + 1

        # Detect common issues
        all_content = " ".join(m.content for m in session.messages)
        content_lower = all_content.lower()

        # Error patterns
        error_phrases = [
            "error", "bug", "issue", "problem", "not working", "doesn't work",
            "failed", "crash", "broken", "fix"
        ]
        for phrase in error_phrases:
            if phrase in content_lower:
                patterns.error_patterns.append(phrase)

        # Common issues
        issue_patterns = [
            ("unclear requirements", ["what do you mean", "can you clarify", "not sure what"]),
            ("missing context", ["what is", "where is", "which"]),
            ("scope creep", ["also", "and another thing", "one more"]),
            ("iteration loops", ["try again", "that's not right", "wrong"]),
        ]

        for issue_name, phrases in issue_patterns:
            if any(phrase in content_lower for phrase in phrases):
                patterns.common_issues.append(issue_name)

        # Successful patterns
        success_patterns = [
            ("clear specification", ["specific", "exactly", "precise"]),
            ("examples provided", ["example", "like this", "such as"]),
            ("incremental approach", ["first", "then", "step by step"]),
            ("feedback loop", ["looks good", "perfect", "that works"]),
        ]

        for pattern_name, phrases in success_patterns:
            if any(phrase in content_lower for phrase in phrases):
                patterns.successful_patterns.append(pattern_name)

        # Topic distribution (basic keyword extraction)
        topic_keywords = {
            "api": ["api", "endpoint", "rest", "graphql"],
            "database": ["database", "sql", "query", "postgres", "mysql", "mongo"],
            "frontend": ["react", "vue", "angular", "css", "html", "ui"],
            "backend": ["server", "backend", "express", "flask", "django"],
            "testing": ["test", "spec", "unit test", "coverage"],
            "deployment": ["deploy", "docker", "kubernetes", "aws", "cloud"],
            "security": ["auth", "security", "token", "password", "encryption"],
        }

        for topic, keywords in topic_keywords.items():
            if any(kw in content_lower for kw in keywords):
                patterns.topic_distribution[topic] = patterns.topic_distribution.get(topic, 0) + 1

        return patterns

    def _aggregate_reports(self, reports: List[AnalyticsReport]) -> AnalyticsReport:
        """Aggregate multiple reports"""
        if not reports:
            return AnalyticsReport()

        aggregated = AnalyticsReport()
        aggregated.session_count = len(reports)

        # Aggregate usage
        aggregated.usage.total_input_tokens = sum(r.usage.total_input_tokens for r in reports)
        aggregated.usage.total_output_tokens = sum(r.usage.total_output_tokens for r in reports)
        aggregated.usage.total_tokens = sum(r.usage.total_tokens for r in reports)
        aggregated.usage.message_count = sum(r.usage.message_count for r in reports)

        if aggregated.usage.message_count > 0:
            aggregated.usage.avg_input_per_message = aggregated.usage.total_input_tokens / aggregated.usage.message_count
            aggregated.usage.avg_output_per_message = aggregated.usage.total_output_tokens / aggregated.usage.message_count

        # Aggregate costs
        aggregated.costs = CostEstimate(model=self.config.model)
        aggregated.costs.input_cost = sum(r.costs.input_cost for r in reports)
        aggregated.costs.output_cost = sum(r.costs.output_cost for r in reports)
        aggregated.costs.total_cost = sum(r.costs.total_cost for r in reports)

        # Average quality
        aggregated.quality.clarity_score = sum(r.quality.clarity_score for r in reports) / len(reports)
        aggregated.quality.efficiency_score = sum(r.quality.efficiency_score for r in reports) / len(reports)
        aggregated.quality.success_rate = sum(r.quality.success_rate for r in reports) / len(reports)
        aggregated.quality.code_quality = sum(r.quality.code_quality for r in reports) / len(reports)

        # Aggregate patterns
        all_issues = []
        all_success = []
        all_errors = []

        for r in reports:
            all_issues.extend(r.patterns.common_issues)
            all_success.extend(r.patterns.successful_patterns)
            all_errors.extend(r.patterns.error_patterns)

            for lang, count in r.patterns.language_distribution.items():
                aggregated.patterns.language_distribution[lang] = (
                    aggregated.patterns.language_distribution.get(lang, 0) + count
                )

            for topic, count in r.patterns.topic_distribution.items():
                aggregated.patterns.topic_distribution[topic] = (
                    aggregated.patterns.topic_distribution.get(topic, 0) + count
                )

        # Most common issues/patterns
        aggregated.patterns.common_issues = [
            item for item, _ in Counter(all_issues).most_common(10)
        ]
        aggregated.patterns.successful_patterns = [
            item for item, _ in Counter(all_success).most_common(10)
        ]
        aggregated.patterns.error_patterns = list(set(all_errors))[:10]

        return aggregated

    def _record_session(self, session: "Session", report: AnalyticsReport):
        """Record session analytics in history"""
        record = {
            "timestamp": datetime.now().isoformat(),
            "source": session.source,
            "tokens": report.usage.total_tokens,
            "cost": report.costs.total_cost,
            "quality": report.quality.overall_score(),
            "messages": report.usage.message_count,
        }

        self._history.append(record)
        self._save_history()

    def get_historical_report(
        self,
        days: int = 30,
        start_date: Optional[datetime] = None,
        end_date: Optional[datetime] = None
    ) -> Dict[str, Any]:
        """Get historical analytics"""
        if start_date is None:
            start_date = datetime.now() - timedelta(days=days)
        if end_date is None:
            end_date = datetime.now()

        # Filter history
        filtered = []
        for record in self._history:
            try:
                ts = datetime.fromisoformat(record["timestamp"])
                if start_date <= ts <= end_date:
                    filtered.append(record)
            except Exception:
                pass

        if not filtered:
            return {"message": "No data for period", "records": []}

        # Aggregate
        total_tokens = sum(r.get("tokens", 0) for r in filtered)
        total_cost = sum(r.get("cost", 0) for r in filtered)
        avg_quality = sum(r.get("quality", 0) for r in filtered) / len(filtered)

        # Daily breakdown
        daily = defaultdict(lambda: {"tokens": 0, "cost": 0, "sessions": 0})
        for record in filtered:
            try:
                day = datetime.fromisoformat(record["timestamp"]).strftime("%Y-%m-%d")
                daily[day]["tokens"] += record.get("tokens", 0)
                daily[day]["cost"] += record.get("cost", 0)
                daily[day]["sessions"] += 1
            except Exception:
                pass

        return {
            "period": {
                "start": start_date.isoformat(),
                "end": end_date.isoformat(),
            },
            "summary": {
                "total_sessions": len(filtered),
                "total_tokens": total_tokens,
                "total_cost": round(total_cost, 4),
                "avg_quality": round(avg_quality, 2),
            },
            "daily": dict(daily),
        }

    def get_cost_projection(
        self,
        current_usage: UsageMetrics,
        projection_days: int = 30
    ) -> Dict[str, Any]:
        """Project future costs based on current usage"""
        daily_tokens = current_usage.total_tokens
        daily_cost = self._calculate_costs(current_usage).total_cost

        return {
            "current_daily": {
                "tokens": daily_tokens,
                "cost": round(daily_cost, 4),
            },
            "projection": {
                "days": projection_days,
                "total_tokens": daily_tokens * projection_days,
                "total_cost": round(daily_cost * projection_days, 2),
            },
            "monthly_estimate": {
                "tokens": daily_tokens * 30,
                "cost": round(daily_cost * 30, 2),
            },
            "yearly_estimate": {
                "tokens": daily_tokens * 365,
                "cost": round(daily_cost * 365, 2),
            },
        }


# Singleton instance
_analytics: Optional[ConversationAnalytics] = None


def get_analytics(config: Optional[AnalyticsConfig] = None) -> ConversationAnalytics:
    """Get the global analytics instance"""
    global _analytics
    if _analytics is None or config is not None:
        _analytics = ConversationAnalytics(config)
    return _analytics
