#!/usr/bin/env python3
"""Word-class induction for Voynichese (Currier B): cluster frequent words by
context distributions only, then measure (a) class-transition structure (MI),
(b) alignment of induced classes with suffix morphology, (c) calibration vs
English / Latin / the fitted generator on the identical pipeline.
"""
import re, math, random
from collections import Counter, defaultdict

BASE = "/Users/arcandledger/taxdome/ancient-texts"
valid = lambda w: re.fullmatch(r'[a-z]+', w) is not None

def load_lines(lang='B'):
    pages, cur, lines = {}, None, []
    locus_re = re.compile(r'^<(f[0-9a-zA-Z]+)\.([^,>]+),\s*([@+=*~$&!])(\w+?)(\d*)>\s*(.*)$')
    hdr = re.compile(r'^<(f[0-9a-zA-Z]+)>')
    for raw in open(f"{BASE}/data/ZL3b-n.txt", encoding='utf-8'):
        raw = raw.rstrip('\n')
        if not raw or raw.startswith('#'): continue
        m = hdr.match(raw)
        if m and not locus_re.match(raw):
            cur = m.group(1); pages[cur] = dict(re.findall(r'\$(\w)=(\w+)', raw)); continue
        m = locus_re.match(raw)
        if not m: continue
        page, _, _, ltype, _, text = m.groups()
        if not ltype.upper().startswith('P'): continue
        if lang and pages.get(page, {}).get('L') != lang: continue
        t = re.sub(r'<!.*?>', '', text)
        t = re.sub(r'<->', '?', t); t = re.sub(r'<%>|<\$>|<@\w+>', '', t)
        t = re.sub(r'@\d+;', '?', t)
        for _ in range(4):
            t = re.sub(r'\[([^:\[\]]*):[^\[\]]*\]', r'\1', t)
        t = t.replace(',', '.'); t = re.sub(r'[!%]', '', t)
        ws = [w for w in t.split('.') if w and valid(w)]
        if len(ws) >= 2: lines.append(ws)
    return lines

def chunk(tokens, seed=0):
    rng = random.Random(seed); lines, i = [], 0
    while i < len(tokens):
        n = rng.randint(6, 12); lines.append(tokens[i:i+n]); i += n
    return [l for l in lines if len(l) >= 2]

# ---------------- induction pipeline ----------------
def induce(lines, k=8, vmin=15, n_ctx=60, seed=5):
    freq = Counter(w for l in lines for w in l)
    vocab = [w for w, c in freq.most_common() if c >= vmin]
    ctx_words = [w for w, _ in freq.most_common(n_ctx)]
    cidx = {w: i for i, w in enumerate(ctx_words)}
    vidx = {w: i for i, w in enumerate(vocab)}
    L = [[0.0]*(2*n_ctx+2) for _ in vocab]   # left ctx | right ctx | line-initial | line-final
    for l in lines:
        for i, w in enumerate(l):
            if w not in vidx: continue
            v = L[vidx[w]]
            if i == 0: v[2*n_ctx] += 1
            else:
                c = l[i-1]
                if c in cidx: v[cidx[c]] += 1
            if i == len(l)-1: v[2*n_ctx+1] += 1
            else:
                c = l[i+1]
                if c in cidx: v[n_ctx+cidx[c]] += 1
    # log-smooth + L2 normalize
    for v in L:
        for j in range(len(v)): v[j] = math.log1p(v[j])
        n = math.sqrt(sum(x*x for x in v)) or 1.0
        for j in range(len(v)): v[j] /= n
    # k-means (cosine via dot on unit vectors)
    rng = random.Random(seed)
    cents = [L[i][:] for i in rng.sample(range(min(len(L), 4*k)), k)]
    assign = [0]*len(L)
    for _ in range(40):
        changed = False
        for i, v in enumerate(L):
            best, bj = -2, 0
            for j, c in enumerate(cents):
                d = sum(a*b for a, b in zip(v, c))
                if d > best: best, bj = d, j
            if assign[i] != bj: assign[i] = bj; changed = True
        for j in range(k):
            members = [L[i] for i in range(len(L)) if assign[i] == j]
            if not members: continue
            c = [sum(col)/len(members) for col in zip(*members)]
            n = math.sqrt(sum(x*x for x in c)) or 1.0
            cents[j] = [x/n for x in c]
        if not changed: break
    word_class = {w: assign[vidx[w]] for w in vocab}
    return vocab, word_class, freq

