#!/usr/bin/env python3
"""Morphosyntax test: do Voynich suffix alternations function like grammar?

For a suffix pair (s1, s2) and every stem x where both x+s1 and x+s2 occur,
compare the immediate-context distributions of the two forms. If the suffix is
grammatical, the *direction* of the context difference must be consistent across
unrelated stems (split-half replication). If it's decorative, consistency ~ 0.
Positive controls: English -s/-, -ed/-ing; Latin case endings. Negative control:
the fitted generator (suffixes re-rolled at random).
"""
import re, math, random
from collections import Counter, defaultdict

BASE = "/Users/arcandledger/taxdome/ancient-texts"
valid = lambda w: re.fullmatch(r'[a-z]+', w) is not None

# ---------------- corpora as line-lists ----------------
def load_voynich_lines(lang='B'):
    pages, cur, lines = {}, None, []
    locus_re = re.compile(r'^<(f[0-9a-zA-Z]+)\.([^,>]+),\s*([@+=*~$&!])(\w+?)(\d*)>\s*(.*)$')
    hdr = re.compile(r'^<(f[0-9a-zA-Z]+)>')
    for raw in open(f"{BASE}/data/ZL3b-n.txt", encoding='utf-8'):
        raw = raw.rstrip('\n')
        if not raw or raw.startswith('#'): continue
        m = hdr.match(raw)
        if m and not locus_re.match(raw):
            cur = m.group(1); pages[cur] = dict(re.findall(r'\$(\w)=(\w+)', raw)); continue
        m = locus_re.match(raw)
        if not m: continue
        page, _, _, ltype, _, text = m.groups()
        if not ltype.upper().startswith('P'): continue
        if lang and pages.get(page, {}).get('L') != lang: continue
        t = re.sub(r'<!.*?>', '', text)
        t = re.sub(r'<->', '?', t); t = re.sub(r'<%>|<\$>|<@\w+>', '', t)
        t = re.sub(r'@\d+;', '?', t)
        for _ in range(4):
            t = re.sub(r'\[([^:\[\]]*):[^\[\]]*\]', r'\1', t)
        t = t.replace(',', '.'); t = re.sub(r'[!%]', '', t)
        ws = [w for w in t.split('.') if w and valid(w)]
        if len(ws) >= 2: lines.append(ws)
    return lines

def chunk_lines(tokens, lo=6, hi=12, seed=0):
    rng = random.Random(seed)
    lines, i = [], 0
    while i < len(tokens):
        n = rng.randint(lo, hi)
        lines.append(tokens[i:i+n]); i += n
    return [l for l in lines if len(l) >= 2]

# ---------------- the test ----------------
def collect(lines, s1, s2, min_form=4):
    """stem -> (Counter of next-words for form1, for form2, prev1, prev2)"""
    freq = Counter(w for l in lines for w in l)
    def stem_of(w, s):
        if s == '': return w if len(w) > 2 else None
        return w[:-len(s)] if w.endswith(s) and len(w) > len(s)+1 else None
    stems = set()
    for w in freq:
        x = stem_of(w, s1)
        if x and freq.get(x+s1, 0) >= min_form and freq.get(x+s2, 0) >= min_form:
            stems.add(x)
    data = {x: (Counter(), Counter(), Counter(), Counter()) for x in stems}
    for l in lines:
        for i, w in enumerate(l):
            for s, k in ((s1, 0), (s2, 1)):
                x = stem_of(w, s)
                if x in data:
                    if i+1 < len(l): data[x][k][l[i+1]] += 1
                    if i > 0: data[x][2+k][l[i-1]] += 1
    return data

def split_half_consistency(data, side='next', n_splits=24, min_ctx=8, rng=None):
    rng = rng or random.Random(42)
    stems = list(data)
    if len(stems) < 8: return float('nan'), len(stems)
    k1, k2 = (0, 1) if side == 'next' else (2, 3)
    cors = []
    for _ in range(n_splits):
        rng.shuffle(stems)
        half = len(stems)//2
        vecs = []
        for grp in (stems[:half], stems[half:]):
            c1, c2 = Counter(), Counter()
            for x in grp: c1 += data[x][k1]; c2 += data[x][k2]
            vecs.append((c1, c2))
        common = [c for c in (vecs[0][0]+vecs[0][1]).keys()
                  if (vecs[0][0]+vecs[0][1])[c] >= min_ctx and (vecs[1][0]+vecs[1][1])[c] >= min_ctx]
        if len(common) < 6: continue
        def delta(c1, c2, c):
            n1, n2 = sum(c1.values()), sum(c2.values())
            return math.log((c1[c]+0.5)/(n1+1)) - math.log((c2[c]+0.5)/(n2+1))
        v1 = [delta(*vecs[0], c) for c in common]
        v2 = [delta(*vecs[1], c) for c in common]
        m1, m2 = sum(v1)/len(v1), sum(v2)/len(v2)
        num = sum((a-m1)*(b-m2) for a, b in zip(v1, v2))
        d1 = math.sqrt(sum((a-m1)**2 for a in v1)); d2 = math.sqrt(sum((b-m2)**2 for b in v2))
        if d1 > 0 and d2 > 0: cors.append(num/(d1*d2))
    return (sum(cors)/len(cors) if cors else float('nan')), len(stems)

