#!/usr/bin/env python3
"""
Reproducible simulations for "The Verifier's Advantage" (blog.pdj.dev).

Generates three figures (themed SVG, transparent background for the dark site):
  1. verifier-coverage.svg   - best-of-n coverage, closed form vs Monte Carlo
  2. verifier-precision.svg  - precision vs stacked verifiers, independent vs correlated
  3. verifier-frontier.svg   - cost-reliability frontier: scale generator vs add verifiers

Run:  python3 simulations.py
Deterministic (seeded). Requires numpy, matplotlib. No network, no data files.
"""

from __future__ import annotations
import pathlib
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

SEED = 7
RNG = np.random.default_rng(SEED)
OUT = pathlib.Path(__file__).resolve().parents[2] / "assets" / "figures"
OUT.mkdir(parents=True, exist_ok=True)

# ---- dark-site theme -------------------------------------------------------
ACCENT, C2, C3, C4 = "#4da3ff", "#ff7ab2", "#82d9ff", "#d2a8ff"
TEXT, MUTED, GRID = "#c9c9cf", "#8a8a90", "#2c2c2e"
plt.rcParams.update({
    "figure.facecolor": "none", "axes.facecolor": "none", "savefig.facecolor": "none",
    "font.family": "DejaVu Sans", "font.size": 11,
    "text.color": TEXT, "axes.labelcolor": TEXT,
    "xtick.color": MUTED, "ytick.color": MUTED,
    "axes.edgecolor": GRID, "grid.color": GRID,
    "axes.spines.top": False, "axes.spines.right": False,
    "legend.frameon": False, "figure.dpi": 110,
})

def style(ax):
    ax.grid(True, alpha=0.5, linewidth=0.7)
    ax.tick_params(length=3)
    for s in ("left", "bottom"):
        ax.spines[s].set_linewidth(0.8)

def save(fig, name):
    fig.tight_layout()
    fig.savefig(OUT / name, format="svg", transparent=True, bbox_inches="tight")
    plt.close(fig)
    print(f"wrote {OUT / name}")

# ---------------------------------------------------------------------------
# Figure 1: best-of-n coverage. Closed form 1-(1-p)^n vs Monte Carlo.
# ---------------------------------------------------------------------------
def fig_coverage():
    ns = np.arange(1, 41)
    ps = [0.10, 0.25, 0.50]
    colors = [C2, ACCENT, C3]
    trials = 20000
    fig, ax = plt.subplots(figsize=(7.0, 3.9))
    for p, c in zip(ps, colors):
        theory = 1 - (1 - p) ** ns
        ax.plot(ns, theory, color=c, lw=1.8, label=f"p = {p:.2f}")
        # Monte Carlo: probability at least one of n samples is correct
        sub = ns[::4]
        mc = []
        for n in sub:
            draws = RNG.random((trials, n)) < p
            mc.append(draws.any(axis=1).mean())
        ax.scatter(sub, mc, color=c, s=18, zorder=3, edgecolor="none")
    ax.set_xlabel("samples drawn, n")
    ax.set_ylabel("P(at least one correct)")
    ax.set_ylim(0, 1.02)
    ax.set_xlim(1, 40)
    ax.legend(loc="best")
    style(ax)
    save(fig, "verifier-coverage.svg")

# ---------------------------------------------------------------------------
# Figure 2: precision P(correct | accepted) vs k stacked verifiers.
# Independent false-accepts vs a shared-failure (correlated) model with floor.
# Sensitivity s = 1 (verifiers never reject a correct answer) isolates the
# false-accept effect; see the essay for the recall/cost treatment.
# ---------------------------------------------------------------------------
def precision(p, a0, k, rho):
    fa = rho * a0 + (1 - rho) * a0 ** k          # effective false-accept rate
    return (p) / (p + (1 - p) * fa)

def fig_precision():
    p, a0 = 0.5, 0.35
    ks = np.arange(1, 9)
    rhos = [(0.0, ACCENT, "independent  (rho = 0)"),
            (0.05, C4, "weakly correlated  (rho = 0.05)"),
            (0.15, C2, "correlated  (rho = 0.15)")]
    trials = 60000
    fig, ax = plt.subplots(figsize=(7.0, 3.9))
    for rho, c, lab in rhos:
        theory = [precision(p, a0, k, rho) for k in ks]
        ax.plot(ks, theory, color=c, lw=1.8, marker="o", ms=4, label=lab)
        if rho > 0:
            ax.axhline(p / (p + (1 - p) * rho * a0), color=c, lw=0.8, ls=":", alpha=0.7)
        # Monte Carlo validation at each k
        mc = []
        for k in ks:
            correct = RNG.random(trials) < p
            shared = RNG.random(trials) < rho
            indep_fa = (RNG.random((trials, k)) < a0).all(axis=1)
            shared_fa = RNG.random(trials) < a0
            wrong_accept = np.where(shared, shared_fa, indep_fa)
            accept = np.where(correct, True, wrong_accept)
            mc.append(correct[accept].mean())
        ax.scatter(ks, mc, color=c, s=14, zorder=3, edgecolor="none", alpha=0.9)
    ax.set_xlabel("stacked verifiers, k")
    ax.set_ylabel("precision  P(correct | accepted)")
    ax.set_ylim(0.6, 1.005)
    ax.legend(loc="best")
    style(ax)
    save(fig, "verifier-precision.svg")

# ---------------------------------------------------------------------------
# Figure 3: cost-reliability frontier. Spend on the generator vs on verifiers.
# Cost per verified-correct output = (c_g + k c_v) / (p * P_accept_correct).
# ---------------------------------------------------------------------------
def fig_frontier():
    a0, c0, cv = 0.35, 1.0, 0.05
    # Strategy A: scale the generator (raise p), single verifier.
    psA = np.linspace(0.30, 0.95, 60)
    cgA = c0 / (1 - psA)                     # pushing p toward 1 is expensive
    relA = psA / (psA + (1 - psA) * a0)      # precision with k=1
    costA = (cgA + cv) / psA                 # cost per correct-accepted output
    # Strategy B: cheap generator (p fixed), add verifiers.
    pB = 0.40
    cgB = c0 / (1 - pB)
    ks = np.arange(1, 8)
    relB = np.array([pB / (pB + (1 - pB) * a0 ** k) for k in ks])
    costB = (cgB + ks * cv) / pB
    fig, ax = plt.subplots(figsize=(7.0, 3.9))
    ax.plot(costA, relA, color=C2, lw=1.8, label="scale generator (raise p)")
    ax.plot(costB, relB, color=ACCENT, lw=1.8, marker="o", ms=4,
            label="add verifiers (p = 0.40)")
    for k, x, y in zip(ks, costB, relB):
        if k in (1, 3, 5, 7):
            ax.annotate(f"k={k}", (x, y), textcoords="offset points",
                        xytext=(6, -10), color=MUTED, fontsize=9)
    ax.set_xlabel("expected cost per verified-correct output")
    ax.set_ylabel("reliability  (precision)")
    ax.set_ylim(0.6, 1.005)
    ax.legend(loc="best")
    style(ax)
    save(fig, "verifier-frontier.svg")

if __name__ == "__main__":
    fig_coverage()
    fig_precision()
    fig_frontier()
    print("done.")
