#!/usr/bin/env python3
"""
Reproducible simulations for "Context as Compute" (blog.pdj.dev).

Generates two figures (themed SVG, transparent background for the dark site):
  1. context-eviction.svg  - hit rate vs context budget B for LRU, LFU,
                             utility-scored, and Belady-MIN (offline optimal).
  2. context-frontier.svg  - quality-per-cost frontier: keep-everything vs a
                             managed working set; plus EAT vs hit rate inset.

Run:  python3 simulations.py
Deterministic (seeded). Requires numpy, matplotlib. No network, no data files.
"""

from __future__ import annotations
import pathlib
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.ticker

SEED = 7
RNG = np.random.default_rng(SEED)
OUT = pathlib.Path(__file__).resolve().parents[2] / "assets" / "figures"
OUT.mkdir(parents=True, exist_ok=True)

# ---- dark-site theme -------------------------------------------------------
ACCENT, C2, C3, C4 = "#4da3ff", "#ff7ab2", "#82d9ff", "#d2a8ff"
TEXT, MUTED, GRID = "#c9c9cf", "#8a8a90", "#2c2c2e"
plt.rcParams.update({
    "figure.facecolor": "none", "axes.facecolor": "none", "savefig.facecolor": "none",
    "font.family": "DejaVu Sans", "font.size": 11,
    "text.color": TEXT, "axes.labelcolor": TEXT,
    "xtick.color": MUTED, "ytick.color": MUTED,
    "axes.edgecolor": GRID, "grid.color": GRID,
    "axes.spines.top": False, "axes.spines.right": False,
    "legend.frameon": False, "figure.dpi": 110,
})

def style(ax):
    ax.grid(True, alpha=0.5, linewidth=0.7)
    ax.tick_params(length=3)
    for s in ("left", "bottom"):
        ax.spines[s].set_linewidth(0.8)

def save(fig, name):
    fig.tight_layout()
    fig.savefig(OUT / name, format="svg", transparent=True, bbox_inches="tight")
    plt.close(fig)
    print(f"wrote {OUT / name}")

# ---------------------------------------------------------------------------
# Reference trace with temporal locality: a slowly drifting working set
# (the items a task is actively touching) plus a Zipfian tail of occasional
# accesses elsewhere. This is the agentic analogue of program locality.
# ---------------------------------------------------------------------------
def reference_trace(n_refs=20000, universe=400, working=24, drift=0.012,
                    tail=0.22, zipf_a=1.25, rng=RNG):
    """Return an integer reference string over `universe` distinct items."""
    refs = np.empty(n_refs, dtype=np.int64)
    center = 0.0
    # Zipfian popularity for the cold tail (ranked items 1..universe).
    ranks = np.arange(1, universe + 1)
    zipf_p = (ranks ** (-zipf_a))
    zipf_p /= zipf_p.sum()
    for i in range(n_refs):
        center = (center + drift) % universe          # working set slowly drifts
        if rng.random() < tail:
            refs[i] = rng.choice(universe, p=zipf_p)   # occasional far access
        else:
            lo = int(center)
            offset = rng.integers(0, working)          # local, coherent subset
            refs[i] = (lo + offset) % universe
    return refs

# ---------------------------------------------------------------------------
# Cache policies. Each returns the hit rate over the reference string for a
# cache of capacity B (number of resident items). Belady-MIN is the offline
# optimal ceiling; LRU and LFU are online; "utility" scores recency+frequency.
# ---------------------------------------------------------------------------
def hit_rate_lru(refs, B):
    cache, recency, hits = set(), {}, 0
    for t, r in enumerate(refs):
        if r in cache:
            hits += 1
        else:
            if len(cache) >= B:
                victim = min(cache, key=lambda x: recency[x])   # least recent
                cache.discard(victim); recency.pop(victim, None)
            cache.add(r)
        recency[r] = t
    return hits / len(refs)

def hit_rate_lfu(refs, B):
    cache, freq, recency, hits = set(), {}, {}, 0
    for t, r in enumerate(refs):
        freq[r] = freq.get(r, 0) + 1
        if r in cache:
            hits += 1
        else:
            if len(cache) >= B:
                # evict least frequent; break ties by least recent
                victim = min(cache, key=lambda x: (freq[x], recency[x]))
                cache.discard(victim)
            cache.add(r)
        recency[r] = t
    return hits / len(refs)

def hit_rate_scored(refs, B, half_life=64.0):
    """Frequency/recency-scored policy: exponentially decayed access score,
    the practical stand-in for a learned utility estimator."""
    cache, score, last, hits = set(), {}, {}, 0
    decay = np.log(2.0) / half_life
    for t, r in enumerate(refs):
        if r in cache:
            hits += 1
        if r in score:                                  # decay then bump
            score[r] = score[r] * np.exp(-decay * (t - last[r])) + 1.0
        else:
            score[r] = 1.0
        last[r] = t
        if r not in cache:
            if len(cache) >= B:
                victim = min(cache, key=lambda x: score[x] * np.exp(-decay * (t - last[x])))
                cache.discard(victim)
            cache.add(r)
    return hits / len(refs)

