import sys
from find_equation_id import Equation, all_eqs, all_shapes, all_rhymes, num_eqs, catalan, bell, shape_order, canonicalize_rhyme, shape_lt, shape_dual
import subprocess
import re
import itertools
import functools
from sympy import poly, prem
from sympy.abc import b

def last_depth(shape):
    if shape is None:
        return 0
    return last_depth(shape[1]) + 1

def lin_poly(shape, rhyme_indicator, a, b):
    if shape is None:
        return 1 if rhyme_indicator[0] else 0
    sho = shape_order(shape[0])
    return (a * lin_poly(shape[0], rhyme_indicator[:sho + 1], a, b)
            + b * lin_poly(shape[1], rhyme_indicator[sho + 1:], a, b))

def has_high_power(shape, rhyme):
    if shape is None or shape == (None, None):
        return False
    if ((shape == (None, (None, None)) or shape == ((None, None), None)) and
        rhyme[0] == rhyme[1] == rhyme[2]):
        return True
    if (shape == ((None, None), (None, None)) and
        rhyme[0] == rhyme[1] == rhyme[2] == rhyme[3]):
        return True
    return (has_high_power(shape[0], rhyme[:shape_order(shape[0]) + 1]) or
            has_high_power(shape[1], rhyme[shape_order(shape[0]) + 1:]))

for order in [1, 2, 3, 4]:
    for eq in all_eqs(order):
        eq_orders = eq.orders()
        if eq.rhyme[eq.orders()[0]] != eq.rhyme[-1]:
            continue
        if last_depth(eq.lhs_shape) - last_depth(eq.rhs_shape) % 2 != 0:
            continue
        if any(eq.rhyme.count(i) == 1 for i in range(max(eq.rhyme) + 1)):
            continue
        lhs_order = shape_order(eq.lhs_shape)
        lhs_rhyme = eq.rhyme[:lhs_order + 1]
        rhs_rhyme = eq.rhyme[lhs_order + 1:]
        if any(prem(poly(
                lin_poly(eq.rhs_shape, [i == k for i in rhs_rhyme], 1 - b, b)
                - lin_poly(eq.lhs_shape, [i == k for i in lhs_rhyme], 1 - b, b), b),
                    poly(b**2 - b - 1)) != 0 for k in range(max(eq.rhyme) + 1)):
            continue
        if (has_high_power(eq.lhs_shape, lhs_rhyme) or
            has_high_power(eq.rhs_shape, rhs_rhyme)):
            continue
        print(eq.id, eq)
