#!/usr/bin/env python3
"""Burst-decay test: is word recurrence driven by token distance (working-memory
process) or by line membership (line = record about one referent)?
At matched token distance d, compare P(same word) for same-line vs cross-line pairs.
Controls: the fitted generator (pure distance process) and English (chunked lines)."""
import re, math, random
from collections import Counter, defaultdict

BASE = "/Users/arcandledger/taxdome/ancient-texts"
valid = lambda w: re.fullmatch(r'[a-z]+', w) is not None

def load_B():
    pages, cur = {}, None
    out = defaultdict(list)   # page -> list of (line_idx, word)
    line_no = 0
    locus_re = re.compile(r'^<(f[0-9a-zA-Z]+)\.([^,>]+),\s*([@+=*~$&!])(\w+?)(\d*)>\s*(.*)$')
    hdr = re.compile(r'^<(f[0-9a-zA-Z]+)>')
    for raw in open(f"{BASE}/data/ZL3b-n.txt", encoding='utf-8'):
        raw = raw.rstrip('\n')
        if not raw or raw.startswith('#'): continue
        m = hdr.match(raw)
        if m and not locus_re.match(raw):
            cur = m.group(1); pages[cur] = dict(re.findall(r'\$(\w)=(\w+)', raw)); continue
        m = locus_re.match(raw)
        if not m: continue
        page, _, _, ltype, _, text = m.groups()
        if not ltype.upper().startswith('P'): continue
        if pages.get(page, {}).get('L') != 'B': continue
        t = re.sub(r'<!.*?>', '', text)
        t = re.sub(r'<->', '?', t); t = re.sub(r'<%>|<\$>|<@\w+>', '', t)
        t = re.sub(r'@\d+;', '?', t)
        for _ in range(4):
            t = re.sub(r'\[([^:\[\]]*):[^\[\]]*\]', r'\1', t)
        t = t.replace(',', '.'); t = re.sub(r'[!%]', '', t)
        ws = [w for w in t.split('.') if w and valid(w)]
        if len(ws) >= 2:
            line_no += 1
            out[page].extend((line_no, w) for w in ws)
    return list(out.values())

def decay_table(pages, dists=(1,2,3,4,6,8,12,16,24)):
    res = {}
    for d in dists:
        cnt = {True: [0, 0], False: [0, 0]}   # same_line -> [pairs, matches]
        for seq in pages:
            for i in range(len(seq)-d):
                (l1, w1), (l2, w2) = seq[i], seq[i+d]
                k = (l1 == l2)
                cnt[k][0] += 1
                cnt[k][1] += (w1 == w2)
        res[d] = {k: (1000*v[1]/v[0] if v[0] > 200 else float('nan'), v[0]) for k, v in cnt.items()}
    return res

def fake_pages(tokens, mean_page=270, seed=4):
    rng = random.Random(seed)
    pages, i, line_no = [], 0, 0
    while i < len(tokens):
        plen, page = 0, []
        while plen < mean_page and i < len(tokens):
            n = rng.randint(6, 12); line_no += 1
            page.extend((line_no, w) for w in tokens[i:i+n])
            i += n; plen += n
        pages.append(page)
    return pages

B_pages = load_B()
B_tokens = [w for p in B_pages for _, w in p]

src = open(f"{BASE}/topic_test.py").read()
ns = {}
exec(src[:src.index('# ---------------- main')] if '# ---------------- main' in src else src, ns)
gen_tokens = ns['generate'](len(B_tokens), B_tokens, random.Random(21))
eng_tokens = re.findall(r'[a-z]+', open(f"{BASE}/data/english.txt").read().lower())[:len(B_tokens)]

print("P(word at distance d is identical), per 1000 — same line vs across lines")
print(f"{'d':>3} | {'Voy same':>9}{'Voy cross':>10} | {'Gen same':>9}{'Gen cross':>10} | {'Eng same':>9}{'Eng cross':>10}")
T = {n: decay_table(p) for n, p in [('voy', B_pages), ('gen', fake_pages(gen_tokens)),
                                     ('eng', fake_pages(eng_tokens))]}
for d in (1,2,3,4,6,8,12,16,24):
    row = f"{d:>3} |"
    for n in ('voy','gen','eng'):
        s, _ = T[n][d][True]; c, _ = T[n][d][False]
        row += f"{s:>9.1f}{c:>10.1f} |" if not math.isnan(s) else f"{'-':>9}{c:>10.1f} |"
    print(row)

# summary statistic: same-line elevation at matched d (mean over d=2..8 where both defined)
print("\nSame-line elevation ratio (same/cross at matched d, mean d=2..8):")
for n, label in [('voy','Voynich B'), ('gen','generator'), ('eng','English')]:
    ratios = []
    for d in (2,3,4,6,8):
        s, _ = T[n][d][True]; c, _ = T[n][d][False]
        if not math.isnan(s) and c > 0: ratios.append(s/c)
    print(f"  {label:<12} {sum(ratios)/len(ratios):.2f}")
