#!/usr/bin/env python3
"""Does the Voynich manuscript have TOPIC structure beyond what a mindless generator produces?

Test: Jensen-Shannon divergence between word distributions of illustration-defined
sections (herbal/biological/stars/pharma), restricted to Currier language B so that
'language' is controlled. Compared against:
  - within-section page-split nulls (same estimator, same sample size)
  - the fitted stationary generator, cut into the SAME page/section shapes
  - real English cut into the same shapes (a text that genuinely has topics)
  - token-shuffled real B (floor)
Plus G2 (log-likelihood) keyword extraction per section.
"""
import re, math, random
from collections import Counter, defaultdict

DATA = "/Users/arcandledger/taxdome/ancient-texts/data"
valid = lambda w: re.fullmatch(r'[a-z]+', w) is not None

# ---------------- parse pages with metadata ----------------
def load_pages():
    pages_meta, order = {}, []
    page_tokens = defaultdict(list)
    page_lines = defaultdict(list)
    locus_re = re.compile(r'^<(f[0-9a-zA-Z]+)\.([^,>]+),\s*([@+=*~$&!])(\w+?)(\d*)>\s*(.*)$')
    hdr = re.compile(r'^<(f[0-9a-zA-Z]+)>')
    cur = None
    for raw in open(f"{DATA}/ZL3b-n.txt", encoding='utf-8'):
        raw = raw.rstrip('\n')
        if not raw or raw.startswith('#'): continue
        m = hdr.match(raw)
        if m and not locus_re.match(raw):
            cur = m.group(1)
            pages_meta[cur] = dict(re.findall(r'\$(\w)=(\w+)', raw))
            order.append(cur); continue
        m = locus_re.match(raw)
        if not m: continue
        page, _, _, ltype, _, text = m.groups()
        if not ltype.upper().startswith('P'): continue
        par_first = '<%>' in text
        t = re.sub(r'<!.*?>', '', text)
        t = re.sub(r'<->', '?', t); t = re.sub(r'<%>|<\$>|<@\w+>', '', t)
        t = re.sub(r'@\d+;', '?', t)
        for _ in range(4):
            t = re.sub(r'\[([^:\[\]]*):[^\[\]]*\]', r'\1', t)
        t = t.replace(',', '.'); t = re.sub(r'[!%]', '', t)
        ws = [w for w in t.split('.') if w]
        page_lines[page].append((ws, par_first))
        page_tokens[page].extend(w for w in ws if valid(w))
    return pages_meta, order, page_tokens, page_lines

pages_meta, order, page_tokens, page_lines = load_pages()

# Currier-B pages grouped by illustration type ($I)
B_pages = [p for p in order if pages_meta.get(p, {}).get('L') == 'B' and page_tokens[p]]
sec_of = {p: pages_meta[p].get('I', '?') for p in B_pages}
sec_tokens = defaultdict(int)
for p in B_pages: sec_tokens[sec_of[p]] += len(page_tokens[p])
print("Currier-B sections ($I = illustration type): tokens per section")
NAMES = dict(H='herbal', A='astronomical', Z='zodiac', B='biological', C='cosmological',
             P='pharmaceutical', S='stars/recipes', T='text-only')
for s, n in sorted(sec_tokens.items(), key=lambda x: -x[1]):
    print(f"  {s} ({NAMES.get(s, s)}): {n} tokens, {sum(1 for p in B_pages if sec_of[p]==s)} pages")
SECS = [s for s, n in sec_tokens.items() if n >= 2500]
print("Using sections:", SECS)

# ---------------- JSD machinery ----------------
def jsd(c1, c2):
    n1, n2 = sum(c1.values()), sum(c2.values())
    keys = set(c1) | set(c2)
    s = 0.0
    for k in keys:
        p = c1.get(k, 0)/n1; q = c2.get(k, 0)/n2; m = (p+q)/2
        if p: s += 0.5*p*math.log2(p/m)
        if q: s += 0.5*q*math.log2(q/m)
    return s

def sample_tokens(pgs, ptok, N, rng):
    """Sample ~N tokens by drawing whole pages (preserves page-level burstiness)."""
    pgs = pgs[:]; rng.shuffle(pgs)
    out = []
    for p in pgs:
        out.extend(ptok[p])
        if len(out) >= N: break
    return Counter(out[:N])

def between_within(groups, ptok, N=3500, trials=25, seed=1):
    """groups: dict section -> list of pages. Returns mean between-JSD, mean within-JSD."""
    rng = random.Random(seed)
    secs = list(groups)
    bet = []
    for _ in range(trials):
        a, b = rng.sample(secs, 2)
        bet.append(jsd(sample_tokens(groups[a], ptok, N, rng),
                       sample_tokens(groups[b], ptok, N, rng)))
    wit = []
    for _ in range(trials):
        s = rng.choice(secs)
        pgs = groups[s][:]; rng.shuffle(pgs)
        half = len(pgs)//2
        wit.append(jsd(sample_tokens(pgs[:half], ptok, N, rng),
                       sample_tokens(pgs[half:], ptok, N, rng)))
    return sum(bet)/len(bet), sum(wit)/len(wit)

