#!/usr/bin/env python3
"""Typological matching: if the Voynich is a word-level code of language X,
then X's word-stream statistics (which survive any 1:1 word substitution)
must match Voynichese's. Which candidate language fits best?

Code-invariant metrics only: vocabulary growth (TTR, Heaps, top-100 coverage),
repetition, word-bigram conditional entropy, induced-class syntax (MI excess,
self-class assortativity). All corpora truncated to identical token count.
"""
import re, math, random
from collections import Counter, defaultdict

BASE = "/Users/arcandledger/taxdome/ancient-texts"
valid = lambda w: re.fullmatch(r'[a-z]+', w) is not None

def load_voynich_lines(lang='B'):
    pages, cur, lines = {}, None, []
    locus_re = re.compile(r'^<(f[0-9a-zA-Z]+)\.([^,>]+),\s*([@+=*~$&!])(\w+?)(\d*)>\s*(.*)$')
    hdr = re.compile(r'^<(f[0-9a-zA-Z]+)>')
    for raw in open(f"{BASE}/data/ZL3b-n.txt", encoding='utf-8'):
        raw = raw.rstrip('\n')
        if not raw or raw.startswith('#'): continue
        m = hdr.match(raw)
        if m and not locus_re.match(raw):
            cur = m.group(1); pages[cur] = dict(re.findall(r'\$(\w)=(\w+)', raw)); continue
        m = locus_re.match(raw)
        if not m: continue
        page, _, _, ltype, _, text = m.groups()
        if not ltype.upper().startswith('P'): continue
        if lang and pages.get(page, {}).get('L') != lang: continue
        t = re.sub(r'<!.*?>', '', text)
        t = re.sub(r'<->', '?', t); t = re.sub(r'<%>|<\$>|<@\w+>', '', t)
        t = re.sub(r'@\d+;', '?', t)
        for _ in range(4):
            t = re.sub(r'\[([^:\[\]]*):[^\[\]]*\]', r'\1', t)
        t = t.replace(',', '.'); t = re.sub(r'[!%]', '', t)
        ws = [w for w in t.split('.') if w and valid(w)]
        if len(ws) >= 2: lines.append(ws)
    return lines

def chunk(tokens, seed=0):
    rng = random.Random(seed); lines, i = [], 0
    while i < len(tokens):
        n = rng.randint(6, 12); lines.append(tokens[i:i+n]); i += n
    return [l for l in lines if l]

def trunc_lines(lines, N):
    out, n = [], 0
    for l in lines:
        if n + len(l) > N:
            l = l[:N-n]
        if l: out.append(l); n += len(l)
        if n >= N: break
    return out

# --------- induction (same as Part 11) ---------
def induce(lines, k=8, vmin=13, n_ctx=60, seed=5):
    freq = Counter(w for l in lines for w in l)
    vocab = [w for w, c in freq.most_common() if c >= vmin]
    ctx_words = [w for w, _ in freq.most_common(n_ctx)]
    cidx = {w: i for i, w in enumerate(ctx_words)}
    vidx = {w: i for i, w in enumerate(vocab)}
    L = [[0.0]*(2*n_ctx+2) for _ in vocab]
    for l in lines:
        for i, w in enumerate(l):
            if w not in vidx: continue
            v = L[vidx[w]]
            if i == 0: v[2*n_ctx] += 1
            elif l[i-1] in cidx: v[cidx[l[i-1]]] += 1
            if i == len(l)-1: v[2*n_ctx+1] += 1
            elif l[i+1] in cidx: v[n_ctx+cidx[l[i+1]]] += 1
    for v in L:
        for j in range(len(v)): v[j] = math.log1p(v[j])
        n = math.sqrt(sum(x*x for x in v)) or 1.0
        for j in range(len(v)): v[j] /= n
    rng = random.Random(seed)
    cents = [L[i][:] for i in rng.sample(range(min(len(L), 4*k)), k)]
    assign = [0]*len(L)
    for _ in range(40):
        changed = False
        for i, v in enumerate(L):
            best, bj = -2, 0
            for j, c in enumerate(cents):
                d = sum(a*b for a, b in zip(v, c))
                if d > best: best, bj = d, j
            if assign[i] != bj: assign[i] = bj; changed = True
        for j in range(k):
            mem = [L[i] for i in range(len(L)) if assign[i] == j]
            if not mem: continue
            c = [sum(col)/len(mem) for col in zip(*mem)]
            n = math.sqrt(sum(x*x for x in c)) or 1.0
            cents[j] = [x/n for x in c]
        if not changed: break
    return {w: assign[vidx[w]] for w in vocab}

