#!/usr/bin/env python3
"""
Cleanup script for Aristotle-generated Lean proofs.

Applies Mathlib style fixes to Lean files, verifying the file still builds
after each transformation.
"""

import argparse
import re
import subprocess
import sys
from pathlib import Path
from typing import Callable

def find_lake_root(file_path: Path) -> Path | None:
    """
    Find the Lake project root by searching for lakefile.toml or lakefile.lean
    in parent directories.

    Returns the directory containing the lakefile, or None if not found.
    """
    file_path = file_path.resolve()
    current = file_path.parent

    while current != current.parent:  # Stop at filesystem root
        if (current / "lakefile.toml").exists() or (current / "lakefile.lean").exists():
            return current
        current = current.parent

    # Check root as well
    if (current / "lakefile.toml").exists() or (current / "lakefile.lean").exists():
        return current

    return None


def file_to_module_name(file_path: Path, project_dir: Path) -> str:
    """
    Convert a file path to a Lean module name.

    E.g., Gorenstein/Ch1_Theorem_2_14.lean -> Gorenstein.Ch1_Theorem_2_14
    """
    rel_path = file_path.resolve().relative_to(project_dir)
    # Remove .lean extension and convert / to .
    module = str(rel_path.with_suffix("")).replace("/", ".")
    return module


def run_build(file_path: Path, project_dir: Path, module_name: str, verbose: bool = False) -> tuple[bool, str, str]:
    """
    Run a build and return (success, stdout, stderr).
    """
    cmd = ["lake", "build", module_name]

    if verbose:
        print(f"  Running: {' '.join(cmd)} (in {project_dir})")

    result = subprocess.run(
        cmd,
        cwd=project_dir,
        capture_output=True,
        text=True
    )

    if verbose and result.returncode != 0:
        print(f"  Build failed: {result.stderr[:200]}")

    return result.returncode == 0, result.stdout, result.stderr


def verify_build(file_path: Path, project_dir: Path, module_name: str, verbose: bool = False) -> bool:
    """
    Verify that a Lean file builds successfully.

    Returns True if the build succeeds, False otherwise.
    """
    success, _, _ = run_build(file_path, project_dir, module_name, verbose)
    return success


def parse_unused_simp_args(build_output: str, target_file: Path) -> list[tuple[int, str]]:
    """
    Parse build output for unused simp argument warnings.

    Returns list of (line_number, argument_name) tuples.
    Line numbers are 1-indexed (as reported by Lean).
    """
    unused_args = []
    target_filename = target_file.name

    # Pattern: warning: <file>:<line>:<col>: This simp argument is unused:
    #   <argument_name>
    pattern = re.compile(
        r'warning: [^:]*' + re.escape(target_filename) + r':(\d+):\d+: This simp argument is unused:\s*\n\s*(\S+)',
        re.MULTILINE
    )

    for match in pattern.finditer(build_output):
        line_num = int(match.group(1))
        arg_name = match.group(2)
        unused_args.append((line_num, arg_name))

    return unused_args


def parse_try_this_suggestions(build_output: str, target_file: Path) -> list[tuple[int, int, str]]:
    """
    Parse build output for "Try this:" suggestions from exact?, simp?, simp_all?, etc.

    Returns list of (line_number, column, suggestion) tuples.
    Line numbers are 1-indexed (as reported by Lean).
    """
    suggestions = []
    target_filename = target_file.name

    # Pattern: info: <file>:<line>:<col>: Try this:
    #   <suggestion>
    # The suggestion is on the next line, indented
    pattern = re.compile(
        r'info: [^:]*' + re.escape(target_filename) + r':(\d+):(\d+): Try this:\s*\n\s*(.+)',
    )

    for match in pattern.finditer(build_output):
        line_num = int(match.group(1))
        col = int(match.group(2))
        suggestion = match.group(3).strip()
        suggestions.append((line_num, col, suggestion))

    return suggestions


