#!/usr/bin/env python3
"""Can a simple, hand-executable procedure reproduce the Voynich's statistical fingerprint?

Model: a scribe with (1) a small morphological table (prefix/mid/suffix slots, with
slot-to-slot association), (2) three habits per word: REUSE a word already written
(frequency-weighted), MUTATE a recent word by re-rolling one slot, or invent a NOVEL
word from the table, plus (3) mechanical layout rules applied at the moment of writing
(line-final 'm', paragraph-initial gallows, t->p k->f in paragraph-opening lines).
Fitted to Currier language B; evaluated with the same battery as analyze_voynich.py.
"""
import re, math, json, random
from collections import Counter, defaultdict

DATA = "/Users/arcandledger/taxdome/ancient-texts/data"

# ---------- shared helpers (same as analyze_voynich.py) ----------
def shannon(counter):
    n = sum(counter.values())
    return -sum(c/n * math.log2(c/n) for c in counter.values())

def cond_entropy_chars(text):
    uni, bi = Counter(text), Counter(zip(text, text[1:]))
    h1 = shannon(uni)
    n2 = sum(bi.values())
    h12 = -sum(c/n2 * math.log2(c/n2) for c in bi.values())
    return h1, h12 - shannon(Counter(text[:-1]))

def zipf_slope(freqs, maxrank=1000):
    fr = sorted(freqs, reverse=True)[:maxrank]
    pts = [(math.log(r+1), math.log(f)) for r, f in enumerate(fr)]
    n = len(pts); sx = sum(x for x,_ in pts); sy = sum(y for _,y in pts)
    sxx = sum(x*x for x,_ in pts); sxy = sum(x*y for x,y in pts)
    return (n*sxy - sx*sy) / (n*sxx - sx*sx)

def levenshtein(a, b, cap=None):
    if a == b: return 0
    la, lb = len(a), len(b)
    if cap is not None and abs(la-lb) > cap: return cap+1
    prev = list(range(lb+1))
    for i in range(1, la+1):
        cur = [i] + [0]*lb
        for j in range(1, lb+1):
            cur[j] = min(prev[j]+1, cur[j-1]+1, prev[j-1] + (a[i-1]!=b[j-1]))
        prev = cur
    return prev[lb]

def adjacency_stats(tokens, seed=42):
    pairs = list(zip(tokens, tokens[1:]))
    def mean_nd(ps):
        return sum(levenshtein(a,b)/max(len(a),len(b)) for a,b in ps)/len(ps)
    real = mean_nd(pairs)
    sh = tokens[:]; random.Random(seed).shuffle(sh)
    base = mean_nd(list(zip(sh, sh[1:])))
    ident = sum(1 for a,b in pairs if a==b)/len(pairs)
    near = sum(1 for a,b in pairs if a!=b and levenshtein(a,b,2)<=1)/len(pairs)
    return real/base, ident, near

def neighbor_connectivity(vocab):
    buckets = defaultdict(set); vs = set(vocab)
    for w in vs:
        buckets[w].add(w)
        for i in range(len(w)):
            buckets[w[:i]+w[i+1:]].add(w)
    connected = 0; total = 0
    for w in vs:
        cand = set()
        for k in [w] + [w[:i]+w[i+1:] for i in range(len(w))]:
            cand |= buckets.get(k, set())
        nb = sum(1 for c in cand if c != w and levenshtein(c, w, 1) == 1)
        if nb: connected += 1
        total += nb
    return connected/len(vs), total/len(vs)

def word_stats(tokens):
    n = len(tokens)
    mean = sum(len(w) for w in tokens)/n
    var = sum((len(w)-mean)**2 for w in tokens)/n
    return mean, math.sqrt(var)

# ---------- load Currier-B running text ----------
def load_B_lines():
    pages, cur, lines = {}, None, []
    locus_re = re.compile(r'^<(f[0-9a-zA-Z]+)\.([^,>]+),\s*([@+=*~$&!])(\w+?)(\d*)>\s*(.*)$')
    hdr = re.compile(r'^<(f[0-9a-zA-Z]+)>')
    for raw in open(f"{DATA}/ZL3b-n.txt", encoding='utf-8'):
        raw = raw.rstrip('\n')
        if not raw or raw.startswith('#'): continue
        m = hdr.match(raw)
        if m and not locus_re.match(raw):
            cur = m.group(1); pages[cur] = dict(re.findall(r'\$(\w)=(\w+)', raw)); continue
        m = locus_re.match(raw)
        if not m: continue
        page, _, _, ltype, _, text = m.groups()
        v = pages.get(page, {})
        if v.get('L') != 'B' or not ltype.upper().startswith('P'): continue
        par_first = '<%>' in text
        t = re.sub(r'<!.*?>', '', text)
        t = re.sub(r'<->', '?', t); t = re.sub(r'<%>|<\$>|<@\w+>', '', t)
        t = re.sub(r'@\d+;', '?', t)
        for _ in range(4):
            t = re.sub(r'\[([^:\[\]]*):[^\[\]]*\]', r'\1', t)
        t = t.replace(',', '.'); t = re.sub(r'[!%]', '', t)
        ws = [w for w in t.split('.') if w]
        lines.append(dict(words=ws, par_first=par_first))
    return lines

