# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import absolute_import, division, print_function
from abc import abstractmethod, ABC
from dataclasses import dataclass
from typing import Sequence, Tuple
from collections import OrderedDict
from enum import Enum
import importlib_resources
import platform

import logging
logger = logging.getLogger('prophet.models')

PLATFORM = "win" if platform.platform().startswith("Win") else "unix"

class TrendIndicator(Enum):
    LINEAR = 0
    LOGISTIC = 1
    FLAT = 2

@dataclass
class ModelInputData:
    T: int
    S: int
    K: int
    tau: float
    trend_indicator: int
    y: Sequence[float]  # length T
    t: Sequence[float]  # length T
    cap: Sequence[float]  # length T
    t_change: Sequence[float]  # length S
    s_a: Sequence[int]  # length K
    s_m: Sequence[int]  # length K
    X: Sequence[Sequence[float]]  # shape (T, K)
    sigmas: Sequence[float]  # length K

@dataclass
class ModelParams:
    k: float
    m: float
    delta: Sequence[float]  # length S
    beta: Sequence[float]  # length K
    sigma_obs: float


class IStanBackend(ABC):
    def __init__(self):
        self.model = self.load_model()
        self.stan_fit = None
        self.newton_fallback = True

    def set_options(self, **kwargs):
        """
        Specify model options as kwargs.
         * newton_fallback [bool]: whether to fallback to Newton if L-BFGS fails
        """
        for k, v in kwargs.items():
            if k == 'newton_fallback':
                self.newton_fallback = v
            else:
                raise ValueError(f'Unknown option {k}')


    @staticmethod
    @abstractmethod
    def get_type():
        pass

    @abstractmethod
    def load_model(self):
        pass

    @abstractmethod
    def fit(self, stan_init, stan_data, **kwargs) -> dict:
        pass

    @abstractmethod
    def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict:
        pass


class CmdStanPyBackend(IStanBackend):
    CMDSTAN_VERSION = "2.33.1"
    def __init__(self):
        import cmdstanpy
        # this must be set before super.__init__() for load_model to work on Windows
        local_cmdstan = importlib_resources.files("prophet") / "stan_model" / f"cmdstan-{self.CMDSTAN_VERSION}"
        if local_cmdstan.exists():
            cmdstanpy.set_cmdstan_path(str(local_cmdstan))
        super().__init__()

    @staticmethod
    def get_type():
        return StanBackendEnum.CMDSTANPY.name

    def load_model(self):
        import cmdstanpy
        model_file = importlib_resources.files("prophet") / "stan_model" / "prophet_model.bin"
        return cmdstanpy.CmdStanModel(exe_file=str(model_file))

    def fit(self, stan_init, stan_data, **kwargs):
        if 'inits' not in kwargs and 'init' in kwargs:
            stan_init = self.sanitize_custom_inits(stan_init, kwargs['init'])
            del kwargs['init']

        inits_list, data_list = self.prepare_data(stan_init, stan_data)
        args = dict(
            data=data_list,
            inits=inits_list,
            algorithm='Newton' if data_list['T'] < 100 else 'LBFGS',
            iter=int(1e4),
        )
        args.update(kwargs)

        try:
            self.stan_fit = self.model.optimize(**args)
        except RuntimeError as e:
            # Fall back on Newton
            if not self.newton_fallback or args['algorithm'] == 'Newton':
                raise e
            logger.warning('Optimization terminated abnormally. Falling back to Newton.')
            args['algorithm'] = 'Newton'
            self.stan_fit = self.model.optimize(**args)
        params = self.stan_to_dict_numpy(
            self.stan_fit.column_names, self.stan_fit.optimized_params_np)
        for par in params:
            params[par] = params[par].reshape((1, -1))
        return params

    def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict:
        if 'inits' not in kwargs and 'init' in kwargs:
            stan_init = self.sanitize_custom_inits(stan_init, kwargs['init'])
            del kwargs['init']

        inits_list, data_list = self.prepare_data(stan_init, stan_data)
        args = dict(
            data=data_list,
            inits=inits_list,
        )
        if 'chains' not in kwargs:
            kwargs['chains'] = 4
        iter_half = samples // 2
        kwargs['iter_sampling'] = iter_half
        if 'iter_warmup' not in kwargs:
            kwargs['iter_warmup'] = iter_half
        args.update(kwargs)

        self.stan_fit = self.model.sample(**args)
        res = self.stan_fit.draws()
        (samples, c, columns) = res.shape
        res = res.reshape((samples * c, columns))
        params = self.stan_to_dict_numpy(self.stan_fit.column_names, res)

        for par in params:
            s = params[par].shape
            if s[1] == 1:
                params[par] = params[par].reshape((s[0],))

            if par in ['delta', 'beta'] and len(s) < 2:
                params[par] = params[par].reshape((-1, 1))

        return params

    @staticmethod
    def sanitize_custom_inits(default_inits, custom_inits):
        """Validate that custom inits have the correct type and shape, otherwise use defaults."""
        sanitized = {}
        for param in ['k', 'm', 'sigma_obs']:
            try:
                sanitized[param] = float(custom_inits.get(param))
            except Exception:
                sanitized[param] = default_inits[param]
        for param in ['delta', 'beta']:
            if default_inits[param].shape == custom_inits[param].shape:
                sanitized[param] = custom_inits[param]
            else:
                sanitized[param] = default_inits[param]
        return sanitized

    @staticmethod
    def prepare_data(init, data) -> Tuple[dict, dict]:
        """Converts np.ndarrays to lists that can be read by cmdstanpy."""
        cmdstanpy_data = {
            'T': data['T'],
            'S': data['S'],
            'K': data['K'],
            'tau': data['tau'],
            'trend_indicator': data['trend_indicator'],
            'y': data['y'].tolist(),
            't': data['t'].tolist(),
            'cap': data['cap'].tolist(),
            't_change': data['t_change'].tolist(),
            's_a': data['s_a'].tolist(),
            's_m': data['s_m'].tolist(),
            'X': data['X'].to_numpy().tolist(),
            'sigmas': data['sigmas']
        }

        cmdstanpy_init = {
            'k': init['k'],
            'm': init['m'],
            'delta': init['delta'].tolist(),
            'beta': init['beta'].tolist(),
            'sigma_obs': init['sigma_obs']
        }
        return (cmdstanpy_init, cmdstanpy_data)

    @staticmethod
    def stan_to_dict_numpy(column_names: Tuple[str, ...], data: 'np.array'):
        import numpy as np

        output = OrderedDict()

        prev = None

        start = 0
        end = 0
        two_dims = len(data.shape) > 1
        for cname in column_names:
            parsed = cname.split(".") if "." in cname else cname.split("[")
            curr = parsed[0]
            if prev is None:
                prev = curr

            if curr != prev:
                if prev in output:
                    raise RuntimeError(
                        "Found repeated column name"
                    )
                if two_dims:
                    output[prev] = np.array(data[:, start:end])
                else:
                    output[prev] = np.array(data[start:end])
                prev = curr
                start = end
            end += 1
        if prev in output:
            raise RuntimeError(
                "Found repeated column name"
            )
        if two_dims:
            output[prev] = np.array(data[:, start:end])
        else:
            output[prev] = np.array(data[start:end])
        return output




class StanBackendEnum(Enum):
    CMDSTANPY = CmdStanPyBackend

    @staticmethod
    def get_backend_class(name: str) -> IStanBackend:
        try:
            return StanBackendEnum[name].value
        except KeyError as e:
            raise ValueError(f"Unknown stan backend: {name}") from e