def replace_tactic_with_suggestion(line: str, col: int, suggestion: str) -> str:
    """
    Replace a tactic (exact?, simp?, simp_all?, etc.) at the given column with the suggestion.

    The column points to the start of the tactic. We need to find and replace
    the tactic call (e.g., `exact?` or `simp?`) with the suggestion.
    """
    # Find what tactic is at/near the column
    # Common patterns: exact?, simp?, simp_all?, decide?, omega?
    # Note: \b doesn't work after ? so we don't use word boundaries
    tactic_pattern = re.compile(r'(exact\?|simp_all\?|simp\?|decide\?|omega\?|aesop\?)')

    # Find the tactic nearest to the column
    best_match = None
    best_distance = float('inf')

    for match in tactic_pattern.finditer(line):
        # Column is 0-indexed in our string, but Lean reports 1-indexed
        distance = abs(match.start() - (col - 1))
        if distance < best_distance:
            best_distance = distance
            best_match = match

    if best_match and best_distance < 20:  # Allow some tolerance
        # Replace the tactic with the suggestion
        return line[:best_match.start()] + suggestion + line[best_match.end():]

    return line


def remove_simp_arg_from_line(line: str, arg_name: str) -> str:
    """
    Remove a specific argument from a simp/simp_all call in a line.

    Handles patterns like:
    - simp [foo, bar, baz] -> simp [foo, baz]  (remove bar)
    - simp [foo] -> simp  (remove only arg)
    - simp_all [foo, bar] -> simp_all [foo]  (remove bar)
    """
    # Pattern to match the argument in a simp list
    # Need to handle: [arg], [arg, ...], [..., arg], [..., arg, ...]

    # Remove ", arg" or "arg, " or "[arg]"
    # First try ", arg" (arg after comma)
    new_line = re.sub(r',\s*' + re.escape(arg_name) + r'(?=[\s,\]])', '', line)
    if new_line != line:
        return new_line

    # Try "arg, " (arg before comma)
    new_line = re.sub(re.escape(arg_name) + r'\s*,\s*', '', line)
    if new_line != line:
        return new_line

    # Try "[arg]" (only arg in list) - remove the whole bracket
    new_line = re.sub(r'\[\s*' + re.escape(arg_name) + r'\s*\]', '', line)
    if new_line != line:
        return new_line

    return line


# =============================================================================
# Rules
# =============================================================================

def rule_remove_trailing_semicolons(content: str) -> str:
    """
    Remove unnecessary trailing semicolons at end of tactic sequences.

    Matches patterns like:
    - `aesop;` at end of line (but not `aesop; next_tactic`)
    - `exact foo;` before newline
    """
    # Remove ; followed by newline (possibly with trailing whitespace)
    # But be careful not to remove ; that's followed by another tactic on the same line
    content = re.sub(r';\s*$', '', content, flags=re.MULTILINE)
    return content


def split_one_semicolon(content: str) -> tuple[str, bool]:
    """
    Split ONE semicolon-joined tactic pair onto separate lines.

    Returns (new_content, changed).

    E.g., `ext x; aesop` becomes:
    ```
    ext x
    aesop
    ```
    with the second tactic indented the same as the first.

    Does NOT split:
    - `<;>` (tactic combinator)
    - Lines starting with `·` (bullet points need special indent handling)
    - Lines with `by ... ;` (semicolon inside inline by block)
    """
    lines = content.split('\n')

    for i, line in enumerate(lines):
        # Skip lines with <;> combinator
        if '<;>' in line:
            continue

        # Skip lines that start with a bullet point (complex indent handling)
        if re.match(r'^\s*·', line):
            continue

        # Skip lines with inline `by` blocks containing semicolons
        # e.g., `exact fun a => by rw [foo] ; exact bar`
        if re.search(r'\bby\b.*;\s*\S', line):
            continue

        # Skip lines where semicolon is inside parentheses
        # e.g., `(have := foo ; bar)`
        # Simple heuristic: if there's a `(` before the `;` without a matching `)`
        semi_pos = line.find(';')
        if semi_pos != -1:
            before_semi = line[:semi_pos]
            open_parens = before_semi.count('(') - before_semi.count(')')
            if open_parens > 0:
                continue

        # Find the indentation of this line
        indent_match = re.match(r'^(\s*)', line)
        indent = indent_match.group(1) if indent_match else ''

        # Look for ` ; ` or `; ` pattern (but not `<;>`)
        # We want to split on the FIRST semicolon only
        match = re.search(r'(?<![<]);\s*(?=\S)', line)
        if match:
            before = line[:match.start()].rstrip()
            after = line[match.end():].rstrip()

            if before.strip() and after.strip():
                # Replace this line with two lines
                lines[i] = before
                lines.insert(i + 1, indent + after)
                return '\n'.join(lines), True

    return content, False


