#!/usr/bin/env python3
"""
LLM Cost Calculator - Calculate and compare API costs across providers.

Usage:
    python cost_calculator.py --input 1000000 --output 100000 --models claude-3-haiku,gpt-4o-mini
    python cost_calculator.py -i 1M -o 100K -m all --format json
    python cost_calculator.py --parse-url "https://example.com/#it=1000000&ot=100000&sel=claude-3-haiku"
"""

import argparse
import json
import sys
from urllib.parse import urlparse, parse_qs
from typing import Optional

# Pricing data: (input_cost, output_cost, cached_input_cost) per 1M tokens
PRICING = {
    # Anthropic
    "claude-opus-4": (15.00, 75.00, 1.50),
    "claude-sonnet-4": (3.00, 15.00, 0.30),
    "claude-haiku-4": (0.80, 4.00, 0.08),
    "claude-3.5-sonnet": (3.00, 15.00, 0.30),
    "claude-3-opus": (15.00, 75.00, 1.50),
    "claude-3-sonnet": (3.00, 15.00, 0.30),
    "claude-3-haiku": (0.25, 1.25, 0.03),
    # OpenAI
    "gpt-4.1": (2.00, 8.00, 0.50),
    "gpt-4.1-mini": (0.40, 1.60, 0.10),
    "gpt-4.1-nano": (0.10, 0.40, 0.025),
    "gpt-4o": (2.50, 10.00, 1.25),
    "gpt-4o-mini": (0.15, 0.60, 0.075),
    "gpt-4-turbo": (10.00, 30.00, None),
    "gpt-3.5-turbo": (0.50, 1.50, None),
    "o1": (15.00, 60.00, 7.50),
    "o1-mini": (3.00, 12.00, 1.50),
    "o3-mini": (1.10, 4.40, 0.55),
    # Google
    "gemini-2.5-pro": (1.25, 10.00, 0.31),
    "gemini-2.5-flash": (0.15, 0.60, 0.0375),
    "gemini-2.5-flash-lite": (0.075, 0.30, 0.01875),
    "gemini-2.0-flash": (0.10, 0.40, 0.025),
    "gemini-1.5-pro": (1.25, 5.00, 0.31),
    "gemini-1.5-flash": (0.075, 0.30, 0.01875),
    "gemini-1.5-flash-8b": (0.0375, 0.15, 0.01),
    # Mistral
    "mistral-large": (2.00, 6.00, None),
    "mistral-small": (0.20, 0.60, None),
    "codestral": (0.30, 0.90, None),
    "ministral-8b": (0.10, 0.10, None),
    "ministral-3b": (0.04, 0.04, None),
    # Meta Llama
    "llama-3.3-70b": (0.40, 0.40, None),
    "llama-3.1-405b": (3.00, 3.00, None),
    "llama-3.1-70b": (0.35, 0.40, None),
    "llama-3.1-8b": (0.05, 0.08, None),
    # xAI
    "grok-3": (3.00, 15.00, None),
    "grok-3-mini": (0.30, 0.50, None),
    "grok-2": (2.00, 10.00, None),
    # DeepSeek
    "deepseek-v3": (0.27, 1.10, 0.07),
    "deepseek-r1": (0.55, 2.19, 0.14),
    # Amazon
    "amazon-nova-pro": (0.80, 3.20, None),
    "amazon-nova-lite": (0.06, 0.24, None),
    "amazon-nova-micro": (0.035, 0.14, None),
}


def parse_token_count(value: str) -> int:
    """Parse token count with K/M/B suffixes."""
    value = value.strip().upper()
    multipliers = {"K": 1_000, "M": 1_000_000, "B": 1_000_000_000}
    for suffix, mult in multipliers.items():
        if value.endswith(suffix):
            return int(float(value[:-1]) * mult)
    return int(value)


def calculate_cost(
    model: str,
    input_tokens: int,
    output_tokens: int,
    use_cached: bool = False
) -> dict:
    """Calculate cost for a single model."""
    if model not in PRICING:
        return {"model": model, "error": f"Unknown model: {model}"}
    
    input_price, output_price, cached_price = PRICING[model]
    
    if use_cached and cached_price:
        effective_input_price = cached_price
    else:
        effective_input_price = input_price
    
    input_cost = (input_tokens / 1_000_000) * effective_input_price
    output_cost = (output_tokens / 1_000_000) * output_price
    total_cost = input_cost + output_cost
    
    return {
        "model": model,
        "input_tokens": input_tokens,
        "output_tokens": output_tokens,
        "input_price_per_1m": effective_input_price,
        "output_price_per_1m": output_price,
        "input_cost": round(input_cost, 4),
        "output_cost": round(output_cost, 4),
        "total_cost": round(total_cost, 4),
        "cached": use_cached and cached_price is not None
    }


