#!/usr/bin/env python3
"""Controls for the morphosyntax result: does suffix-context consistency survive
(A) restriction to a single section (kills topic confound) and
(B) exclusion of line-edge positions (kills layout confound)?
"""
import re, math, random
from collections import Counter, defaultdict

BASE = "/Users/arcandledger/taxdome/ancient-texts"
valid = lambda w: re.fullmatch(r'[a-z]+', w) is not None

def load_lines(lang='B', section=None):
    pages, cur, lines = {}, None, []
    locus_re = re.compile(r'^<(f[0-9a-zA-Z]+)\.([^,>]+),\s*([@+=*~$&!])(\w+?)(\d*)>\s*(.*)$')
    hdr = re.compile(r'^<(f[0-9a-zA-Z]+)>')
    for raw in open(f"{BASE}/data/ZL3b-n.txt", encoding='utf-8'):
        raw = raw.rstrip('\n')
        if not raw or raw.startswith('#'): continue
        m = hdr.match(raw)
        if m and not locus_re.match(raw):
            cur = m.group(1); pages[cur] = dict(re.findall(r'\$(\w)=(\w+)', raw)); continue
        m = locus_re.match(raw)
        if not m: continue
        page, _, _, ltype, _, text = m.groups()
        if not ltype.upper().startswith('P'): continue
        v = pages.get(page, {})
        if lang and v.get('L') != lang: continue
        if section and v.get('I') != section: continue
        t = re.sub(r'<!.*?>', '', text)
        t = re.sub(r'<->', '?', t); t = re.sub(r'<%>|<\$>|<@\w+>', '', t)
        t = re.sub(r'@\d+;', '?', t)
        for _ in range(4):
            t = re.sub(r'\[([^:\[\]]*):[^\[\]]*\]', r'\1', t)
        t = t.replace(',', '.'); t = re.sub(r'[!%]', '', t)
        ws = [w for w in t.split('.') if w and valid(w)]
        if len(ws) >= 2: lines.append(ws)
    return lines

def collect(lines, s1, s2, min_form=3, medial_only=False):
    freq = Counter(w for l in lines for w in l)
    def stem_of(w, s):
        if s == '': return w if len(w) > 2 else None
        return w[:-len(s)] if w.endswith(s) and len(w) > len(s)+1 else None
    stems = set()
    for w in freq:
        x = stem_of(w, s1)
        if x and freq.get(x+s1, 0) >= min_form and freq.get(x+s2, 0) >= min_form:
            stems.add(x)
    data = {x: (Counter(), Counter()) for x in stems}
    for l in lines:
        for i, w in enumerate(l):
            if medial_only and (i < 1 or i >= len(l)-2): continue
            for s, k in ((s1, 0), (s2, 1)):
                x = stem_of(w, s)
                if x in data and i+1 < len(l):
                    data[x][k][l[i+1]] += 1
    return data

def consistency(data, n_splits=24, min_ctx=6, rng=None):
    rng = rng or random.Random(42)
    stems = [x for x in data if sum(data[x][0].values()) > 0 and sum(data[x][1].values()) > 0]
    if len(stems) < 8: return float('nan'), len(stems)
    cors = []
    for _ in range(n_splits):
        rng.shuffle(stems)
        half = len(stems)//2
        vecs = []
        for grp in (stems[:half], stems[half:]):
            c1, c2 = Counter(), Counter()
            for x in grp: c1 += data[x][0]; c2 += data[x][1]
            vecs.append((c1, c2))
        common = [c for c in (vecs[0][0]+vecs[0][1]).keys()
                  if (vecs[0][0]+vecs[0][1])[c] >= min_ctx and (vecs[1][0]+vecs[1][1])[c] >= min_ctx]
        if len(common) < 6: continue
        def delta(c1, c2, c):
            n1, n2 = sum(c1.values()), sum(c2.values())
            return math.log((c1[c]+0.5)/(n1+1)) - math.log((c2[c]+0.5)/(n2+1))
        v1 = [delta(*vecs[0], c) for c in common]
        v2 = [delta(*vecs[1], c) for c in common]
        m1, m2 = sum(v1)/len(v1), sum(v2)/len(v2)
        num = sum((a-m1)*(b-m2) for a, b in zip(v1, v2))
        d1 = math.sqrt(sum((a-m1)**2 for a in v1)); d2 = math.sqrt(sum((b-m2)**2 for b in v2))
        if d1 > 0 and d2 > 0: cors.append(num/(d1*d2))
    return (sum(cors)/len(cors) if cors else float('nan')), len(stems)

def null_z(data, observed, n_perm=30, seed=9):
    rng = random.Random(seed)
    vals = []
    for _ in range(n_perm):
        fake = {}
        for x, (a, b) in data.items():
            pool = list(a.elements()) + list(b.elements())
            rng.shuffle(pool)
            n1 = sum(a.values())
            fake[x] = (Counter(pool[:n1]), Counter(pool[n1:]))
        v, _ = consistency(fake, n_splits=8, rng=rng)
        if not math.isnan(v): vals.append(v)
    if not vals: return float('nan')
    m = sum(vals)/len(vals)
    sd = math.sqrt(sum((v-m)**2 for v in vals)/len(vals)) or 1e-9
    return (observed-m)/sd

PAIRS = [('edy','ey'), ('dy','y')]

print("CONTROL A — single section only (topic confound removed):")
for sec_name, sec in [('stars/recipes (S)', 'S'), ('biological (B)', 'B')]:
    lines = load_lines('B', sec)
    n = sum(map(len, lines))
    print(f"\n  section {sec_name}: {n} tokens")
    for s1, s2 in PAIRS:
        data = collect(lines, s1, s2)
        r, ns = consistency(data)
        if math.isnan(r): print(f"    {s1}/{s2}: too few stems ({ns})"); continue
        z = null_z(data, r)
        print(f"    {s1}/{s2}: stems={ns}  r={r:.3f}  z={z:.1f}")

print("\nCONTROL B — full Currier B, line-edge positions excluded (layout confound removed):")
lines = load_lines('B')
for s1, s2 in PAIRS:
    data = collect(lines, s1, s2, min_form=4, medial_only=True)
    r, ns = consistency(data)
    if math.isnan(r): print(f"  {s1}/{s2}: too few stems ({ns})"); continue
    z = null_z(data, r)
    print(f"  {s1}/{s2}: stems={ns}  r={r:.3f}  z={z:.1f}")

print("\nBASELINE (uncontrolled, for comparison): edy/ey r=0.554, dy/y r=0.560")
