#!/usr/bin/env python3
"""Shared null-model generators + the Part-10 grammar battery, extracted so
new tests (kipchak_test T2 control, longrange_test, rugg_test) can import them
without re-implementing. Verbatim ports:
  - self-citation slot generator + fitted params BP : topic_test.py (Part 6/5)
  - M1 class-conditioned / M2 word-bigram Markov    : steelman_test.py (Part 32)
  - collect()/consistency() grammar test            : steelman_test.py (Part 10)
The source scripts execute their experiments at import time, hence the copy;
any change here must be mirrored there (and vice versa).
"""
import math, random
from collections import Counter, defaultdict

# ----------------------------------------------------------- suffix classes
SUF = ['eedy','edy','dy','aiin','ain','eey','ey','y','ol','or','al','ar','am',
       'o','l','r','n','s','-']

def suf(w):
    for s in SUF:
        if s != '-' and w.endswith(s) and len(w) > len(s):
            return s
    return '-'

# ----------------------------------------- self-citation generator (fitted)
PRE = ['qo','ch','sh','da','ol','o','d','y','s','l','q','r','']
END = ['eedy','aiin','aiir','eey','edy','ain','air','am','an','ar','al','dy',
       'ey','ol','or','y','o','n','r','l','s','m','']

def decompose(w):
    best = None
    for p in PRE:
        if not w.startswith(p):
            continue
        rest = w[len(p):]
        for e in sorted(END, key=len, reverse=True):
            if rest.endswith(e) and len(rest) >= len(e):
                cand = (len(e), len(p), p, rest[:len(rest) - len(e)], e)
                if best is None or cand[:2] > best[:2]:
                    best = cand
                break
    return ('', w, '') if best is None else (best[2], best[3], best[4])

class SlotModel:
    def __init__(self, tokens, tau, tau_mid, mid_min):
        cp, cm, ce = Counter(), Counter(), Counter()
        jm, je = defaultdict(Counter), defaultdict(Counter)
        for w in tokens:
            w = w.replace('p', 't').replace('f', 'k')
            p, m, e = decompose(w)
            if len(m) > 6:
                continue
            if e == 'm':
                e = 'n'
            elif e == 'am':
                e = 'an'
            cp[p] += 1; cm[m] += 1; ce[e] += 1
            jm[p][m] += 1; je[m][e] += 1
        keep = {m for m, c in cm.items() if c >= mid_min}
        def T(c, t):
            items = [(k, v ** (1 / t)) for k, v in c.items() if v > 0]
            tot = sum(x for _, x in items)
            return ([k for k, _ in items], [v / tot for _, v in items]) if items else ([], [])
        self.pre = T(cp, tau)
        self.mid_marg = T(Counter({m: c for m, c in cm.items() if m in keep}), tau_mid)
        self.end_marg = T(ce, tau)
        self.mid_g = {p: T(Counter({m: c for m, c in cnt.items() if m in keep}), tau_mid)
                      for p, cnt in jm.items()
                      if sum(cnt[m] for m in cnt if m in keep) >= 10}
        self.end_g = {m: T(cnt, tau) for m, cnt in je.items() if sum(cnt.values()) >= 10}
    def s_pre(self, rng): return rng.choices(*self.pre)[0]
    def s_mid(self, rng, p):
        d = self.mid_g.get(p, self.mid_marg)
        return rng.choices(*d)[0] if d[0] else ''
    def s_end(self, rng, m):
        d = self.end_g.get(m, self.end_marg)
        return rng.choices(*d)[0] if d[0] else 'dy'

BP = dict(p_reuse=0.415, p_mut=0.166, window=64, p_nc=0.053, tau=0.73,
          tau_mid=0.991, tau_flat=1.227, q_head=0.9, mid_min=2, decay=0.836,
          p_local=0.334)

def gen_selfcitation(n_tokens, B_toks, seed=123, as_lines=False):
    """Fitted 11-parameter self-citation process (topic_test.py).
    Returns flat stream, or list of lines if as_lines."""
    rng = random.Random(seed)
    sharp = SlotModel(B_toks, BP['tau'], BP['tau_mid'], BP['mid_min'])
    flat = SlotModel(B_toks, BP['tau_flat'], BP['tau_flat'], BP['mid_min'])
    weights = [BP['decay'] ** i for i in range(BP['window'])]
    def novel():
        model = sharp if rng.random() < BP['q_head'] else flat
        for _ in range(20):
            p = model.s_pre(rng); m = model.s_mid(rng, p); e = model.s_end(rng, m)
            w = p + m + e
            if 1 <= len(w) <= 12:
                return w
        return 'daiin'
    def mutate(w):
        if rng.random() < BP['p_nc']:
            return w
        p, m, e = decompose(w)
        r = rng.random()
        if r < 0.35:   p = sharp.s_pre(rng)
        elif r < 0.65: m = sharp.s_mid(rng, p)
        else:          e = sharp.s_end(rng, m)
        w2 = p + m + e
        return w2 if 1 <= len(w2) <= 12 else w
    hist, stream, out_lines = [], [], []
    par_left = 0
    while len(stream) < n_tokens:
        par_first = par_left == 0
        if par_first:
            par_left = rng.randint(3, 14)
        par_left -= 1
        base = []
        for _ in range(rng.randint(6, 12)):
            r = rng.random()
            if hist and r < BP['p_reuse']:
                pool = hist[-250:] if rng.random() < BP['p_local'] else hist
                w = pool[rng.randrange(len(pool))]
            elif hist and r < BP['p_reuse'] + BP['p_mut']:
                recent = hist[-BP['window']:][::-1]
                w = mutate(rng.choices(recent, weights[:len(recent)])[0])
            else:
                w = novel()
            base.append(w); hist.append(w)
        line = base[:]
        if par_first:
            line = [w.replace('t', 'p').replace('k', 'f')
                    if rng.random() < 0.45 else w for w in line]
            if line[0][0] not in 'tkpf' and rng.random() < 0.82:
                line[0] = ('p' if rng.random() < 0.7 else 't') + line[0]
        if rng.random() < 0.38 and line[-1][-1] in 'nrl':
            line[-1] = line[-1][:-1] + 'm'
        stream.extend(line)
        out_lines.append(line)
    return out_lines if as_lines else stream[:n_tokens]