def parse_url_state(url: str) -> dict:
    """Parse URL fragment state into parameters."""
    parsed = urlparse(url)
    fragment = parsed.fragment
    
    # Parse fragment as query string
    params = {}
    if fragment:
        for pair in fragment.split("&"):
            if "=" in pair:
                key, value = pair.split("=", 1)
                params[key] = value
    
    result = {
        "input_tokens": int(params.get("it", 0)),
        "output_tokens": int(params.get("ot", 0)),
        "input_cost_threshold": float(params.get("ic", 0)) if params.get("ic") else None,
        "cached_input_cost": float(params.get("cic", 0)) if params.get("cic") else None,
        "output_cost_threshold": float(params.get("oc", 0)) if params.get("oc") else None,
        "selected_models": params.get("sel", "").split(",") if params.get("sel") else []
    }
    return result


def generate_url_state(
    input_tokens: int,
    output_tokens: int,
    models: list,
    base_url: str = "https://llm-prices.example.com/",
    ic: Optional[float] = None,
    cic: Optional[float] = None,
    oc: Optional[float] = None
) -> str:
    """Generate URL with fragment state."""
    parts = [f"it={input_tokens}", f"ot={output_tokens}"]
    if ic is not None:
        parts.append(f"ic={ic}")
    if cic is not None:
        parts.append(f"cic={cic}")
    if oc is not None:
        parts.append(f"oc={oc}")
    if models:
        parts.append(f"sel={','.join(models)}")
    
    return f"{base_url}#{('&'.join(parts))}"


def format_output(results: list, format_type: str) -> str:
    """Format results for output."""
    if format_type == "json":
        return json.dumps(results, indent=2)
    
    elif format_type == "markdown":
        lines = ["| Model | Input Cost | Output Cost | Total Cost |",
                 "|-------|------------|-------------|------------|"]
        for r in sorted(results, key=lambda x: x.get("total_cost", float("inf"))):
            if "error" in r:
                lines.append(f"| {r['model']} | Error | {r['error']} | - |")
            else:
                lines.append(f"| {r['model']} | ${r['input_cost']:.4f} | ${r['output_cost']:.4f} | ${r['total_cost']:.4f} |")
        return "\n".join(lines)
    
    else:  # table
        lines = [f"{'Model':<25} {'Input Cost':>12} {'Output Cost':>12} {'Total Cost':>12}",
                 "-" * 65]
        for r in sorted(results, key=lambda x: x.get("total_cost", float("inf"))):
            if "error" in r:
                lines.append(f"{r['model']:<25} {'ERROR':>12} {r['error']}")
            else:
                lines.append(f"{r['model']:<25} ${r['input_cost']:>10.4f} ${r['output_cost']:>10.4f} ${r['total_cost']:>10.4f}")
        return "\n".join(lines)


def main():
    parser = argparse.ArgumentParser(description="Calculate LLM API costs")
    parser.add_argument("-i", "--input", type=str, default="1000000",
                        help="Input tokens (supports K/M/B suffixes)")
    parser.add_argument("-o", "--output", type=str, default="100000",
                        help="Output tokens (supports K/M/B suffixes)")
    parser.add_argument("-m", "--models", type=str, default="all",
                        help="Comma-separated model names or 'all'")
    parser.add_argument("-c", "--cached", action="store_true",
                        help="Use cached input pricing where available")
    parser.add_argument("-f", "--format", choices=["table", "json", "markdown"],
                        default="table", help="Output format")
    parser.add_argument("--parse-url", type=str,
                        help="Parse state from URL fragment")
    parser.add_argument("--generate-url", action="store_true",
                        help="Generate shareable URL")
    parser.add_argument("--list-models", action="store_true",
                        help="List all available models")
    
    args = parser.parse_args()
    
    if args.list_models:
        print("Available models:")
        for model in sorted(PRICING.keys()):
            prices = PRICING[model]
            cached = f", cached: ${prices[2]}" if prices[2] else ""
            print(f"  {model}: input=${prices[0]}, output=${prices[1]}{cached}")
        return
    
    if args.parse_url:
        state = parse_url_state(args.parse_url)
        print(json.dumps(state, indent=2))
        return
    
    input_tokens = parse_token_count(args.input)
    output_tokens = parse_token_count(args.output)
    
    if args.models.lower() == "all":
        models = list(PRICING.keys())
    else:
        models = [m.strip() for m in args.models.split(",")]
    
    if args.generate_url:
        url = generate_url_state(input_tokens, output_tokens, models)
        print(url)
        return
    
    results = []
    for model in models:
        result = calculate_cost(model, input_tokens, output_tokens, args.cached)
        results.append(result)
    
    print(format_output(results, args.format))


if __name__ == "__main__":
    main()
