
from __future__ import annotations
from typing import Optional, List, Dict, TYPE_CHECKING, Tuple

from device.block_definition import ModeInterface, Port
from common_device.device_service_interface import BlockService

from design.db_item import (
    PeriDesignRegistry,
    PeriDesignGenPinItem,
    GenericParamGroup,
    GenericPin,
    GenericParamService,
    ParamGroupMonitor
)
from tx375_device.common_quad.dep_graph import QuadCommonDependencyGraph
from tx375_device.common_quad.quad_param_info import build_param_info
from tx375_device.common_quad.design_param_info import build_design_param_info, get_supported_params_by_quad_type, QuadDesignParamInfo
from tx375_device.common_quad.quad_prop_id import CommonQuadConfigParamInfo
from common_device.quad.res_service import QuadResService, QuadType

from tx375_device.lane10g.quad_param_info import get_supported_common_parameters as get_lane_10g_param
from tx375_device.lane1g.quad_param_info import get_supported_common_parameters as get_lane_1g_param
from tx375_device.raw_serdes.quad_param_info import get_supported_common_parameters as get_raw_serdes_param

if TYPE_CHECKING:
    from device.db import PeripheryDevice


class QuadLaneCommon(PeriDesignGenPinItem):
    """
    All lanes in the same Quad will be treated as a instance.

    E.g.
    Q0_LN1 -> Q0_LN2: UI can use QuadCommon with device name Q0 for display
    """
    def __init__(self, name: str, block_def: str = "", device_db: Optional[PeripheryDevice]=None):
        super().__init__()
        self.__name = name
        self.__block_def = block_def # Quad name
        self._categories = set()
        self._port_map = {}

        self._param_info = self.build_param_info()
        self._param_group = self.build_param()
        self._device_db = device_db

        if device_db is not None:
            self.build_port_info(device_db=device_db)
            self.build_generic_pin()

        self._dep_graph = QuadCommonDependencyGraph(self._param_info, self.gen_pin, self.get_pin_type_by_class)
        self.gp_monitor = ParamGroupMonitor(self._dep_graph)
        self._param_group.register_param_observer(self.gp_monitor)
        self.set_default_setting()

    @property
    def device_port_map(self):
        return self._port_map

    @property
    def name(self):
        return self.__name

    @name.setter
    def name(self, value: str):
        self.__name = value

    @property
    def dependency_graph(self):
        assert self._dep_graph is not None
        return self._dep_graph

    def get_device(self):
        return self.__block_def

    def set_device(self, device_name):
        self.__block_def = device_name

    def get_pin_property_name(self, pin_type_name):
        assert self.device_port_map is not None

        device_port = self.device_port_map.get(pin_type_name, None)
        if device_port is not None:
            self.dp_service.device_port = device_port
            desc = self.dp_service.get_desc()
            if desc != None:
                return desc

        return ""

    def get_pin_class(self, pin_type_name):
        """
        Based on device pin type, get its pin class.

        Pin type is the port name defined in device db.
        Pin class is the port class name defined in device db.

        :param pin_type_name: Device pin type name
        :return: Device pin class if found, else None
        """
        assert self.device_port_map is not None

        device_port = self.device_port_map.get(pin_type_name, None)
        if device_port is not None:
            self.dp_service.device_port = device_port
            return self.dp_service.get_class()

        return None

    def build_port_info(self, device_db: PeripheryDevice, is_mockup: bool = False):
        """
        Build port info for 10G which is mode based.
        This function need to be call only once since it is static shared
        between class.

        :param device_db: Device db instance
        :param is_mockup: True, build a mockup data, else build from device db
        """
        dev_service = self.get_blk_service(device_db)
        # This is common across all protocols
        mode_name_list = ['LN0_10G', 'LN0_1G', 'LN0_RAW_SERDES']
        device_port_map: Dict[str, ModeInterface] = {}

        for mode_name in mode_name_list:
            device_port_map.update(dev_service.get_interface_by_mode(mode_name))

        self.device_port_map.clear()

        skip_port = []
        for name, mode_port in device_port_map.items():
            if name in skip_port:
                continue

            if mode_port.get_type() != Port.TYPE_PAD and self.is_common_class(mode_port.get_class()):

                # All the interface have description except some
                # which are in exlucded list
                assert mode_port.get_description() not in ("", None)
                self.device_port_map[name] = mode_port

    def is_common_class(self, class_str: Optional[str]) -> bool:
        return class_str is not None and \
            class_str.endswith((":SW_COMMON", "Common Parameters"))

    def get_blk_service(self, device_db: PeripheryDevice) -> BlockService:
        from device.db_interface import DeviceDBService
        dbi = DeviceDBService(device_db)

        # TODO: Hard-code Quad0 and Quad2 used by PCIe, should use db information
        if self.__block_def.endswith(("_0", "_2")):
            blk_type = DeviceDBService.BlockType.QUAD_PCIE
        else:
            blk_type = DeviceDBService.BlockType.QUAD

        return dbi.get_block_service(blk_type)

    def build_param_info(self):
        param_info = build_param_info()
        more_param_info = build_design_param_info()

        # Update options
        if self.__block_def != "":
            pll_options = self.get_all_ref_clk_pll_options()
            param_list = [
                QuadDesignParamInfo.Id.pll_ref_clk0,
                QuadDesignParamInfo.Id.pll_ref_clk1,
            ]
            for param_id in param_list:
                prop_info = more_param_info.get_prop_info_by_name(param_id.value)
                assert prop_info is not None
                prop_info.valid_setting = pll_options

        param_info.concat_param_info(more_param_info)
        return param_info

    def get_all_ref_clk_pll_options(self):
        options = []
        if self._device_db is None or self.__block_def == "":
            return [
                "PMA_CMN_REFCLK_PLL_1",
                "PMA_CMN_REFCLK_PLL_2",
                "PMA_CMN_REFCLK1_PLL_1",
                "PMA_CMN_REFCLK1_PLL_2",
            ]

        dev_service = self.get_blk_service(self._device_db)
        mode_name = 'LN0_10G'
        device_port_map: Dict[str, ModeInterface] = dev_service.get_interface_by_mode(mode_name)
        mode_inf_obj = device_port_map.get("REFCLK", None)
        assert mode_inf_obj is not None

        inf2ports = mode_inf_obj.get_interface_to_ports_map()
        for name in inf2ports.values():
            res_name_list, ref_pin = dev_service.get_all_resource_on_ins_pin(
                self.__block_def, name, None)
            if len(res_name_list) <= 0:
                continue
            elif "PLL" not in name:
                continue
            options.append(name)

        return options


    def build_param(self):
        assert self._param_info is not None
        param_group = GenericParamGroup()
        for param_info in self._param_info.get_all_prop():
            param_group.add_param(param_info.name, param_info.default, param_info.data_type)
            self._categories.add(param_info.category)

        return param_group

    def build_generic_pin(self, is_rebuild: bool = False):
        from device.block_definition import Port as DevicePort, PortDir
        new_pin_list = []
        if not is_rebuild:
            self.gen_pin.clear()

        for device_port in self.device_port_map.values():
            self.dp_service.device_port = device_port
            type_name = self.dp_service.get_name()
            if is_rebuild:
                if self.gen_pin.get_pin_by_type_name(type_name) is None:
                    new_pin_list.append(type_name)
                else:
                    continue
            else:
                new_pin_list.append(type_name)

            if self.dp_service.get_type() == DevicePort.TYPE_CLOCK:
                self.gen_pin.add_clock_pin(type_name, "", self.dp_service.is_bus_port())
            else:
                direction = BlockService.get_port_direction_from_core(device_port)
                if direction == PortDir.input:
                    self.gen_pin.add_input_pin(type_name, "", self.dp_service.is_bus_port())
                elif direction == PortDir.output:
                    self.gen_pin.add_output_pin(type_name, "", self.dp_service.is_bus_port())
                else:
                    self.gen_pin.add_pin(type_name, "", self.dp_service.is_bus_port())

        if is_rebuild:
            for pin_type in self.gen_pin.get_all_pin_type_name():
                if pin_type not in self.device_port_map:
                    self.gen_pin.delete_pin_by_type(pin_type)

        return new_pin_list

    def rebuild_generic_pin(self, device_db: PeripheryDevice):
        self.build_port_info(device_db)
        new_pin_type_list = self.build_generic_pin(is_rebuild=self.__block_def != "")
        if len(new_pin_type_list) > 0:
            self.gen_pin.generate_pin_name_by_type(self.__name, new_pin_type_list)

        # Clockout interface and clock input need user to key in name
        self.update_default_pin_name(new_pin_type_list)

    # Overload
    def generate_pin_name_from_inst(self, inst_name: str):
        assert self.gen_pin is not None

        for pin in self.gen_pin.get_all_pin():
            is_available = self._dep_graph.get_pin_attributes(pin.type_name)['is_available']
            if is_available:
                pin.generate_pin_name(inst_name)

        # Clockout interface and clock input need user to key in name
        self.update_default_pin_name()

    def generate_specific_pin_name_from_inst(self, pin_name, is_only_empty_name: bool = False):
        '''

        :param pin_name: The pin type_name to have its pin name generated
        :param is_empty_name: IF true, only regenerate if the pin name is empty.
        :return:
        '''
        assert self.gen_pin is not None

        if self._is_clkout_pin(pin_name) or self._is_input_clk_pin(pin_name) or\
            self._is_hidden_pin(pin_name):
            return

        pin = self.gen_pin.get_pin_by_type_name(pin_name)
        if pin is not None and (not is_only_empty_name or \
                                    (pin.name == "" and is_only_empty_name)):
            pin.generate_pin_name(self.name)

    def update_default_pin_name(self, updated_pin_name_list: Optional[List[str]] = None):
        pin_list: List[GenericPin] = self.gen_pin.get_all_pin()
        for pin in pin_list:
            if updated_pin_name_list is not None and pin.type_name not in updated_pin_name_list:
                continue
            # No auto-generate pin name for CLKOUT and input clock pins
            if self._is_clkout_pin(pin.type_name) or \
                self._is_input_clk_pin(pin.type_name) or \
                self._is_hidden_pin(pin.type_name):
                pin.name = ""

    def generate_pin_name(self):
        self.generate_pin_name_from_inst(self.__name)

    def _is_clkout_pin(self, pin_type_name: str) -> bool:
        return pin_type_name in {
                                 'USER_APB_CLK',
                                 'PMA_CMN_REFCLK_CORE',
                                 'PMA_CMN_REFCLK1_CORE',
                                }

    def _is_input_clk_pin(self, pin_type_name: str) -> bool:
        return pin_type_name in {}

    def _is_hidden_pin(self, pin_type_name: str) -> bool:
        # TODO: This is a common section for common lane but the individual
        # hidden pins list vary between protocols.
        return pin_type_name in {
                                    "CH_RX_FWD_CLK",
                                    "10G_USER_APB_RESET_N",
                                    "PMA_CMN_EXT_REFCLK1_DETECTED",
                                    "PMA_CMN_EXT_REFCLK1_DETECTED_CFG",
                                    "PMA_CMN_EXT_REFCLK1_DETECTED_VALID",
                                    "PMA_CMN_EXT_REFCLK_DETECTED",
                                    "PMA_CMN_EXT_REFCLK_DETECTED_CFG",
                                    "PMA_CMN_EXT_REFCLK_DETECTED_VALID",
                                    "PMA_CMN_REFCLK1_CORE",
                                    "PMA_CMN_REFCLK_CORE",
                                    "USER_APB_PSLVERR",
                                }

    def get_unique_pins_by_quad_type(self, quad_type: QuadType) -> List[str]:
        quad_to_pins = {
            QuadType.lane_1g: [
                "SGMII_USER_APB_RESET_N",
                "LED_TICK_TOGGLE",
                "REM_PRE"
            ],
        }
        return quad_to_pins.get(quad_type, [])

    def get_filter_pins_by_quad_type(self, quad_type: QuadType) -> List[str]:
        '''
        Give a list of pin type name that should be filtered out
        based on the quad type being passed.
        For example, for 10G, we should filter the 1G unique pins. For
        1G, there's nothing to filter
        :return a list of pin type name that is to be filtered
        '''
        if quad_type in [QuadType.lane_10g, QuadType.raw_serdes]:
            return [
                "SGMII_USER_APB_RESET_N",
                "LED_TICK_TOGGLE",
                "REM_PRE"
            ]

        return []

    @staticmethod
    def get_all_precision() -> Dict[QuadDesignParamInfo.Id | CommonQuadConfigParamInfo.Id, int]:
        return {
        }

    @staticmethod
    def get_precision(param_id: QuadDesignParamInfo.Id | CommonQuadConfigParamInfo.Id):
        return QuadLaneCommon.get_all_precision().get(param_id, None)

    def set_default_setting(self, device_db=None):
        param_service = GenericParamService(self._param_group, self._param_info)
        for param_id in CommonQuadConfigParamInfo.Id:
            default_val = self._param_info.get_default(param_id)
            param_service.set_param_value(param_id, default_val)

    def create_chksum(self):
        return super().create_chksum()

    @property
    def param_info(self):
        return self._param_info

    @property
    def param_group(self):
        return self._param_group

    def get_param_info(self):
        return self._param_info

    def get_param_group(self):
        return self._param_group

    def get_available_params(self) -> List[str]:
        param_name_list: List[str] = []
        for param in self._param_group.get_all_param():
            param_name_list.append(param.name)
        return param_name_list

    def get_available_pins(self) -> List[str]:
        assert self._dep_graph is not None
        pin_type_name_list: List[str] = []

        for pin in self.gen_pin.get_all_pin():
            pin_type_name_list.append(pin.type_name)

        return pin_type_name_list

    def get_available_params_by_quad_type(self, quad_type: QuadType, inst_name: str) -> List[str]:
        """
        Get all available parameters by quad type (e.g. 10G, 1G, etc)

        :param quad_type: Quad type
        :type quad_type: QuadType
        :return: List of available parameters
        :rtype: List[str]
        """
        param_name_list: List[str] = []

        for param in self._param_group.get_all_param():
            if self.is_param_available(param.name, quad_type, inst_name):
                param_name_list.append(param.name)

        return param_name_list

    def get_supported_common_parameters(self, quad_type: QuadType):
        supported_parameters = []
        match quad_type:
            case QuadType.lane_10g:
                supported_parameters = get_lane_10g_param()
            case QuadType.lane_1g:
                supported_parameters = get_lane_1g_param()
            case QuadType.raw_serdes:
                supported_parameters = get_raw_serdes_param()

        return supported_parameters

    def get_sw_supported_cmn_parameters(self, quad_type: QuadType):
        return get_supported_params_by_quad_type(quad_type)

    def get_available_pins_by_quad_type(self, quad_type: QuadType, inst_name: str) -> List[str]:
        """
        Get all available pins by quad type (e.g. 10G, 1G, etc)

        :param quad_type: Quad type
        :type quad_type: QuadType
        :return: List of available pins
        :rtype: List[str]
        """
        assert self._dep_graph is not None
        pin_type_name_list: List[str] = []

        for pin in self.gen_pin.get_all_pin():
            if self.is_pin_available(quad_type, inst_name, pin.type_name):
                pin_type_name_list.append(pin.type_name)

        return pin_type_name_list

    def is_param_available(self, param_name: str, quad_type: QuadType, inst_name: str) -> bool:
        if QuadDesignParamInfo.Id.has_member(param_name) and \
            not QuadDesignParamInfo.Id(param_name) in self.get_sw_supported_cmn_parameters(quad_type):
            return False

        elif CommonQuadConfigParamInfo.Id.has_member(param_name) and \
            not param_name in self.get_supported_common_parameters(quad_type):
            return False

        inst2available = self._dep_graph.get_param_attributes(
            param_name)['is_available'].get(quad_type, {})
        is_available = inst2available.get(inst_name, False)
        return is_available

    def is_pin_available(self, quad_type: QuadType, inst_name: str, pin_type_name: str) -> bool:
        is_available = False
        attr = self._dep_graph.get_pin_attributes(pin_type_name)

        if attr is not None:
            inst2available = attr['is_available'].get(quad_type, {})
            is_available = inst2available.get(inst_name, False)

        return is_available

    def register_param_observer(self, quad_type: QuadType, lane_inst_name: str, gp_monitor):
        self.dependency_graph.add_lane_inst(quad_type, lane_inst_name)
        self.param_group.register_param_observer(gp_monitor)

    def unregister_param_observer(self, quad_type: QuadType, lane_inst_name: str, gp_monitor):
        self.dependency_graph.delete_lane_inst(quad_type, lane_inst_name)
        self.param_group.unregister_param_observer(gp_monitor)

    def update_settings_by_others(self, other: QuadLaneCommon):
        private_attr_list = ["name", "block_def"]
        exclude_attr_list = ["_param_group", "_dep_graph"]

        for attr in private_attr_list:
            exclude_attr_list.append(f"_{self.__class__.__name__}__{attr}")

        # Since we have observers in param_group, need special handling for this
        self.copy_common(other, exclude_attr=exclude_attr_list)
        self.param_group.update(other.param_group)

        # Gen pin
        self.refresh_pin_name(other.name, self.name)

    def reset_param_by_id(self, param_id: QuadDesignParamInfo.Id| CommonQuadConfigParamInfo.Id):
        default_val = self.param_info.get_default(param_id)

        if self.param_group.get_param_value(param_id.value) != default_val:
            self.param_group.set_param_value(param_id.value, default_val)

    @staticmethod
    def get_phy_reset_related_pcr_list():
        return [
            CommonQuadConfigParamInfo.Id.pcr_q0_user_phy_reset_n_expose,
            CommonQuadConfigParamInfo.Id.pcr_q1_user_phy_reset_n_expose,
        ]