valid = lambda w: re.fullmatch(r'[a-z]+', w) is not None

# ---------- morphology table ----------
PRE = ['qo','ch','sh','da','ol','o','d','y','s','l','q','r','']
END = ['eedy','aiin','aiir','eey','edy','ain','air','am','an','ar','al','dy','ey','ol','or','y','o','n','r','l','s','m','']

def decompose(w):
    """Pick the (pre, mid, end) split that maximizes end length, then pre length."""
    best = None
    for p in PRE:
        if not w.startswith(p): continue
        rest = w[len(p):]
        for e in sorted(END, key=len, reverse=True):
            if rest.endswith(e) and len(rest) >= len(e):
                cand = (len(e), len(p), p, rest[:len(rest)-len(e)], e)
                if best is None or cand[:2] > best[:2]:
                    best = cand
                break
    if best is None: return '', w, ''
    return best[2], best[3], best[4]

class SlotModel:
    """P(pre), P(mid|pre), P(end|mid) with backoff to marginals; temperature tau.
    'm' endings are folded to 'n' when learning: m is treated as a purely
    positional (line-final) variant applied by the layout rules."""
    def __init__(self, tokens, tau, tau_mid, mid_min):
        self.tau = tau
        cp, cm, ce = Counter(), Counter(), Counter()
        jm, je = defaultdict(Counter), defaultdict(Counter)
        for w in tokens:
            w = w.replace('p', 't').replace('f', 'k')   # p/f = paragraph-initial variants
            p, m, e = decompose(w)
            if len(m) > 6: continue
            if e == 'm': e = 'n'
            elif e == 'am': e = 'an'
            cp[p] += 1; cm[m] += 1; ce[e] += 1
            jm[p][m] += 1; je[m][e] += 1
        keep = {m for m, c in cm.items() if c >= mid_min}
        self.pre = self._temper(cp, tau)
        self.mid_marg = self._temper(Counter({m: c for m, c in cm.items() if m in keep}), tau_mid)
        self.end_marg = self._temper(ce, tau)
        self.mid_g = {p: self._temper(Counter({m: c for m, c in cnt.items() if m in keep}), tau_mid)
                      for p, cnt in jm.items() if sum(cnt[m] for m in cnt if m in keep) >= 10}
        self.end_g = {m: self._temper(cnt, tau) for m, cnt in je.items() if sum(cnt.values()) >= 10}
    def _temper(self, c, tau):
        items = [(k, v**(1.0/tau)) for k, v in c.items() if v > 0]
        tot = sum(v for _, v in items) or 1
        return [k for k, _ in items], [v/tot for _, v in items]
    def s_pre(self, rng): return rng.choices(*self.pre)[0]
    def s_mid(self, rng, p):
        d = self.mid_g.get(p, self.mid_marg)
        return rng.choices(*d)[0] if d[0] else ''
    def s_end(self, rng, m):
        d = self.end_g.get(m, self.end_marg)
        return rng.choices(*d)[0] if d[0] else 'dy'

