"""
Helper module to run Lean commands and get the results.

Uses Lean Interact to run the commands and get the results.
Create an instance of the LeanRunner and call the run method to run a Lean4 code string.

Example:
```
lean_runner = LeanRunner(import_commands="import Mathlib")
response = lean_runner.run("theorem my_theorem : true := by trivial")
print(response)
```
"""
import os
from enum import Enum
import time
import threading
from lean_interact import LeanREPLConfig, LeanServer, LocalProject, Command
from lean_interact.interface import CommandResponse, LeanError, Message, Sorry

_import_commands: str = """
import Mathlib

open Matrix
open Finset
open Submodule
open Module
"""

def get_errors(response: CommandResponse) -> list[Message]:
    errors: list[Message] = []
    for message in response.messages:
        if message.severity == 'error':
            errors.append(message)
    return errors

def get_unsolved_goals_errors(response: CommandResponse) -> list[Message]:
    """
    Usually, there's only either 0 or 1 errors that say "unsolved goals"
    """
    return [msg for msg in get_errors(response) if msg.data.startswith("unsolved goals")]

def get_unsolved_goals_message(response: CommandResponse) -> str:
    """
    Gets the unsolved goals message if there is exactly one unsolved goals error;
    raises an exception otherwise
    """
    unsolved_goals_errors = get_unsolved_goals_errors(response)
    if len(unsolved_goals_errors) != 1:
        raise Exception("There wasn't exactly one unsolved goals message")
    return unsolved_goals_errors[0].data

def get_other_errors_and_sorries(response: CommandResponse) -> list[Message | Sorry]:
    """
    Returns all errors except "unsolved goals", as well as sorries.
    """
    errors: list[Message | Sorry] = \
        [msg for msg in get_errors(response) if not msg.data.startswith("unsolved goals")]
    errors.extend(response.sorries)
    return errors

def get_other_errors_and_sorries_message(response: CommandResponse) -> str:
    msgs: list[str] = []
    for e in get_other_errors_and_sorries(response):
        if isinstance(e, Message):
            msgs.append(e.data)
        else:
            msgs.append(str(e)) # TODO: Better formatting for sorries
    return "\n\n".join(msgs)

os.environ["PATH"] = os.path.expanduser("~/.elan/bin") + ":" + os.environ["PATH"]

class LeanRunner:
    def __init__(self, import_commands: str = _import_commands):
        self._server = None
        self._initialized_env_id = None
        self._initialization_lock = threading.Lock()
        self._server_healthy = False
        self._import_commands = import_commands
        self._initialize_server(import_commands)
    
    def _initialize_server(self, import_commands: str):
        """Initialize or reinitialize the Lean server."""
        print("Initializing Lean environment. Configuring Lean Interact...")
        project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "./.."))
        lean_interact_config = LeanREPLConfig(project=LocalProject(directory=project_root), verbose=True, build_repl = False)
        print("Starting the Lean Interact server...")
        self._server = LeanServer(lean_interact_config)

        print("Running imports...") 
        env = self._run_command_and_crash_if_error(import_commands, None)

        print("Setting Lean options...")
        env = self._run_command_and_crash_if_error("set_option linter.unusedVariables false", env)

        print("Checking to make sure Mathlib is available (this serves to check if imports succeeded)...")
        self._initialized_env_id = self._run_command_and_crash_if_error("#check Nat", env)
        self._initialized_env_id = self._run_command_and_crash_if_error("#check Real", env)

        self._server_healthy = True
        print("Lean environment initialized")
    
    def _check_server_health(self) -> bool:
        """Check if the server is responsive by running a simple command."""
        try:
            if not self._server or not self._initialized_env_id:
                return False
            
            # Try a simple health check command
            result = self._server.run(Command(cmd="#check Nat", env=self._initialized_env_id))
            if isinstance(result, LeanError):
                return False
            
            return True
        except Exception:
            return False
    
    def _restart_server(self):
        """Restart the server if it's unhealthy."""
        with self._initialization_lock:
            # Double-check health after acquiring lock
            if self._check_server_health():
                return
            
            print("⚠️ Lean server unhealthy, restarting...")
            self._server_healthy = False
            try:
                self._initialize_server(self._import_commands)
                print("✅ Lean server restarted successfully")
            except Exception as e:
                print(f"❌ Failed to restart Lean server: {e}")
                raise
    
    def _run_command_and_crash_if_error(self, command: str, env: int | None) -> int:
        """
        Primarily intended for initialization.
        Returns the resulting environment id.
        """
        assert self._server is not None
        result: CommandResponse | LeanError = self._server.run(Command(cmd=command, env=env))
        if isinstance(result, LeanError):
            raise Exception("Lean interact gave a LeanError: ", result)
        errors = get_errors(result)
        if errors:
            raise Exception("Lean interact gave errors: ", errors)
        return result.env
    
    def run(self, command: str, max_retries: int = 2) -> CommandResponse:
        """
        Run a command with automatic retry and server recovery on failure.
        
        Args:
            command: The Lean command to execute
            max_retries: Maximum number of retry attempts (default: 2)
            
        Returns:
            CommandResponse from the Lean server
            
        Raises:
            Exception: If all retry attempts fail
        """
        last_exception = None
        
        for attempt in range(max_retries + 1):
            try:
                # Check server health before running command
                if not self._server_healthy or not self._check_server_health():
                    if attempt == 0:
                        print(f"🔄 Server unhealthy, attempting restart (attempt {attempt + 1})")
                    self._restart_server()

                # Execute the command with all_tactics=True to get goal states
                assert self._server is not None
                result: CommandResponse | LeanError = \
                    self._server.run(Command(cmd=command, env=self._initialized_env_id, all_tactics=True))
                
                # Handle LeanErrors
                if isinstance(result, LeanError):
                    error_msg = f"Encountered LeanError: {result}"
                    
                    # If it's a server connection issue, try to restart
                    if "server" in str(result).lower() or "connection" in str(result).lower():
                        if attempt < max_retries:
                            print(f"🔄 LeanError suggests connection issue, retrying (attempt {attempt + 1})")
                            self._server_healthy = False
                            time.sleep(0.1 * (attempt + 1))  # Small backoff
                            continue
                    
                    # For other LeanErrors, fail immediately
                    raise Exception(error_msg)
                
                # Success - mark server as healthy and return result
                self._server_healthy = True
                return result
                
            except Exception as e:
                last_exception = e
                
                # If this was the last attempt, re-raise the exception
                if attempt == max_retries:
                    break
                
                # For connection-related errors, try to restart the server
                error_str = str(e).lower()
                if any(keyword in error_str for keyword in ["server", "connection", "interact", "timeout"]):
                    print(f"🔄 Connection error detected, retrying (attempt {attempt + 1}): {str(e)[:100]}")
                    self._server_healthy = False
                    time.sleep(0.2 * (attempt + 1))  # Exponential backoff
                else:
                    # For non-connection errors, re-raise immediately
                    break
        
        # If we get here, all retries failed
        raise Exception(f"Failed after {max_retries + 1} attempts. Last error: {last_exception}")
