#!/usr/bin/env python3
"""
Tokenization bias test — Groq API.
Measures EN vs ES token overhead using usage.prompt_tokens from Groq completions.

Usage: GROQ_API_KEY=your-key python3 test_groq_tokenizer.py [--model MODEL]

Available models: llama-3.3-70b-versatile, qwen-qwq-32b, gpt-oss-120b, etc.

Part of: "Your AI charges you up to 67% more for not speaking English"
https://theprivatestack.com/research/tokenization-bias
"""

import argparse
import json
import os
import urllib.request


def count_tokens_groq(text, model, api_key):
    url = "https://api.groq.com/openai/v1/chat/completions"
    data = json.dumps({
        "model": model,
        "messages": [{"role": "user", "content": text}],
        "max_tokens": 1,
        "temperature": 0,
    }).encode()
    req = urllib.request.Request(url, data=data, headers={
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}",
        "User-Agent": "tokenizer-bias-test/1.0",
    })
    with urllib.request.urlopen(req) as resp:
        result = json.loads(resp.read())
    return result["usage"]["prompt_tokens"]


WORD_PAIRS = [
    ("contract", "contrato"),
    ("agreement", "acuerdo"),
    ("corporation", "corporación"),
    ("shareholders", "accionistas"),
    ("liability", "responsabilidad"),
    ("compliance", "cumplimiento"),
    ("articles of incorporation", "acta constitutiva"),
    ("tax identification number", "registro federal de contribuyentes"),
    ("artificial intelligence", "inteligencia artificial"),
    ("data sovereignty", "soberanía de datos"),
]

CLAUSES_EN = {
    "NDA": "The Receiving Party agrees to hold in strict confidence all Confidential Information disclosed by the Disclosing Party and shall not disclose such information to any third party without prior written consent. This obligation shall survive the termination of this agreement for a period of five years.",
    "Tax": "The taxpayer shall file annual returns with the Internal Revenue Service for each taxable year. All digital tax receipts must comply with current regulations and include the employer identification number assigned by the federal tax authority.",
}

CLAUSES_ES = {
    "NDA": "La Parte Receptora se obliga a mantener en estricta confidencialidad toda la Información Confidencial revelada por la Parte Divulgante y no divulgará dicha información a terceros sin el consentimiento previo por escrito. Esta obligación sobrevivirá la terminación del presente contrato por un período de cinco años.",
    "Fiscal": "El contribuyente deberá presentar declaraciones anuales ante el Servicio de Administración Tributaria por cada ejercicio fiscal. Todos los comprobantes fiscales digitales por Internet deberán cumplir con la normatividad vigente e incluir el Registro Federal de Contribuyentes asignado por la autoridad fiscal federal.",
}


def main():
    parser = argparse.ArgumentParser(description="Test tokenizer bias via Groq API")
    parser.add_argument("--model", default="llama-3.3-70b-versatile",
                        help="Groq model (default: llama-3.3-70b-versatile)")
    parser.add_argument("--api-key", default=os.environ.get("GROQ_API_KEY", ""),
                        help="Groq API key (or set GROQ_API_KEY env var)")
    args = parser.parse_args()

    if not args.api_key:
        print("Set GROQ_API_KEY or pass --api-key")
        return

    model, api_key = args.model, args.api_key

    try:
        baseline = count_tokens_groq("test", model, api_key)
        print(f"Connected to Groq — model: {model}")
        print(f"System prompt baseline: {baseline} tokens\n")
    except Exception as e:
        print(f"Groq error: {e}")
        return

    # Word pairs
    print("=" * 90)
    print(f"  GROQ ({model}) — EN vs ES")
    print("=" * 90)
    print(f"{'English':<38} {'Tok':>5}  {'Spanish':<38} {'Tok':>5}")
    print("-" * 90)

    for en, es in WORD_PAIRS:
        t_en = count_tokens_groq(en, model, api_key)
        t_es = count_tokens_groq(es, model, api_key)
        print(f"{en:<38} {t_en:>5}  {es:<38} {t_es:>5} {t_es - t_en:>+4}")

    # Clause comparison (with baseline correction)
    print(f"\n  Clause comparison (baseline-corrected, -{baseline} tokens):")

    clause_pairs = [("NDA", "NDA"), ("Tax", "Fiscal")]
    for en_key, es_key in clause_pairs:
        t_en = count_tokens_groq(CLAUSES_EN[en_key], model, api_key) - baseline
        t_es = count_tokens_groq(CLAUSES_ES[es_key], model, api_key) - baseline
        pct = ((t_es / t_en) - 1) * 100
        print(f"    {en_key:<20} EN: {t_en:>4} tokens  ES: {t_es:>4} tokens  Overhead: {pct:>+.1f}%")

    print(f"\n  Note: raw counts include ~{baseline} tokens of system prompt overhead.")
    print("  Corrected counts subtract that baseline to isolate text tokenization.\n")


if __name__ == "__main__":
    main()