def permutation_null(data, side='next', n_perm=30, seed=9):
    """Shuffle form labels within each stem (preserving per-form counts), recompute."""
    rng = random.Random(seed)
    k1, k2 = (0, 1) if side == 'next' else (2, 3)
    vals = []
    for _ in range(n_perm):
        fake = {}
        for x, tup in data.items():
            pool = list(tup[k1].elements()) + list(tup[k2].elements())
            rng.shuffle(pool)
            n1 = sum(tup[k1].values())
            fake[x] = (Counter(pool[:n1]), Counter(pool[n1:]), Counter(), Counter())
        v, _ = split_half_consistency(fake, 'next', n_splits=8, rng=rng)
        if not math.isnan(v): vals.append(v)
    if not vals: return float('nan'), float('nan')
    m = sum(vals)/len(vals)
    sd = math.sqrt(sum((v-m)**2 for v in vals)/len(vals)) or 1e-9
    return m, sd

def run(name, lines, pairs):
    print(f"\n=== {name} ===")
    print(f"{'suffix pair':<16}{'stems':>6}{'next r':>8}{'null':>7}{'z':>7}{'prev r':>8}")
    for s1, s2 in pairs:
        data = collect(lines, s1, s2)
        rn, ns = split_half_consistency(data, 'next')
        rp, _ = split_half_consistency(data, 'prev')
        if math.isnan(rn):
            print(f"{s1+'/'+s2:<16}{ns:>6}   (too few stems)"); continue
        nm, nsd = permutation_null(data, 'next')
        z = (rn-nm)/nsd if not math.isnan(nm) else float('nan')
        print(f"{s1+'/'+s2:<16}{ns:>6}{rn:>8.3f}{nm:>7.3f}{z:>7.1f}{rp:>8.3f}")

# ---------------- corpora ----------------
voy_lines = load_voynich_lines('B')
print(f"Voynich B: {sum(map(len, voy_lines))} tokens in {len(voy_lines)} lines")

eng_tokens = re.findall(r'[a-z]+', open(f"{BASE}/data/english.txt").read().lower())[:120000]
eng_lines = chunk_lines(eng_tokens)
lat_tokens = re.findall(r'[a-z]+', open(f"{BASE}/data/latin.txt").read().lower())[:35000]
lat_lines = chunk_lines(lat_tokens)

# generated control (same generator as before, quick re-gen via topic_test's machinery)
import importlib.util, sys, io
spec = importlib.util.spec_from_file_location('tt', f"{BASE}/topic_test.py")
# topic_test runs everything on import; instead inline-generate with its functions via exec of just what we need
src = open(f"{BASE}/topic_test.py").read()
ns = {}
head = src[:src.index("# ---------------- main")] if "# ---------------- main" in src else src
head = head[:head.index("print(\"Loading")] if "print(\"Loading" in head else head
exec(head, ns)
B_toks = [w for l in voy_lines for w in l]
gen_stream = ns['generate'](len(B_toks), B_toks, random.Random(77))
gen_lines = chunk_lines(gen_stream, seed=3)
print(f"Generated control: {sum(map(len, gen_lines))} tokens")

VPAIRS = [('edy','ey'), ('dy','y'), ('aiin','ain'), ('ol','or'), ('ar','or'), ('al','ol'), ('ey','eey')]
run("VOYNICH B", voy_lines, VPAIRS)
run("GENERATED (negative control)", gen_lines, VPAIRS)
run("ENGLISH (positive control)", eng_lines, [('s',''), ('ed','ing'), ('','ly')])
run("LATIN (positive control)", lat_lines, [('us','um'), ('a','am'), ('is','ibus'), ('em','es')])