def class_stats(lines, wc):
    pairs = Counter()
    for l in lines:
        for a, b in zip(l, l[1:]):
            if a in wc and b in wc: pairs[(wc[a], wc[b])] += 1
    n = sum(pairs.values())
    pa, pb = Counter(), Counter()
    for (a, b), c in pairs.items(): pa[a] += c; pb[b] += c
    mi = sum(c/n * math.log2((c/n)/((pa[a]/n)*(pb[b]/n))) for (a, b), c in pairs.items())
    same = sum(c for (a, b), c in pairs.items() if a == b)/n
    exp_same = sum((pa[c]/n)*(pb[c]/n) for c in pa)
    # shuffled-MI baseline
    toks = [w for l in lines for w in l]
    random.Random(3).shuffle(toks)
    sl = chunk(toks, seed=1)
    sp = Counter()
    for l in sl:
        for a, b in zip(l, l[1:]):
            if a in wc and b in wc: sp[(wc[a], wc[b])] += 1
    ns = sum(sp.values())
    sa, sb = Counter(), Counter()
    for (a, b), c in sp.items(): sa[a] += c; sb[b] += c
    mi0 = sum(c/ns * math.log2((c/ns)/((sa[a]/ns)*(sb[b]/ns))) for (a, b), c in sp.items())
    return mi - mi0, same/exp_same

def metrics(lines):
    toks = [w for l in lines for w in l]
    N = len(toks)
    c = Counter(toks)
    half = len(Counter(toks[:N//2]))
    heaps = math.log(len(c)/half)/math.log(2)
    bi = Counter(zip(toks, toks[1:])); n2 = sum(bi.values())
    h12 = -sum(v/n2*math.log2(v/n2) for v in bi.values())
    h1 = -sum(v/len(toks)*math.log2(v/len(toks)) for v in Counter(toks[:-1]).values())
    rep = sum(1 for a, b in zip(toks, toks[1:]) if a == b)/(N-1)
    wc = induce(lines)
    mi_x, assort = class_stats(lines, wc)
    return dict(ttr=len(c)/N, heaps=heaps, top100=sum(f for _, f in c.most_common(100))/N,
                h2w=h12-h1, rep=rep*1000, mi=mi_x, assort=assort)

# --------- corpora at matched size ---------
import unicodedata
def lang_lines(path, N):
    toks = [w.lower() for w in re.findall(r'[^\W\d_]+', open(path, encoding='utf-8', errors='replace').read())]
    # keep only the dominant script (drops prefaces/footnotes in other languages)
    def block(w):
        try: return unicodedata.name(w[0]).split()[0]
        except ValueError: return '?'
    bc = Counter(block(w) for w in toks[:4000])
    dom = bc.most_common(1)[0][0]
    toks = [w for w in toks if block(w) == dom]
    return trunc_lines(chunk(toks), N)

voy_lines = load_voynich_lines('B')
N = min(19000, sum(len(l) for l in voy_lines))
corpora = {'Voynich B': trunc_lines(voy_lines, N)}
SOURCES = [('Latin','latin.txt'), ('English','english.txt'), ('Italian','it.txt'),
           ('Spanish','es.txt'), ('German','de_clean.txt'), ('Finnish','fi_clean.txt'),
           ('Hungarian','typo_hu.txt'), ('Turkish','typo_tr.txt'), ('Hebrew','typo_he.txt'),
           ('Arabic','typo_ar.txt'), ('Persian','typo_fa.txt'), ('Greek','typo_el.txt'),
           ('Czech','typo_cs.txt')]
import os
for name, fn in SOURCES:
    p = f"{BASE}/data/{fn}"
    if not os.path.exists(p):
        src = {'de_clean.txt': 'de2', 'fi_clean.txt': 'fi'}.get(fn)
        if src:
            t = open(f"/tmp/try_{src}.txt", encoding='utf-8', errors='replace').read()
            s = t.find('*** START'); e = t.find('*** END')
            open(p, 'w').write(t[t.find(chr(10), s)+1:e])
        else:
            print(f"skip {name} (no file)"); continue
    lines = lang_lines(p, N)
    if sum(len(l) for l in lines) < N - 500:
        print(f"skip {name} (too short after filtering)"); continue
    corpora[name] = lines

print(f"Matched corpus size: {N} tokens\n")
cols = ['ttr','heaps','top100','h2w','rep','mi','assort']
print(f"{'corpus':<11}" + ''.join(f"{c:>9}" for c in cols))
M = {}
for name, lines in corpora.items():
    m = metrics(lines)
    M[name] = m
    print(f"{name:<11}" + ''.join(f"{m[c]:>9.3f}" for c in cols))

# --------- distance ranking ---------
langs = [k for k in M if k != 'Voynich B']
mu = {c: sum(M[l][c] for l in M)/len(M) for c in cols}
sd = {c: (sum((M[l][c]-mu[c])**2 for l in M)/len(M))**0.5 or 1e-9 for c in cols}
print("\nDistance from Voynich B (z-normalized, lower = closer):")
dists = []
for l in langs:
    d = math.sqrt(sum(((M[l][c]-M['Voynich B'][c])/sd[c])**2 for c in cols))
    dists.append((d, l))
for d, l in sorted(dists):
    print(f"  {l:<10} {d:.2f}")
print("\nPer-metric nearest language to Voynich:")
for c in cols:
    best = min(langs, key=lambda l: abs(M[l][c]-M['Voynich B'][c]))
    print(f"  {c:<8} Voy={M['Voynich B'][c]:.3f}  nearest: {best} ({M[best][c]:.3f})")