def class_mi(lines, word_class, k, shuffle_seed=None):
    toks = [w for l in lines for w in l]
    if shuffle_seed is not None:
        toks = toks[:]; random.Random(shuffle_seed).shuffle(toks)
        lines = chunk(toks, seed=1)
    pairs = Counter()
    for l in lines:
        for a, b in zip(l, l[1:]):
            if a in word_class and b in word_class:
                pairs[(word_class[a], word_class[b])] += 1
    n = sum(pairs.values())
    pa, pb = Counter(), Counter()
    for (a, b), c in pairs.items(): pa[a] += c; pb[b] += c
    mi = 0.0
    for (a, b), c in pairs.items():
        mi += c/n * math.log2((c/n)/((pa[a]/n)*(pb[b]/n)))
    return mi, n

def suffix_class(w):
    for s in ['eedy','edy','dy','eey','ey','y','aiin','ain','ol','or','al','ar','am','o','l','r','n','s']:
        if w.endswith(s) and len(w) > len(s): return s
    return '-'

def run(name, lines, k=8):
    vocab, wc, freq = induce(lines, k=k)
    mi, n = class_mi(lines, wc, k)
    mis = [class_mi(lines, wc, k, shuffle_seed=s)[0] for s in (11, 12, 13)]
    mi0 = sum(mis)/len(mis)
    print(f"\n{name}: vocab={len(vocab)} clustered words, {n} class bigrams")
    print(f"  class-transition MI = {mi:.3f} bits   shuffled = {mi0:.3f}   excess = {mi-mi0:.3f}")
    return vocab, wc, freq, mi-mi0

# ---------------- corpora ----------------
voy = load_lines('B')
eng = chunk(re.findall(r'[a-z]+', open(f"{BASE}/data/english.txt").read().lower())[:23000])
lat = chunk(re.findall(r'[a-z]+', open(f"{BASE}/data/latin.txt").read().lower())[:23000])
# generator
src = open(f"{BASE}/topic_test.py").read()
ns = {}
exec(src[:src.index('# ---------------- main')] if '# ---------------- main' in src else src, ns)
B_toks = [w for l in voy for w in l]
gen = chunk(ns['generate'](len(B_toks), B_toks, random.Random(7)), seed=2)

vocab, wc, freq, _ = run("VOYNICH B", voy)
run("ENGLISH (same size)", eng)
run("LATIN (same size)", lat)
run("GENERATED control", gen)

# ---------------- describe Voynich classes ----------------
k = 8
print("\n=== INDUCED VOYNICH WORD CLASSES (context-only clustering) ===")
# positional rates per class
pos = defaultdict(lambda: [0, 0, 0])  # initial, final, total
for l in voy:
    for i, w in enumerate(l):
        if w not in wc: continue
        c = wc[w]
        pos[c][2] += 1
        if i == 0: pos[c][0] += 1
        if i == len(l)-1: pos[c][1] += 1
trans = Counter()
ct = Counter()
for l in voy:
    for a, b in zip(l, l[1:]):
        if a in wc and b in wc:
            trans[(wc[a], wc[b])] += 1; ct[wc[a]] += 1
for c in range(k):
    members = sorted([w for w in wc if wc[w] == c], key=lambda w: -freq[w])
    if not members: continue
    sfx = Counter(suffix_class(w) for w in members)
    tot = pos[c][2] or 1
    nxt = sorted(((trans[(c, d)]/max(1, ct[c]), d) for d in range(k)), reverse=True)[:3]
    print(f"\nClass {c}: {len(members)} words | line-initial {100*pos[c][0]/tot:.0f}% | line-final {100*pos[c][1]/tot:.0f}%")
    print(f"  members: {', '.join(members[:10])}")
    print(f"  endings: {dict(sfx.most_common(4))}")
    print(f"  follows -> " + ', '.join(f"C{d}({100*p:.0f}%)" for p, d in nxt))