def rule_normalize_whitespace(content: str) -> str:
    """
    Remove excessive whitespace in tactics.

    - `( foo )` -> `(foo)`
    - `[ foo ]` -> `[foo]`  (space after [ and before ], but preserve space before [)
    - `‹ foo ›` -> `‹foo›`
    - `⟨ foo ⟩` -> `⟨foo⟩`
    """
    # Handle parentheses: ( foo ) -> (foo)
    content = re.sub(r'\(\s+', '(', content)
    content = re.sub(r'\s+\)', ')', content)

    # Handle square brackets: remove space after [ and before ]
    # But preserve space before [ (e.g., `simp [foo]` not `simp[foo]`)
    content = re.sub(r'\[\s+', '[', content)
    content = re.sub(r'\s+\]', ']', content)

    # Handle single angle brackets: ‹ foo › -> ‹foo›
    content = re.sub(r'‹\s+', '‹', content)
    content = re.sub(r'\s+›', '›', content)

    # Handle double angle brackets: ⟨ foo ⟩ -> ⟨foo⟩
    content = re.sub(r'⟨\s+', '⟨', content)
    content = re.sub(r'\s+⟩', '⟩', content)

    return content


def rule_replace_refine_prime(content: str) -> str:
    """
    Replace deprecated `refine'` with `refine`.
    """
    return content.replace('refine\'', 'refine')


def rule_remove_plus_decide(content: str) -> str:
    """
    Remove all `+decide` flags from simp/simp_all calls.
    """
    # Match ` +decide` (with leading space) or `+decide` at start of bracket content
    content = re.sub(r'\s+\+decide\b', '', content)
    content = re.sub(r'\[\+decide,\s*', '[', content)
    content = re.sub(r',\s*\+decide\]', ']', content)
    content = re.sub(r'\[\+decide\]', '', content)
    return content


def rule_remove_plus_contextual(content: str) -> str:
    """
    Remove all `+contextual` flags from simp/simp_all calls.
    """
    content = re.sub(r'\s+\+contextual\b', '', content)
    content = re.sub(r'\[\+contextual,\s*', '[', content)
    content = re.sub(r',\s*\+contextual\]', ']', content)
    content = re.sub(r'\[\+contextual\]', '', content)
    return content


def rule_remove_plus_zetaDelta(content: str) -> str:
    """
    Remove all `+zetaDelta` flags from simp/simp_all calls.
    """
    content = re.sub(r'\s+\+zetaDelta\b', '', content)
    content = re.sub(r'\[\+zetaDelta,\s*', '[', content)
    content = re.sub(r',\s*\+zetaDelta\]', ']', content)
    content = re.sub(r'\[\+zetaDelta\]', '', content)
    return content


def find_lines_matching(content: str, pattern: str) -> list[tuple[int, str]]:
    """
    Find all lines matching a pattern.

    Returns list of (line_number, line_content) tuples.
    Line numbers are 0-indexed.
    """
    matches = []
    for i, line in enumerate(content.split('\n')):
        if re.search(pattern, line):
            matches.append((i, line))
    return matches


def remove_line(content: str, line_num: int) -> str:
    """
    Remove a specific line from content.
    """
    lines = content.split('\n')
    del lines[line_num]
    return '\n'.join(lines)


# =============================================================================
# Main cleanup logic
# =============================================================================

class CleanupRule:
    """A cleanup rule that can be applied to file content."""

    def __init__(self, name: str, apply_fn: Callable[[str], str], requires_build_check: bool = True):
        self.name = name
        self.apply_fn = apply_fn
        self.requires_build_check = requires_build_check

    def apply(self, content: str) -> str:
        return self.apply_fn(content)