# ---------- generator ----------
def generate(params, n_tokens, models, rng, layout=True):
    sharp, flat = models
    p_reuse, p_mut, window, p_nc, decay, q_head = (params['p_reuse'], params['p_mut'],
        params['window'], params['p_nc'], params['decay'], params['q_head'])
    def novel():
        model = sharp if rng.random() < q_head else flat
        for _ in range(20):
            p = model.s_pre(rng); m = model.s_mid(rng, p); e = model.s_end(rng, m)
            w = p + m + e
            if 1 <= len(w) <= 12: return w
        return 'daiin'
    def mutate(w):
        if rng.random() < p_nc: return w
        p, m, e = decompose(w)
        slot = rng.random()
        if slot < 0.35:   p = sharp.s_pre(rng)
        elif slot < 0.65: m = sharp.s_mid(rng, p)
        else:             e = sharp.s_end(rng, m)
        w2 = p + m + e
        return w2 if 1 <= len(w2) <= 12 else w
    hist, out_lines, stream = [], [], []
    par_lines_left = 0
    weights = [decay**i for i in range(window)]
    while len(stream) < n_tokens:
        par_first = par_lines_left == 0
        if par_first: par_lines_left = rng.randint(3, 14)
        par_lines_left -= 1
        n_words = rng.randint(6, 12)
        base = []
        for i in range(n_words):
            r = rng.random()
            if hist and r < p_reuse:
                if rng.random() < params['p_local']:      # re-read the current page
                    pool = hist[-250:]
                    w = pool[rng.randrange(len(pool))]
                else:                                      # long-term favourites
                    w = hist[rng.randrange(len(hist))]
            elif hist and r < p_reuse + p_mut:
                recent = hist[-window:][::-1]
                w = mutate(rng.choices(recent, weights[:len(recent)])[0])
            else:
                w = novel()
            base.append(w); hist.append(w)      # scribe's memory keeps base forms
        line = base[:]
        if layout:
            if par_first:
                line = [w.replace('t','p').replace('k','f') if rng.random() < 0.45 else w
                        for w in line]
                if line[0][0] not in 'tkpf' and rng.random() < 0.82:
                    line[0] = ('p' if rng.random() < 0.7 else 't') + line[0]
            if rng.random() < 0.38 and line[-1][-1] in 'nrl':
                line[-1] = line[-1][:-1] + 'm'
            if len(line[0]) < 4 and rng.random() < 0.3:
                line[0] = rng.choice('ysd') + line[0]
        out_lines.append(dict(words=line, par_first=par_first))
        stream.extend(line)
    return stream[:n_tokens], out_lines

# ---------- evaluation ----------
def fast_stats(tokens):
    c = Counter(tokens)
    text = ' '.join(tokens)
    h1, h2 = cond_entropy_chars(text)
    mean, sd = word_stats(tokens)
    ratio, ident, near = adjacency_stats(tokens[:12000])
    return dict(types=len(c), zipf=zipf_slope(list(c.values())), h1=h1, h2=h2,
                wlen=mean, wsd=sd, adj=ratio, same=ident, near=near,
                top100=sum(f for _, f in c.most_common(100))/len(tokens))

def score(g, t, n_scale):
    s = 0
    for k, w in [('types',2),('zipf',3),('h2',3),('wlen',2),('wsd',1),
                 ('adj',2),('same',1),('near',2),('top100',3)]:
        tgt = t[k]*n_scale if k == 'types' else t[k]
        s += w * abs(g[k]-tgt) / max(abs(tgt), 1e-9)
    return s

# ---------- main ----------
print("Loading Currier B ...")
B_lines = load_B_lines()
B_tokens = [w for l in B_lines for w in l['words'] if valid(w)]
N = len(B_tokens)
print(f"B: {N} tokens")

print("Computing targets on real B ...")
T = fast_stats(B_tokens)
conn_t, mnb_t = neighbor_connectivity(list(Counter(B_tokens)))
T['conn'], T['mnb'] = conn_t, mnb_t
print("targets:", {k: round(v,4) for k,v in T.items()})