def hit_rate_belady(refs, B):
    """Offline optimum: evict the item whose next use is farthest in future."""
    n = len(refs)
    # next-use index for every position
    nxt = np.full(n, n, dtype=np.int64)
    seen = {}
    for t in range(n - 1, -1, -1):
        r = refs[t]
        nxt[t] = seen.get(r, n)
        seen[r] = t
    cache, next_use, hits = set(), {}, 0
    for t, r in enumerate(refs):
        if r in cache:
            hits += 1
        else:
            if len(cache) >= B:
                victim = max(cache, key=lambda x: next_use[x])   # farthest future
                cache.discard(victim)
            cache.add(r)
        next_use[r] = nxt[t]
    return hits / len(refs)

# ---------------------------------------------------------------------------
# Figure 1: eviction policy quality. Hit rate vs context budget B.
# Belady-MIN is the ceiling; good online policies approach it; LRU is
# competitive because the trace has locality.
# ---------------------------------------------------------------------------
def fig_eviction():
    refs = reference_trace()
    budgets = np.array([8, 12, 16, 24, 32, 48, 64, 96, 128])
    policies = [
        ("Belady-MIN (offline optimal)", hit_rate_belady, C3, "--"),
        ("utility-scored", hit_rate_scored, ACCENT, "-"),
        ("LRU", hit_rate_lru, C4, "-"),
        ("LFU", hit_rate_lfu, C2, "-"),
    ]
    fig, ax = plt.subplots(figsize=(7.0, 3.9))
    for lab, fn, c, ls in policies:
        ys = [fn(refs, int(B)) for B in budgets]
        ax.plot(budgets, ys, color=c, lw=1.8, ls=ls,
                marker="o", ms=4, label=lab)
    ax.set_xlabel("context budget B  (resident items)")
    ax.set_ylabel("hit rate  P(served from L0)")
    ax.set_xscale("log", base=2)
    ax.set_xticks(budgets)
    ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
    ax.set_ylim(0, 1.02)
    ax.legend(loc="best")
    style(ax)
    save(fig, "context-eviction.svg")

# ---------------------------------------------------------------------------
# Figure 2: cost of context. Per-step compute is proportional to resident
# tokens (attention/KV read). Quality is the fraction of currently-needed
# items resident. "Keep-everything" grows residency (and cost) without bound;
# a "managed working set" evicts to a budget. We plot quality vs expected
# per-step cost (a quality-per-cost frontier) and overlay an EAT(h) inset.
# ---------------------------------------------------------------------------
def fig_frontier():
    refs = reference_trace()
    n = len(refs)

    # Keep-everything: residency = number of distinct items seen so far,
    # so per-step cost rises monotonically; quality is 1 (everything resident)
    # but the cost is the count of distinct items so far, ever-growing.
    distinct_seen = np.zeros(n, dtype=np.int64)
    seen = set()
    for t, r in enumerate(refs):
        seen.add(int(r))
        distinct_seen[t] = len(seen)
    keep_cost = distinct_seen.mean()      # mean resident tokens per step
    keep_quality = 1.0                    # everything needed is resident

    # Managed working set: evict to budget B with the utility-scored policy.
    # Per-step cost = B (resident tokens). Quality = hit rate = fraction of
    # currently-needed items served from L0.
    budgets = np.array([8, 12, 16, 24, 32, 48, 64, 96, 128])
    man_cost = budgets.astype(float)
    man_quality = np.array([hit_rate_scored(refs, int(B)) for B in budgets])

    fig, (ax, axi) = plt.subplots(1, 2, figsize=(7.4, 3.9),
                                  gridspec_kw={"width_ratios": [1.45, 1.0]})

    # ---- left: quality-per-cost frontier ----
    ax.plot(man_cost, man_quality, color=ACCENT, lw=1.8, marker="o", ms=4,
            label="managed working set")
    for B, x, y in zip(budgets, man_cost, man_quality):
        if B in (8, 24, 64, 128):
            ax.annotate(f"B={B}", (x, y), textcoords="offset points",
                        xytext=(6, -11), color=MUTED, fontsize=9)
    ax.scatter([keep_cost], [keep_quality], color=C2, s=42, zorder=4,
               edgecolor="none", label="keep-everything")
    ax.annotate("keep-everything", (keep_cost, keep_quality),
                textcoords="offset points", xytext=(8, -14),
                color=C2, fontsize=9, ha="left")
    ax.set_xlabel("expected per-step cost  (resident tokens)")
    ax.set_ylabel("task quality  (fraction needed resident)")
    ax.set_ylim(0, 1.04)
    ax.set_xlim(0, keep_cost * 1.18)
    ax.legend(loc="lower right")
    style(ax)

    # ---- right: effective access time vs hit rate ----
    # EAT = h*t_fast + (1-h)*t_slow, normalized to t_fast = 1.
    t_fast, t_slow = 1.0, 60.0
    hs = np.linspace(0.0, 1.0, 101)
    eat = hs * t_fast + (1 - hs) * t_slow
    axi.plot(hs, eat, color=C3, lw=1.8)
    # Monte Carlo: draw served/miss outcomes at a few hit rates and average.
    trials = 40000
    for h in (0.3, 0.6, 0.9):
        served = RNG.random(trials) < h
        cost = np.where(served, t_fast, t_slow)
        axi.scatter([h], [cost.mean()], color=C4, s=22, zorder=4, edgecolor="none")
    axi.set_xlabel("hit rate  h")
    axi.set_ylabel("EAT  (units of t_fast)")
    axi.set_xlim(0, 1)
    axi.set_ylim(0, t_slow * 1.04)
    style(axi)

    save(fig, "context-frontier.svg")

if __name__ == "__main__":
    fig_eviction()
    fig_frontier()
    print("done.")
