#!/usr/bin/env python3
"""
Sora access token tester.

Usage:
  tools/sora-test -at "<ACCESS_TOKEN>"
"""

from __future__ import annotations

import argparse
import base64
import json
import sys
import textwrap
import urllib.error
import urllib.request
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Dict, Optional, Tuple


DEFAULT_BASE_URL = "https://sora.chatgpt.com"
DEFAULT_TIMEOUT = 20
DEFAULT_USER_AGENT = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"


@dataclass
class EndpointResult:
    path: str
    status: int
    request_id: str
    cf_ray: str
    body_preview: str


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Test Sora access token against core backend endpoints.",
        formatter_class=argparse.RawTextHelpFormatter,
        epilog=textwrap.dedent(
            """\
            Examples:
              tools/sora-test -at "eyJhbGciOi..."
              tools/sora-test -at "eyJhbGciOi..." --timeout 30
            """
        ),
    )
    parser.add_argument("-at", "--access-token", required=True, help="Sora/OpenAI access token (JWT)")
    parser.add_argument(
        "--base-url",
        default=DEFAULT_BASE_URL,
        help=f"Base URL for Sora backend (default: {DEFAULT_BASE_URL})",
    )
    parser.add_argument(
        "--timeout",
        type=int,
        default=DEFAULT_TIMEOUT,
        help=f"HTTP timeout seconds (default: {DEFAULT_TIMEOUT})",
    )
    return parser.parse_args()


def mask_token(token: str) -> str:
    if len(token) <= 16:
        return token
    return f"{token[:10]}...{token[-6:]}"


def decode_jwt_payload(token: str) -> Optional[Dict]:
    parts = token.split(".")
    if len(parts) != 3:
        return None
    payload = parts[1]
    payload += "=" * ((4 - len(payload) % 4) % 4)
    payload = payload.replace("-", "+").replace("_", "/")
    try:
        decoded = base64.b64decode(payload)
        return json.loads(decoded.decode("utf-8", errors="replace"))
    except Exception:
        return None


def ts_to_iso(ts: Optional[int]) -> str:
    if not ts:
        return "-"
    try:
        return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat()
    except Exception:
        return "-"


def http_get(base_url: str, path: str, access_token: str, timeout: int) -> EndpointResult:
    url = base_url.rstrip("/") + path
    req = urllib.request.Request(url=url, method="GET")
    req.add_header("Authorization", f"Bearer {access_token}")
    req.add_header("Accept", "application/json, text/plain, */*")
    req.add_header("Origin", DEFAULT_BASE_URL)
    req.add_header("Referer", DEFAULT_BASE_URL + "/")
    req.add_header("User-Agent", DEFAULT_USER_AGENT)

    try:
        with urllib.request.urlopen(req, timeout=timeout) as resp:
            raw = resp.read()
            body = raw.decode("utf-8", errors="replace")
            return EndpointResult(
                path=path,
                status=resp.getcode(),
                request_id=(resp.headers.get("x-request-id") or "").strip(),
                cf_ray=(resp.headers.get("cf-ray") or "").strip(),
                body_preview=body[:500].replace("\n", " "),
            )
    except urllib.error.HTTPError as e:
        raw = e.read()
        body = raw.decode("utf-8", errors="replace")
        return EndpointResult(
            path=path,
            status=e.code,
            request_id=(e.headers.get("x-request-id") if e.headers else "") or "",
            cf_ray=(e.headers.get("cf-ray") if e.headers else "") or "",
            body_preview=body[:500].replace("\n", " "),
        )
    except Exception as e:
        return EndpointResult(
            path=path,
            status=0,
            request_id="",
            cf_ray="",
            body_preview=f"network_error: {e}",
        )


def classify(me_status: int) -> Tuple[str, int]:
    if me_status == 200:
        return "AT looks valid for Sora (/backend/me == 200).", 0
    if me_status == 401:
        return "AT is invalid or expired (/backend/me == 401).", 2
    if me_status == 403:
        return "AT may be blocked by policy/challenge or lacks permission (/backend/me == 403).", 3
    if me_status == 0:
        return "Request failed before reaching Sora (network/proxy/TLS issue).", 4
    return f"Unexpected status on /backend/me: {me_status}", 5


def main() -> int:
    args = parse_args()
    token = args.access_token.strip()
    if not token:
        print("ERROR: empty access token")
        return 1

    payload = decode_jwt_payload(token)
    print("=== Sora AT Test ===")
    print(f"token: {mask_token(token)}")
    if payload:
        exp = payload.get("exp")
        iat = payload.get("iat")
        scopes = payload.get("scp")
        scope_count = len(scopes) if isinstance(scopes, list) else 0
        print(f"jwt.iat: {iat} ({ts_to_iso(iat)})")
        print(f"jwt.exp: {exp} ({ts_to_iso(exp)})")
        print(f"jwt.scope_count: {scope_count}")
    else:
        print("jwt: payload decode failed (token may not be JWT)")

    endpoints = [
        "/backend/me",
        "/backend/nf/check",
        "/backend/project_y/invite/mine",
        "/backend/billing/subscriptions",
    ]

    print("\n--- endpoint checks ---")
    results = []
    for path in endpoints:
        res = http_get(args.base_url, path, token, args.timeout)
        results.append(res)
        print(f"{res.path} -> status={res.status} request_id={res.request_id or '-'} cf_ray={res.cf_ray or '-'}")
        if res.body_preview:
            print(f"  body: {res.body_preview}")

    me_result = next((r for r in results if r.path == "/backend/me"), None)
    me_status = me_result.status if me_result else 0
    summary, code = classify(me_status)
    print("\n--- summary ---")
    print(summary)
    return code


if __name__ == "__main__":
    sys.exit(main())