class QuadLaneCommonRegistry(PeriDesignRegistry):
    """
    Common Registry for Quad/Lane.

    It will be used for UI display, managing pin names etc.
    """
    def __init__(self):
        self.lane_inst2cmn_inst: Dict[Tuple[QuadType, str], QuadLaneCommon] = {}
        super().__init__()

    def create_instance(self, name, apply_default=True, auto_pin=False):
        with self._write_lock:
            inst = QuadLaneCommon(name, device_db=self.device_db)
            if apply_default:
                inst.set_default_setting(self.device_db)
            if auto_pin:
                inst.generate_pin_name()

            self._register_new_instance(inst)
            return inst

    def connect_lane_inst_2cmn_inst(self, quad_type: QuadType, lane_name: str, inst: QuadLaneCommon):
        if not quad_type.is_lane_based():
            raise ValueError(f"Quad type {quad_type} is not lane based")

        self.lane_inst2cmn_inst[(quad_type, lane_name)] = inst

    def disconnect_lane_inst_2cmn_inst(self, quad_type: QuadType, lane_name: str):
        if not quad_type.is_lane_based():
            raise ValueError(f"Quad type {quad_type} is not lane based")

        self.lane_inst2cmn_inst.pop((quad_type, lane_name), None)

    def rename_lane_inst(self, quad_type: QuadType, current_name: str, new_name: str):
        if not quad_type.is_lane_based():
            raise ValueError(f"Quad type {quad_type} is not lane based")

        inst = self.lane_inst2cmn_inst[(quad_type, current_name)]
        self.lane_inst2cmn_inst[(quad_type, new_name)] = inst
        self.lane_inst2cmn_inst.pop((quad_type, current_name), None)
        inst.dependency_graph.rename_lane_inst(quad_type, current_name, new_name)

    def get_inst_by_lane_name(self, quad_type: QuadType, lane_name: str) -> Optional[QuadLaneCommon]:
        if not quad_type.is_lane_based():
            raise ValueError(f"Quad type {quad_type} is not lane based")

        return self.lane_inst2cmn_inst.get((quad_type, lane_name), None)

    def assign_inst_device(self, inst: QuadLaneCommon, new_dev: str):
        new_dev = self.translate_device_def(new_dev)
        super().assign_inst_device(inst, new_dev)

        if self.device_db is not None:
            inst.rebuild_generic_pin(device_db=self.device_db)

    def apply_device_db(self, device_db: PeripheryDevice):
        assert device_db is not None
        self.device_db = device_db

    def is_quad_based_resource(self, res_name: str):
        quad_num, lane_num = QuadResService.break_res_name(res_name)
        return quad_num != -1 and lane_num == -1

    def is_lane_based_resource(self, res_name: str):
        quad_num, lane_num = QuadResService.break_res_name(res_name)
        return quad_num != -1 and lane_num != -1

    def is_device_valid(self, device_def: str):
        if self.is_quad_based_resource(device_def) or self.is_lane_based_resource(device_def):
            return True
        return False

    def is_device_used(self, device_def: str):
        dev_name = self.translate_device_def(device_def)
        return super().is_device_used(dev_name)

    def translate_device_def(self, device_def: str):
        quad_num, lane_num = QuadResService.break_res_name(device_def)

        # Lane based resource
        if quad_num != -1 and lane_num != -1:
            dev_name = f"QUAD_{quad_num}"
        # None
        elif quad_num == -1 and lane_num == -1:
            dev_name = ""
        else:
            dev_name = device_def

        return dev_name

    def get_inst_by_device_name(self, device_def):
        dev_name = self.translate_device_def(device_def)
        return super().get_inst_by_device_name(dev_name)

    def is_lane_same_quad(self, lane_res_name: str, lane_res_name2) -> bool:
        """
        Check if 2 lane resources are on the same quad.
        """
        quad_num, _ = QuadResService.break_res_name(lane_res_name)
        quad_num_2, _ = QuadResService.break_res_name(lane_res_name2)
        return quad_num == quad_num_2