# --------------------------------------------- steelman Markov generators
def gen_class_markov(n, B_lines, seed=0, as_lines=False):
    """M1: word ~ P(word | previous word's suffix class)."""
    by, starts = defaultdict(Counter), Counter()
    for l in B_lines:
        starts[l[0]] += 1
        for a, b in zip(l, l[1:]):
            by[suf(a)][b] += 1
    rng = random.Random(seed)
    out, out_lines = [], []
    sk, sw = list(starts), list(starts.values())
    while len(out) < n:
        w = rng.choices(sk, sw)[0]; line = [w]
        for _ in range(rng.randint(6, 11)):
            nx = by.get(suf(w))
            if not nx:
                break
            w = rng.choices(list(nx), list(nx.values()))[0]; line.append(w)
        out.extend(line); out_lines.append(line)
    return out_lines if as_lines else out[:n]

def gen_word_bigram(n, B_lines, seed=0, as_lines=False):
    """M2: word-bigram Markov chain restarted per line."""
    trans, starts = defaultdict(Counter), Counter()
    for l in B_lines:
        starts[l[0]] += 1
        for a, b in zip(l, l[1:]):
            trans[a][b] += 1
    rng = random.Random(seed)
    out, out_lines = [], []
    sk, sw = list(starts), list(starts.values())
    while len(out) < n:
        w = rng.choices(sk, sw)[0]; line = [w]
        for _ in range(rng.randint(6, 11)):
            nx = trans.get(w)
            if not nx:
                break
            w = rng.choices(list(nx), list(nx.values()))[0]; line.append(w)
        out.extend(line); out_lines.append(line)
    return out_lines if as_lines else out[:n]

# ------------------------------------------------ Part-10 grammar battery
def collect(L, s1, s2, minf=4):
    freq = Counter(w for l in L for w in l)
    def stem(w, s):
        return w[:-len(s)] if w.endswith(s) and len(w) > len(s) + 1 else None
    stems = {x for w in freq for x in [stem(w, s1)]
             if x and freq.get(x + s1, 0) >= minf and freq.get(x + s2, 0) >= minf}
    D = {x: (Counter(), Counter()) for x in stems}
    for l in L:
        for i, w in enumerate(l):
            for s, k in ((s1, 0), (s2, 1)):
                x = stem(w, s)
                if x in D and i + 1 < len(l):
                    D[x][k][l[i + 1]] += 1
    return D

def consistency(D, nsplit=24, minctx=6, seed=42):
    rng = random.Random(seed)
    stems = [x for x in D if sum(D[x][0].values()) and sum(D[x][1].values())]
    if len(stems) < 8:
        return float('nan'), len(stems)
    cors = []
    for _ in range(nsplit):
        rng.shuffle(stems); h = len(stems) // 2; V = []
        for grp in (stems[:h], stems[h:]):
            c1, c2 = Counter(), Counter()
            for x in grp:
                c1 += D[x][0]; c2 += D[x][1]
            V.append((c1, c2))
        common = [c for c in (V[0][0] + V[0][1])
                  if (V[0][0] + V[0][1])[c] >= minctx and (V[1][0] + V[1][1])[c] >= minctx]
        if len(common) < 6:
            continue
        def d(c1, c2, c):
            n1, n2 = sum(c1.values()), sum(c2.values())
            return math.log((c1[c] + 0.5) / (n1 + 1)) - math.log((c2[c] + 0.5) / (n2 + 1))
        v1 = [d(*V[0], c) for c in common]; v2 = [d(*V[1], c) for c in common]
        m1, m2 = sum(v1) / len(v1), sum(v2) / len(v2)
        num = sum((a - m1) * (b - m2) for a, b in zip(v1, v2))
        d1 = math.sqrt(sum((a - m1) ** 2 for a in v1))
        d2 = math.sqrt(sum((b - m2) ** 2 for b in v2))
        if d1 * d2:
            cors.append(num / (d1 * d2))
    return (sum(cors) / len(cors) if cors else float('nan')), len(stems)

def h2w(stream):
    """Word-level conditional entropy H(w2|w1) in bits."""
    big = Counter(zip(stream, stream[1:]))
    uni = Counter(stream[:-1])
    N = sum(big.values())
    return -sum(c / N * math.log2(c / uni[a]) for (a, _), c in big.items())
