from typing import TYPE_CHECKING, Any, Dict, List, Union

from ...utils import (
    get_module_from_name,
    is_accelerate_available,
    is_nvidia_modelopt_available,
    is_torch_available,
    logging,
)
from ..base import DiffusersQuantizer


if TYPE_CHECKING:
    from ...models.modeling_utils import ModelMixin


if is_torch_available():
    import torch
    import torch.nn as nn

if is_accelerate_available():
    from accelerate.utils import set_module_tensor_to_device


logger = logging.get_logger(__name__)


class NVIDIAModelOptQuantizer(DiffusersQuantizer):
    r"""
    Diffusers Quantizer for Nvidia-Model Optimizer
    """

    use_keep_in_fp32_modules = True
    requires_calibration = False
    required_packages = ["nvidia_modelopt"]

    def __init__(self, quantization_config, **kwargs):
        super().__init__(quantization_config, **kwargs)

    def validate_environment(self, *args, **kwargs):
        if not is_nvidia_modelopt_available():
            raise ImportError(
                "Loading an nvidia-modelopt quantized model requires nvidia-modelopt library (`pip install nvidia-modelopt`)"
            )

        self.offload = False

        device_map = kwargs.get("device_map", None)
        if isinstance(device_map, dict):
            if "cpu" in device_map.values() or "disk" in device_map.values():
                if self.pre_quantized:
                    raise ValueError(
                        "You are attempting to perform cpu/disk offload with a pre-quantized modelopt model "
                        "This is not supported yet. Please remove the CPU or disk device from the `device_map` argument."
                    )
                else:
                    self.offload = True

    def check_if_quantized_param(
        self,
        model: "ModelMixin",
        param_value: "torch.Tensor",
        param_name: str,
        state_dict: Dict[str, Any],
        **kwargs,
    ):
        # ModelOpt imports diffusers internally. This is here to prevent circular imports
        from modelopt.torch.quantization.utils import is_quantized

        module, tensor_name = get_module_from_name(model, param_name)
        if self.pre_quantized:
            return True
        elif is_quantized(module) and "weight" in tensor_name:
            return True
        return False

    def create_quantized_param(
        self,
        model: "ModelMixin",
        param_value: "torch.Tensor",
        param_name: str,
        target_device: "torch.device",
        *args,
        **kwargs,
    ):
        """
        Create the quantized parameter by calling .calibrate() after setting it to the module.
        """
        # ModelOpt imports diffusers internally. This is here to prevent circular imports
        import modelopt.torch.quantization as mtq

        dtype = kwargs.get("dtype", torch.float32)
        module, tensor_name = get_module_from_name(model, param_name)
        if self.pre_quantized:
            module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
        else:
            set_module_tensor_to_device(model, param_name, target_device, param_value, dtype)
            mtq.calibrate(
                module, self.quantization_config.modelopt_config["algorithm"], self.quantization_config.forward_loop
            )
            mtq.compress(module)
            module.weight.requires_grad = False

    def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
        max_memory = {key: val * 0.90 for key, val in max_memory.items()}
        return max_memory

    def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
        if self.quantization_config.quant_type == "FP8":
            target_dtype = torch.float8_e4m3fn
        return target_dtype

    def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype":
        if torch_dtype is None:
            logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
            torch_dtype = torch.float32
        return torch_dtype

    def get_conv_param_names(self, model: "ModelMixin") -> List[str]:
        """
        Get parameter names for all convolutional layers in a HuggingFace ModelMixin. Includes Conv1d/2d/3d and
        ConvTranspose1d/2d/3d.
        """
        conv_types = (
            nn.Conv1d,
            nn.Conv2d,
            nn.Conv3d,
            nn.ConvTranspose1d,
            nn.ConvTranspose2d,
            nn.ConvTranspose3d,
        )

        conv_param_names = []
        for name, module in model.named_modules():
            if isinstance(module, conv_types):
                for param_name, _ in module.named_parameters(recurse=False):
                    conv_param_names.append(f"{name}.{param_name}")

        return conv_param_names

    def _process_model_before_weight_loading(
        self,
        model: "ModelMixin",
        device_map,
        keep_in_fp32_modules: List[str] = [],
        **kwargs,
    ):
        # ModelOpt imports diffusers internally. This is here to prevent circular imports
        import modelopt.torch.opt as mto

        if self.pre_quantized:
            return

        modules_to_not_convert = self.quantization_config.modules_to_not_convert

        if modules_to_not_convert is None:
            modules_to_not_convert = []
        if isinstance(modules_to_not_convert, str):
            modules_to_not_convert = [modules_to_not_convert]
        modules_to_not_convert.extend(keep_in_fp32_modules)
        if self.quantization_config.disable_conv_quantization:
            modules_to_not_convert.extend(self.get_conv_param_names(model))

        for module in modules_to_not_convert:
            self.quantization_config.modelopt_config["quant_cfg"]["*" + module + "*"] = {"enable": False}
        self.quantization_config.modules_to_not_convert = modules_to_not_convert
        mto.apply_mode(model, mode=[("quantize", self.quantization_config.modelopt_config)])
        model.config.quantization_config = self.quantization_config

    def _process_model_after_weight_loading(self, model, **kwargs):
        # ModelOpt imports diffusers internally. This is here to prevent circular imports
        from modelopt.torch.opt import ModeloptStateManager

        if self.pre_quantized:
            return model

        for _, m in model.named_modules():
            if hasattr(m, ModeloptStateManager._state_key) and m is not model:
                ModeloptStateManager.remove_state(m)

        return model

    @property
    def is_trainable(self):
        return True

    @property
    def is_serializable(self):
        self.quantization_config.check_model_patching(operation="saving")
        return True