# type count scales sublinearly with corpus size; estimate exponent from B itself (Heaps)
half_types = len(Counter(B_tokens[:N//2]))
heaps_beta = math.log(T['types']/half_types)/math.log(2)
n_eval = 15000
n_scale = (n_eval/N)**heaps_beta
print(f"Heaps beta={heaps_beta:.3f}; type target at {n_eval} tokens ~ {T['types']*n_scale:.0f}")

print("\nTuning (seeded hill-climb) ...")
SEED_P = dict(p_reuse=0.401, p_mut=0.166, window=8, p_nc=0.053, tau=0.73,
              tau_mid=1.07, tau_flat=1.58, q_head=0.787, mid_min=2, decay=0.944,
              p_local=0.25)
LIMITS = dict(p_reuse=(0.30,0.75), p_mut=(0.05,0.40), p_nc=(0.02,0.30),
              tau=(0.55,1.2), tau_mid=(0.7,1.5), tau_flat=(1.2,2.6),
              q_head=(0.3,0.9), decay=(0.7,0.97), p_local=(0.0,0.7))
models = {}
def get_model(p):
    key = (round(p['tau'],2), round(p['tau_mid'],2), round(p['tau_flat'],2), p['mid_min'])
    if key not in models:
        models[key] = (SlotModel(B_tokens, key[0], key[1], p['mid_min']),
                       SlotModel(B_tokens, key[2], key[2], p['mid_min']))
    return models[key]
def evaluate(p):
    toks, _ = generate(p, n_eval, get_model(p), random.Random(7))
    return score(fast_stats(toks), T, n_scale)
rs = random.Random(2026)
best_p = dict(SEED_P)
best_s = evaluate(best_p)
print(f"  seed score {best_s:.3f}")
for it in range(50):
    cand = dict(best_p)
    for k in rs.sample([k for k in LIMITS], rs.choice([1,1,2])):
        lo, hi = LIMITS[k]
        cand[k] = min(hi, max(lo, cand[k] * rs.uniform(0.78, 1.28)))
    if rs.random() < 0.2: cand['window'] = rs.choice([8,16,32,64])
    if rs.random() < 0.15: cand['mid_min'] = rs.choice([1,2])
    if cand['p_reuse'] + cand['p_mut'] > 0.92: continue
    sc = evaluate(cand)
    if sc < best_s:
        best_s, best_p = sc, cand
        print(f"  iter {it:>2} score {sc:.3f}  {dict((k,(round(v,3) if isinstance(v,float) else v)) for k,v in cand.items())}")

print("\nFull-size run with best params ...")
model = get_model(best_p)
toks, glines = generate(best_p, N, model, random.Random(99))
G = fast_stats(toks)
conn_g, mnb_g = neighbor_connectivity(list(Counter(toks)))
G['conn'], G['mnb'] = conn_g, mnb_g

def layout_stats(lines):
    lm = ln = om = on = pf_p = pw = pf_o = ow = gp = pl = go = ol_ = 0
    for l in lines:
        ws = [w for w in l['words'] if valid(w)]
        if len(ws) < 3: continue
        lm += ws[-1].endswith('m'); ln += 1
        for w in ws[:-1]: om += w.endswith('m'); on += 1
        if l['par_first']:
            pf_p += sum(1 for w in ws if 'p' in w or 'f' in w); pw += len(ws)
            gp += ws[0][0] in 'tkpf'; pl += 1
        else:
            pf_o += sum(1 for w in ws if 'p' in w or 'f' in w); ow += len(ws)
            go += ws[0][0] in 'tkpf'; ol_ += 1
    return dict(m_final=lm/ln, m_other=om/on, pf_par=pf_p/pw, pf_oth=pf_o/ow,
                gal_par=gp/pl, gal_oth=go/ol_)

LS_g = layout_stats(glines)
LS_r = layout_stats(B_lines)

print("\n=== REAL Currier B  vs  GENERATED ===")
rows = [('tokens', N, len(toks)), ('types', T['types'], G['types']),
        ('zipf slope', T['zipf'], G['zipf']), ('h1', T['h1'], G['h1']),
        ('h2 cond entropy', T['h2'], G['h2']), ('word len mean', T['wlen'], G['wlen']),
        ('word len sd', T['wsd'], G['wsd']), ('top-100 coverage', T['top100'], G['top100']),
        ('adjacency ratio', T['adj'], G['adj']), ('P(next identical)', T['same'], G['same']),
        ('P(next dist<=1)', T['near'], G['near']), ('connectivity', T['conn'], G['conn']),
        ('mean neighbours', T['mnb'], G['mnb'])]
print(f"{'metric':<20}{'real B':>10}{'generated':>11}")
for name, a, b in rows:
    print(f"{name:<20}{a:>10.3f}{b:>11.3f}" if isinstance(a, float) else f"{name:<20}{a:>10}{b:>11}")
print("\nLayout effects            real B   generated")
for k in LS_r:
    print(f"{k:<24}{LS_r[k]:>8.3f}{LS_g[k]:>11.3f}")

print("\nBest params:", {k: round(v,3) if isinstance(v,float) else v for k,v in best_p.items()})
print("\nTop 10 generated words:", Counter(toks).most_common(10))
print("Top 10 real B words:   ", Counter(B_tokens).most_common(10))
print("\nSample of generated 'Voynichese' (first 4 lines):")
for l in glines[:4]:
    print(("  P> " if l['par_first'] else "     ") + '.'.join(l['words']))

json.dump(dict(targets=T, generated=G, layout_real=LS_r, layout_gen=LS_g, params=best_p),
          open('/Users/arcandledger/taxdome/ancient-texts/generation_results.json','w'), indent=1)
print("\ngeneration_results.json written")