groups_real = {s: [p for p in B_pages if sec_of[p] == s] for s in SECS}
# check every section half can yield N tokens
N = min(3500, min(sec_tokens[s] for s in SECS)//2 - 100)
print(f"\nSample size per side: {N} tokens, 25 trials")

b_real, w_real = between_within(groups_real, page_tokens, N)

# ---------------- generated corpus, cut into the same shapes ----------------
PRE = ['qo','ch','sh','da','ol','o','d','y','s','l','q','r','']
END = ['eedy','aiin','aiir','eey','edy','ain','air','am','an','ar','al','dy','ey','ol','or','y','o','n','r','l','s','m','']
def decompose(w):
    best = None
    for p in PRE:
        if not w.startswith(p): continue
        rest = w[len(p):]
        for e in sorted(END, key=len, reverse=True):
            if rest.endswith(e) and len(rest) >= len(e):
                cand = (len(e), len(p), p, rest[:len(rest)-len(e)], e)
                if best is None or cand[:2] > best[:2]: best = cand
                break
    return ('', w, '') if best is None else (best[2], best[3], best[4])

class SlotModel:
    def __init__(self, tokens, tau, tau_mid, mid_min):
        cp, cm, ce = Counter(), Counter(), Counter()
        jm, je = defaultdict(Counter), defaultdict(Counter)
        for w in tokens:
            w = w.replace('p','t').replace('f','k')
            p, m, e = decompose(w)
            if len(m) > 6: continue
            if e == 'm': e = 'n'
            elif e == 'am': e = 'an'
            cp[p] += 1; cm[m] += 1; ce[e] += 1
            jm[p][m] += 1; je[m][e] += 1
        keep = {m for m, c in cm.items() if c >= mid_min}
        T = lambda c, t: ([k for k,_ in [(k, v**(1/t)) for k, v in c.items() if v > 0]],
                          [v/sum(x for _, x in [(k2, v2**(1/t)) for k2, v2 in c.items() if v2 > 0])
                           for _, v in [(k, v**(1/t)) for k, v in c.items() if v > 0]])
        self.pre = T(cp, tau)
        self.mid_marg = T(Counter({m: c for m, c in cm.items() if m in keep}), tau_mid)
        self.end_marg = T(ce, tau)
        self.mid_g = {p: T(Counter({m: c for m, c in cnt.items() if m in keep}), tau_mid)
                      for p, cnt in jm.items() if sum(cnt[m] for m in cnt if m in keep) >= 10}
        self.end_g = {m: T(cnt, tau) for m, cnt in je.items() if sum(cnt.values()) >= 10}
    def s_pre(self, rng): return rng.choices(*self.pre)[0]
    def s_mid(self, rng, p):
        d = self.mid_g.get(p, self.mid_marg)
        return rng.choices(*d)[0] if d[0] else ''
    def s_end(self, rng, m):
        d = self.end_g.get(m, self.end_marg)
        return rng.choices(*d)[0] if d[0] else 'dy'

BP = dict(p_reuse=0.415, p_mut=0.166, window=64, p_nc=0.053, tau=0.73,
          tau_mid=0.991, tau_flat=1.227, q_head=0.9, mid_min=2, decay=0.836, p_local=0.334)

def generate(n_tokens, B_toks, rng):
    sharp = SlotModel(B_toks, BP['tau'], BP['tau_mid'], BP['mid_min'])
    flat = SlotModel(B_toks, BP['tau_flat'], BP['tau_flat'], BP['mid_min'])
    weights = [BP['decay']**i for i in range(BP['window'])]
    def novel():
        model = sharp if rng.random() < BP['q_head'] else flat
        for _ in range(20):
            p = model.s_pre(rng); m = model.s_mid(rng, p); e = model.s_end(rng, m)
            w = p + m + e
            if 1 <= len(w) <= 12: return w
        return 'daiin'
    def mutate(w):
        if rng.random() < BP['p_nc']: return w
        p, m, e = decompose(w)
        r = rng.random()
        if r < 0.35:   p = sharp.s_pre(rng)
        elif r < 0.65: m = sharp.s_mid(rng, p)
        else:          e = sharp.s_end(rng, m)
        w2 = p + m + e
        return w2 if 1 <= len(w2) <= 12 else w
    hist, stream = [], []
    par_left = 0
    while len(stream) < n_tokens:
        par_first = par_left == 0
        if par_first: par_left = rng.randint(3, 14)
        par_left -= 1
        base = []
        for _ in range(rng.randint(6, 12)):
            r = rng.random()
            if hist and r < BP['p_reuse']:
                pool = hist[-250:] if rng.random() < BP['p_local'] else hist
                w = pool[rng.randrange(len(pool))]
            elif hist and r < BP['p_reuse'] + BP['p_mut']:
                recent = hist[-BP['window']:][::-1]
                w = mutate(rng.choices(recent, weights[:len(recent)])[0])
            else:
                w = novel()
            base.append(w); hist.append(w)
        line = base[:]
        if par_first:
            line = [w.replace('t','p').replace('k','f') if rng.random() < 0.45 else w for w in line]
            if line[0][0] not in 'tkpf' and rng.random() < 0.82:
                line[0] = ('p' if rng.random() < 0.7 else 't') + line[0]
        if rng.random() < 0.38 and line[-1][-1] in 'nrl':
            line[-1] = line[-1][:-1] + 'm'
        stream.extend(line)
    return stream[:n_tokens]

B_toks_all = [w for p in B_pages for w in page_tokens[p]]
total_B = len(B_toks_all)
print("\nGenerating stationary pseudo-Voynich of the same size ...")
gen_stream = generate(total_B, B_toks_all, random.Random(123))

def cut_into_shape(stream):
    """Cut a token stream into pseudo-pages mirroring the real B page sizes/sections."""
    ptok, groups = {}, defaultdict(list)
    i = 0
    for p in B_pages:
        n = len(page_tokens[p])
        ptok[p] = stream[i:i+n]; i += n
        groups[sec_of[p]].append(p)
    return ptok, {s: groups[s] for s in SECS}

ptok_gen, groups_gen = cut_into_shape(gen_stream)
b_gen, w_gen = between_within(groups_gen, ptok_gen, N)

# English control: same shapes
eng = re.findall(r'[a-z]+', open(f"{DATA}/english.txt").read().lower())[:total_B]
ptok_eng, groups_eng = cut_into_shape(eng)
b_eng, w_eng = between_within(groups_eng, ptok_eng, N)

# shuffled-real floor
shuf = B_toks_all[:]; random.Random(5).shuffle(shuf)
ptok_sh, groups_sh = cut_into_shape(shuf)
b_sh, w_sh = between_within(groups_sh, ptok_sh, N)

print("\n=== TOPIC-STRUCTURE TEST (JSD between word distributions, bits) ===")
print(f"{'corpus':<22}{'between-section':>16}{'within-section':>16}{'ratio':>8}")
for name, b, w in [('real Voynich B', b_real, w_real), ('generated (stationary)', b_gen, w_gen),
                   ('English, same shapes', b_eng, w_eng), ('shuffled real B', b_sh, w_sh)]:
    print(f"{name:<22}{b:>16.4f}{w:>16.4f}{b/w:>8.3f}")

# ---------------- G2 keywords per section (real B) ----------------
def g2_keywords(groups, ptok, topn=8):
    out = {}
    tot = Counter()
    sec_cnt = {}
    for s, pgs in groups.items():
        c = Counter(w for p in pgs for w in ptok[p])
        sec_cnt[s] = c; tot += c
    Ntot = sum(tot.values())
    for s, c in sec_cnt.items():
        n_s = sum(c.values())
        scores = []
        for w, a in c.items():
            if a < 5: continue
            b = tot[w] - a
            E1 = tot[w]*n_s/Ntot; E2 = tot[w]*(Ntot-n_s)/Ntot
            g = 2*(a*math.log(a/E1) + (b*math.log(b/E2) if b > 0 else 0))
            if a/n_s > tot[w]/Ntot:
                scores.append((g, w, a))
        scores.sort(reverse=True)
        out[s] = scores[:topn]
    return out

print("\nMost section-distinctive words (G2 log-likelihood), real Currier B:")
for s, words in g2_keywords(groups_real, page_tokens).items():
    print(f"  {NAMES.get(s,s):<14}: " + ', '.join(f"{w}({g:.0f})" for g, w, _ in words))
print("\nSame measure on the stationary generated text (any 'keywords' here are noise):")
for s, words in g2_keywords(groups_gen, ptok_gen).items():
    print(f"  {NAMES.get(s,s):<14}: " + ', '.join(f"{w}({g:.0f})" for g, w, _ in words))

# count of strongly section-linked types (G2 > 10.83 ~ p<0.001) per corpus
def strong_count(groups, ptok):
    ks = g2_keywords(groups, ptok, topn=10**6)
    return sum(sum(1 for g, _, _ in v if g > 10.83) for v in ks.values())
print(f"\nWord types significantly tied to a section (G2>10.83):")
print(f"  real B: {strong_count(groups_real, page_tokens)}   generated: {strong_count(groups_gen, ptok_gen)}   english: {strong_count(groups_eng, ptok_eng)}   shuffled: {strong_count(groups_sh, ptok_sh)}")
