"""
Lean Compiler Wrapper

Provides a clean interface for compiling Lean 4 code and extracting error messages.
Wraps the LeanRunner from helpers and includes logic to parse compilation results.
"""

from typing import Optional, List, Any
from lean_interact.interface import CommandResponse, Message, Sorry, Tactic

from .run_lean import LeanRunner
from pathlib import Path
from datetime import datetime


class LeanCompiler:
    """
    Wrapper around LeanRunner that provides a simple compile interface.

    Handles compilation of Lean code and extraction of error messages,
    filtering out sorries in the axiom section.
    """

    lean_runner: LeanRunner

    def __init__(self, debug: bool = True, imports: str = "import Mathlib") -> None:
        """Initialize the LeanRunner."""
        self.lean_runner = LeanRunner(imports)
        self.debug = debug

    def compile(self, code: str, ignore_sorries: bool = False) -> str:
        """
        Compile Lean code and return error message if any.

        Args:
            code: Lean 4 code string to compile
            ignore_sorries: If True, ignore sorries and treat them as if proof is complete

        Returns:
            - Empty string if compilation successful (proof complete)
            - Error message string if there are errors
            - "unsolved goals\n{goal}" if there are sorries outside axiom section (unless ignore_sorries=True)

        The function filters out sorries in the axiom section (between
        ---- AXIOMS START and ---- AXIOMS END markers).
        """
        # Run the Lean code through LeanRunner
        response = self.lean_runner.run(code)

        # Parse the response to get error/goal state
        if self.debug:
            self._save_code(code)

        error_message = self._parse_goal_state(response, code, ignore_sorries=ignore_sorries)

        return error_message

    def _parse_goal_state(self, response: CommandResponse, code: Optional[str] = None, ignore_sorries: bool = False) -> str:
        """
        Parse goal state from LeanRunner response.

        This is based on the parse_goal_state function from the test notebook.

        Args:
            response: LeanRunner CommandResponse object
            code: Optional original code string to extract line content from errors
            ignore_sorries: If True, ignore sorries and treat them as if proof is complete

        Returns:
            - If error: formatted error with line number and content, plus last goal state if available
            - If "No goals to be solved": empty string (proof complete)
            - If unsolved goals (sorries): "unsolved goals\n" + goal state (excluding axiom sorries), unless ignore_sorries=True
            - If proof complete: empty string
        """
        if not response or not hasattr(response, 'messages'):
            return ""

        try:
            # First, check for actual errors (severity='error')
            if response.messages:
                error_messages: List[Message] = [
                    msg for msg in response.messages
                    if msg.severity == 'error'
                ]

                if error_messages:
                    # Process ALL errors
                    all_errors: List[str] = []
                    code_lines: List[str] = code.split('\n') if code else []

                    for error in error_messages:
                        error_text: str = error.data
                        result_parts: List[str] = []

                        # If we have the code, extract line number and content
                        if code:
                            line_num: int = error.start_pos.line

                            # Lean line numbers are 1-indexed
                            if 1 <= line_num <= len(code_lines):
                                line_content: str = code_lines[line_num - 1].strip()
                                result_parts.append(f"Error at line {line_num}: {line_content}")
                        else:
                            raise ValueError("lean_compiler: code is required to parse error message with line number and content")

                        # Try to find the goal state from tactics at or before the error
                        # Skip if the error line has ":= by" but is NOT a have statement this is to basically check if we are at the first line of a lemma
                        # This is because in the case where the error is at the by (unsolved goals error), we would be adding the goal state twice since the unsolved goal is already in the error message so we avoid that
                        should_show_goal_state: bool = True
                        if code:
                            line_num: int = error.start_pos.line
                            if 1 <= line_num <= len(code_lines):
                                line_content: str = code_lines[line_num - 1].strip()
                                # Check if line has ":= by" but is not a have statement
                                if ":= by" in line_content and not line_content.startswith("have "):
                                    should_show_goal_state = False
                        
                        if should_show_goal_state and response.tactics :
                            # Find the last tactic before or at the error position
                            relevant_tactic: Optional[Tactic] = None
                            for tactic in response.tactics:
                                if tactic.start_pos.line <= error.start_pos.line:
                                    if relevant_tactic is None or tactic.start_pos.line >= relevant_tactic.start_pos.line:
                                        relevant_tactic = tactic

                            if relevant_tactic and relevant_tactic.goals:
                                result_parts.append(f"\n\nGoal state before error:\n{relevant_tactic.goals}")

                        result_parts.append(f"\n\n{error_text}")
                        all_errors.append("".join(result_parts) if result_parts else error_text)

                    # Join all errors with separator
                    separator = "\n" + "="*80 + "\n"
                    return separator.join(all_errors)

            # If ignore_sorries is True, treat sorries as if proof is complete
            if ignore_sorries:
                return ""

            # If no errors, check for sorries and extract goal from sorry object
            if hasattr(response, 'sorries') and response.sorries:
                # Filter out sorries from axiom section if code is provided
                non_axiom_sorries: List[Sorry] = response.sorries

                if code:
                    # Find axiom section boundaries
                    axiom_start: Optional[int] = None
                    axiom_end: Optional[int] = None
                    code_lines: List[str] = code.split('\n')

                    for i, line in enumerate[str](code_lines):  # ignore sorries within the axioms
                        if '---- AXIOMS START' in line:
                            axiom_start = i + 1  # 0-indexed to 1-indexed
                        elif '---- AXIOMS END' in line:
                            axiom_end = i + 1  # 0-indexed to 1-indexed

                    # Filter sorries that are outside the axiom section
                    if axiom_start is not None and axiom_end is not None:
                        non_axiom_sorries = [
                            s for s in response.sorries
                            if not (axiom_start <= s.start_pos.line <= axiom_end)
                        ]
                else:
                    raise ValueError("lean_compiler: code is required to filter sorries outside the axiom section")

                # Return the first non-axiom sorry if any exist
                if non_axiom_sorries:
                    return f"unsolved goals\n{non_axiom_sorries[0].goal}"

            # No errors and no non-axiom sorries - proof complete
            return ""

        except Exception as e:
            return f"exception: {str(e)[:100]}"

    def extract_goals(self, code: str) -> List[str]:
        """
        Extract theorem statements from extract_goal tactics.

        The extract_goal tactic formats the current proof goal as a standalone
        theorem statement and outputs it as an info message. This method finds
        those info messages and returns the theorem statements.

        Args:
            code: Lean 4 code with extract_goal tactics

        Returns:
            List of theorem statements (as strings) from each extract_goal point
            Example: ["theorem extracted_1 (x y : ℕ) : x + y = y + x := sorry"]
        """
        # Run the code through LeanRunner
        response = self.lean_runner.run(code)

        if not response or not hasattr(response, 'messages'):
            return []

        # Find extract_goal tactic positions in the code
        code_lines = code.split('\n')
        extract_goal_positions: List[int] = []

        for i, line in enumerate(code_lines):
            if 'extract_goal' in line.strip():
                # Lean uses 1-indexed line numbers
                extract_goal_positions.append(i + 1)

        if not extract_goal_positions:
            return []

        # Find info messages at extract_goal positions
        theorem_statements: List[str] = []

        for message in response.messages:
            # Check if this is an info message (not error or warning)
            if message.severity not in ['information', 'info']:
                continue

            # Check if the message is at an extract_goal position
            message_line = message.start_pos.line

            if message_line in extract_goal_positions:
                # This is the theorem statement from extract_goal
                theorem_statements.append(message.data)

        return theorem_statements

    def _save_code(self, code: str) -> Path:
        """
        Save the code to a file with timestamp-based naming.

        Args:
            code: The Lean code to save

        Returns:
            Path to the saved file
        """
        # Create the directory if it doesn't exist
        save_dir = Path(__file__).parent / "compiled_code"
        save_dir.mkdir(parents=True, exist_ok=True)
        
        # Generate filename with timestamp
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")
        filename = f"compiled_{timestamp}.lean"
        filepath = save_dir / filename

        # Write the code to file
        with open(filepath, 'w', encoding='utf-8') as f:
            f.write(code)

        return filepath

