"""
Container for the information used in a generic CmdStan run,
such as file locations
"""

import os
import re
import shutil
import tempfile
from datetime import datetime
from time import time
from typing import List, Optional

from cmdstanpy import _TMPDIR
from cmdstanpy.cmdstan_args import CmdStanArgs, Method
from cmdstanpy.utils import get_logger


class RunSet:
    """
    Encapsulates the configuration and results of a call to any CmdStan
    inference method. Records the method return code and locations of
    all console, error, and output files.

    RunSet objects are instantiated by the CmdStanModel class inference methods
    which validate all inputs, therefore "__init__" method skips input checks.
    """

    def __init__(
        self,
        args: CmdStanArgs,
        chains: int = 1,
        *,
        chain_ids: Optional[List[int]] = None,
        time_fmt: str = "%Y%m%d%H%M%S",
        one_process_per_chain: bool = True,
    ) -> None:
        """Initialize object (no input arg checks)."""
        self._args = args
        self._chains = chains
        self._one_process_per_chain = one_process_per_chain
        if one_process_per_chain:
            self._num_procs = chains
        else:
            self._num_procs = 1
        self._retcodes = [-1 for _ in range(self._num_procs)]
        self._timeout_flags = [False for _ in range(self._num_procs)]
        if chain_ids is None:
            chain_ids = [i + 1 for i in range(chains)]
        self._chain_ids = chain_ids

        if args.output_dir is not None:
            self._output_dir = args.output_dir
        else:
            # make a per-run subdirectory of our master temp directory
            self._output_dir = tempfile.mkdtemp(
                prefix=args.model_name, dir=_TMPDIR
            )

        # output files prefix: ``<model_name>-<YYYYMMDDHHMM>_<chain_id>``
        self._base_outfile = (
            f'{args.model_name}-{datetime.now().strftime(time_fmt)}'
        )
        # per-process outputs
        self._stdout_files = [''] * self._num_procs
        self._profile_files = [''] * self._num_procs  # optional
        if one_process_per_chain:
            for i in range(chains):
                self._stdout_files[i] = self.file_path("-stdout.txt", id=i)
                if args.save_profile:
                    self._profile_files[i] = self.file_path(
                        ".csv", extra="-profile", id=chain_ids[i]
                    )
        else:
            self._stdout_files[0] = self.file_path("-stdout.txt")
            if args.save_profile:
                self._profile_files[0] = self.file_path(
                    ".csv", extra="-profile"
                )

        # per-chain output files
        self._csv_files: List[str] = [''] * chains
        self._diagnostic_files = [''] * chains  # optional

        if chains == 1:
            self._csv_files[0] = self.file_path(".csv")
            if args.save_latent_dynamics:
                self._diagnostic_files[0] = self.file_path(
                    ".csv", extra="-diagnostic"
                )
        else:
            for i in range(chains):
                self._csv_files[i] = self.file_path(".csv", id=chain_ids[i])
                if args.save_latent_dynamics:
                    self._diagnostic_files[i] = self.file_path(
                        ".csv", extra="-diagnostic", id=chain_ids[i]
                    )

    def __repr__(self) -> str:
        repr = 'RunSet: chains={}, chain_ids={}, num_processes={}'.format(
            self._chains, self._chain_ids, self._num_procs
        )
        repr = '{}\n cmd (chain 1):\n\t{}'.format(repr, self.cmd(0))
        repr = '{}\n retcodes={}'.format(repr, self._retcodes)
        repr = f'{repr}\n per-chain output files (showing chain 1 only):'
        repr = '{}\n csv_file:\n\t{}'.format(repr, self._csv_files[0])
        if self._args.save_latent_dynamics:
            repr = '{}\n diagnostics_file:\n\t{}'.format(
                repr, self._diagnostic_files[0]
            )
        if self._args.save_profile:
            repr = '{}\n profile_file:\n\t{}'.format(
                repr, self._profile_files[0]
            )
        repr = '{}\n console_msgs (if any):\n\t{}'.format(
            repr, self._stdout_files[0]
        )
        return repr

    @property
    def model(self) -> str:
        """Stan model name."""
        return self._args.model_name

    @property
    def method(self) -> Method:
        """CmdStan method used to generate this fit."""
        return self._args.method

    @property
    def num_procs(self) -> int:
        """Number of processes run."""
        return self._num_procs

    @property
    def one_process_per_chain(self) -> bool:
        """
        When True, for each chain, call CmdStan in its own subprocess.
        When False, use CmdStan's `num_chains` arg to run parallel chains.
        Always True if CmdStan < 2.28.
        For CmdStan 2.28 and up, `sample` method determines value.
        """
        return self._one_process_per_chain

    @property
    def chains(self) -> int:
        """Number of chains."""
        return self._chains

    @property
    def chain_ids(self) -> List[int]:
        """Chain ids."""
        return self._chain_ids

    def cmd(self, idx: int) -> List[str]:
        """
        Assemble CmdStan invocation.
        When running parallel chains from single process (2.28 and up),
        specify CmdStan arg `num_chains` and leave chain idx off CSV files.
        """
        if self._one_process_per_chain:
            return self._args.compose_command(
                idx,
                csv_file=self.csv_files[idx],
                diagnostic_file=self.diagnostic_files[idx]
                if self._args.save_latent_dynamics
                else None,
                profile_file=self.profile_files[idx]
                if self._args.save_profile
                else None,
            )
        else:
            return self._args.compose_command(
                idx,
                csv_file=self.file_path('.csv'),
                diagnostic_file=self.file_path(".csv", extra="-diagnostic")
                if self._args.save_latent_dynamics
                else None,
                profile_file=self.file_path(".csv", extra="-profile")
                if self._args.save_profile
                else None,
            )

    @property
    def csv_files(self) -> List[str]:
        """List of paths to CmdStan output files."""
        return self._csv_files

    @property
    def stdout_files(self) -> List[str]:
        """
        List of paths to transcript of CmdStan messages sent to the console.
        Transcripts include config information, progress, and error messages.
        """
        return self._stdout_files

    def _check_retcodes(self) -> bool:
        """Returns ``True`` when all chains have retcode 0."""
        for code in self._retcodes:
            if code != 0:
                return False
        return True

    @property
    def diagnostic_files(self) -> List[str]:
        """List of paths to CmdStan hamiltonian diagnostic files."""
        return self._diagnostic_files

    @property
    def profile_files(self) -> List[str]:
        """List of paths to CmdStan profiler files."""
        return self._profile_files

    # pylint: disable=invalid-name
    def file_path(
        self, suffix: str, *, extra: str = "", id: Optional[int] = None
    ) -> str:
        if id is not None:
            suffix = f"_{id}{suffix}"
        file = os.path.join(
            self._output_dir, f"{self._base_outfile}{extra}{suffix}"
        )
        return file

    def _retcode(self, idx: int) -> int:
        """Get retcode for process[idx]."""
        return self._retcodes[idx]

    def _set_retcode(self, idx: int, val: int) -> None:
        """Set retcode at process[idx] to val."""
        self._retcodes[idx] = val

    def _set_timeout_flag(self, idx: int, val: bool) -> None:
        """Set timeout_flag at process[idx] to val."""
        self._timeout_flags[idx] = val

    def get_err_msgs(self) -> str:
        """Checks console messages for each CmdStan run."""
        msgs = []
        for i in range(self._num_procs):
            if (
                os.path.exists(self._stdout_files[i])
                and os.stat(self._stdout_files[i]).st_size > 0
            ):
                if self._args.method == Method.OPTIMIZE:
                    msgs.append('console log output:\n')
                    with open(self._stdout_files[0], 'r') as fd:
                        msgs.append(fd.read())
                else:
                    with open(self._stdout_files[i], 'r') as fd:
                        contents = fd.read()
                        # pattern matches initial "Exception" or "Error" msg
                        pat = re.compile(r'^E[rx].*$', re.M)
                        errors = re.findall(pat, contents)
                        if len(errors) > 0:
                            msgs.append('\n\t'.join(errors))
        return '\n'.join(msgs)

    def save_csvfiles(self, dir: Optional[str] = None) -> None:
        """
        Moves CSV files to specified directory.

        :param dir: directory path

        See Also
        --------
        cmdstanpy.from_csv
        """
        if dir is None:
            dir = os.path.realpath('.')
        test_path = os.path.join(dir, str(time()))
        try:
            os.makedirs(dir, exist_ok=True)
            with open(test_path, 'w'):
                pass
            os.remove(test_path)  # cleanup
        except (IOError, OSError, PermissionError) as exc:
            raise RuntimeError('Cannot save to path: {}'.format(dir)) from exc

        for i in range(self.chains):
            if not os.path.exists(self._csv_files[i]):
                raise ValueError(
                    'Cannot access CSV file {}'.format(self._csv_files[i])
                )

            to_path = os.path.join(dir, os.path.basename(self._csv_files[i]))
            if os.path.exists(to_path):
                raise ValueError(
                    'File exists, not overwriting: {}'.format(to_path)
                )
            try:
                get_logger().debug(
                    'saving tmpfile: "%s" as: "%s"', self._csv_files[i], to_path
                )
                shutil.move(self._csv_files[i], to_path)
                self._csv_files[i] = to_path
            except (IOError, OSError, PermissionError) as e:
                raise ValueError(
                    'Cannot save to file: {}'.format(to_path)
                ) from e

    def raise_for_timeouts(self) -> None:
        if any(self._timeout_flags):
            raise TimeoutError(
                f"{sum(self._timeout_flags)} of {self.num_procs} processes "
                "timed out"
            )