# Simple rules that don't need build verification (safe transformations)
SAFE_RULES = [
    CleanupRule("Remove trailing semicolons", rule_remove_trailing_semicolons, requires_build_check=True),
    CleanupRule("Normalize whitespace", rule_normalize_whitespace, requires_build_check=True),
    CleanupRule("Replace refine' with refine", rule_replace_refine_prime, requires_build_check=True),
]

# Rules that need build verification (may break the build)
SPECULATIVE_RULES = [
    CleanupRule("Remove +decide flags", rule_remove_plus_decide, requires_build_check=True),
    CleanupRule("Remove +contextual flags", rule_remove_plus_contextual, requires_build_check=True),
    CleanupRule("Remove +zetaDelta flags", rule_remove_plus_zetaDelta, requires_build_check=True),
]


def cleanup_file(
    file_path: Path,
    dry_run: bool = False,
    verbose: bool = False
) -> bool:
    """
    Apply all cleanup rules to a file.

    Returns True if any changes were made.
    """
    # Find the Lake project root
    project_dir = find_lake_root(file_path)
    if project_dir is None:
        print(f"Error: Could not find lakefile.toml or lakefile.lean for {file_path}")
        return False

    module_name = file_to_module_name(file_path, project_dir)

    if verbose:
        print(f"Project root: {project_dir}")
        print(f"Module: {module_name}")

    # Read the file
    content = file_path.read_text()
    original_content = content

    # Verify initial build
    print("Verifying initial build...")
    if not verify_build(file_path, project_dir, module_name, verbose):
        print("Error: File doesn't build initially. Aborting.")
        return False
    print("Initial build OK")

    changes_made = False

    # Apply safe rules
    for rule in SAFE_RULES:
        new_content = rule.apply(content)
        if new_content != content:
            if dry_run:
                print(f"[DRY RUN] Would apply: {rule.name}")
            else:
                # Write and verify
                file_path.write_text(new_content)
                if verify_build(file_path, project_dir, module_name, verbose):
                    print(f"Applied: {rule.name}")
                    content = new_content
                    changes_made = True
                else:
                    print(f"Rolled back: {rule.name} (build failed)")
                    file_path.write_text(content)

    # Split semicolon-joined tactics one at a time
    if not dry_run:
        max_splits = 100  # Safety limit
        for _ in range(max_splits):
            new_content, changed = split_one_semicolon(content)
            if not changed:
                break

            file_path.write_text(new_content)
            if verify_build(file_path, project_dir, module_name, verbose):
                # Find what line was split for logging
                print(f"Split semicolon-joined tactics")
                content = new_content
                changes_made = True
            else:
                print(f"Kept semicolon (split would break build)")
                file_path.write_text(content)
                break  # Stop trying to split if one fails

    # Apply speculative rules
    for rule in SPECULATIVE_RULES:
        new_content = rule.apply(content)
        if new_content != content:
            if dry_run:
                print(f"[DRY RUN] Would try: {rule.name}")
            else:
                file_path.write_text(new_content)
                if verify_build(file_path, project_dir, module_name, verbose):
                    print(f"Applied: {rule.name}")
                    content = new_content
                    changes_made = True
                else:
                    print(f"Rolled back: {rule.name} (build failed)")
                    file_path.write_text(content)

    # Try removing each set_option line
    set_option_lines = find_lines_matching(content, r'^\s*set_option\s+')
    for line_num, line in reversed(set_option_lines):  # Reverse to preserve line numbers
        new_content = remove_line(content, line_num)
        if dry_run:
            print(f"[DRY RUN] Would try removing: {line.strip()}")
        else:
            file_path.write_text(new_content)
            if verify_build(file_path, project_dir, module_name, verbose):
                print(f"Removed: {line.strip()}")
                content = new_content
                changes_made = True
            else:
                print(f"Kept (needed): {line.strip()}")
                file_path.write_text(content)

    # Try removing each open scoped line
    open_scoped_lines = find_lines_matching(content, r'^\s*open\s+scoped\s+')
    for line_num, line in reversed(open_scoped_lines):
        new_content = remove_line(content, line_num)
        if dry_run:
            print(f"[DRY RUN] Would try removing: {line.strip()}")
        else:
            file_path.write_text(new_content)
            if verify_build(file_path, project_dir, module_name, verbose):
                print(f"Removed: {line.strip()}")
                content = new_content
                changes_made = True
            else:
                print(f"Kept (needed): {line.strip()}")
                file_path.write_text(content)

    # Remove unused simp arguments based on warnings
    # We need to iterate because removing one arg may reveal more warnings
    if not dry_run:
        max_iterations = 20  # Safety limit
        for iteration in range(max_iterations):
            file_path.write_text(content)
            success, stdout, stderr = run_build(file_path, project_dir, module_name, verbose)
            if not success:
                print("Warning: Build failed during unused simp arg removal")
                break

            build_output = stdout + stderr
            unused_args = parse_unused_simp_args(build_output, file_path)

            if not unused_args:
                break

            # Process each unused arg
            lines = content.split('\n')
            removed_any = False
            for line_num, arg_name in unused_args:
                # line_num is 1-indexed
                idx = line_num - 1
                if 0 <= idx < len(lines):
                    old_line = lines[idx]
                    new_line = remove_simp_arg_from_line(old_line, arg_name)
                    if new_line != old_line:
                        lines[idx] = new_line
                        print(f"Removed unused simp arg '{arg_name}' from line {line_num}")
                        removed_any = True
                        changes_made = True

            if not removed_any:
                break

            content = '\n'.join(lines)

    # Replace exact?, simp?, simp_all? with their "Try this:" suggestions
    if not dry_run:
        max_iterations = 20  # Safety limit
        for iteration in range(max_iterations):
            file_path.write_text(content)
            success, stdout, stderr = run_build(file_path, project_dir, module_name, verbose)
            if not success:
                print("Warning: Build failed during Try this replacement")
                break

            build_output = stdout + stderr
            suggestions = parse_try_this_suggestions(build_output, file_path)

            if not suggestions:
                break

            # Process each suggestion
            lines = content.split('\n')
            replaced_any = False
            for line_num, col, suggestion in suggestions:
                # line_num is 1-indexed
                idx = line_num - 1
                if 0 <= idx < len(lines):
                    old_line = lines[idx]
                    new_line = replace_tactic_with_suggestion(old_line, col, suggestion)
                    if new_line != old_line:
                        lines[idx] = new_line
                        print(f"Replaced tactic with '{suggestion}' at line {line_num}")
                        replaced_any = True
                        changes_made = True

            if not replaced_any:
                break

            content = '\n'.join(lines)

            # Verify the replacement doesn't break the build
            file_path.write_text(content)
            if not verify_build(file_path, project_dir, module_name, verbose):
                print("Warning: Build failed after Try this replacement, rolling back")
                # This shouldn't happen since we're using Lean's own suggestions
                break

    if dry_run:
        print(f"\n[DRY RUN] No changes made to {file_path}")
        return True  # Dry run is always "successful"
    elif changes_made:
        print(f"\nFile modified: {file_path}")
    else:
        print(f"\nNo changes made to {file_path}")

    return changes_made


def main():
    parser = argparse.ArgumentParser(
        description="Cleanup Aristotle-generated Lean proofs for Mathlib style"
    )
    parser.add_argument(
        "file",
        type=Path,
        help="Target Lean file to clean up"
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Show what would be changed without modifying the file"
    )
    parser.add_argument(
        "--verbose", "-v",
        action="store_true",
        help="Show detailed output"
    )

    args = parser.parse_args()

    if not args.file.exists():
        print(f"Error: File not found: {args.file}")
        sys.exit(1)

    if not args.file.suffix == ".lean":
        print(f"Error: Expected .lean file, got: {args.file}")
        sys.exit(1)

    success = cleanup_file(args.file, dry_run=args.dry_run, verbose=args.verbose)
    sys.exit(0 if success else 1)


if __name__ == "__main__":
    main()