# Example usage
if __name__ == "__main__":
    compiler = LeanCompiler()

    # Example: code with error
    code_with_error = """---- AXIOMS START
theorem example_axiom (x : ℝ) : 0 ≤ x → 0 ≤ x := by
  sorry 
-- we expect this sorry to be skipped since its within the axiom section
---- AXIOMS END

theorem test_thm (x : ℝ) : 0 ≤ x := by
  apply example_axiom
  -- Missing argument, will cause error

theorem test_thm_2 (x : ℕ) : x + y = y + x:= by
  induction x
  rw [add_zero]
  rw [zero_add]
  rw [add_succ]
"""

    print("Testing code with error:")
    error = compiler.compile(code_with_error)
    if error:
        print(f"Error found:\n{error}")
    else:
        print("Proof complete!")

    print("\n" + "="*50 + "\n")

    # Example: correct code
    code_correct = """---- AXIOMS START
theorem example_axiom (x : ℝ) : 0 ≤ x → 0 ≤ x := by
  sorry
---- AXIOMS END

theorem test_thm (x : ℝ) (h : 0 ≤ x) : 0 ≤ x := by
  exact h
"""

    print("Testing correct code:")
    error = compiler.compile(code_correct)
    if error:
        print(f"Error found:\n{error}")
    else:
        print("Proof complete!")
