from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Optional

import pt_version

from common_device.quad.dep_graph import LaneBaseDependencyGraph
from common_device.quad.res_service import QuadResService

from tx375_device.raw_serdes.quad_dep_graph import RawSerdesConfigDependencyGraph
from tx375_device.raw_serdes.raw_serdes_pin_dep_graph import RawSerdesPinDependencyGraph
from tx375_device.raw_serdes.raw_serdes_pll_cfg_prop_id import RawSerdesPLLConfigParamInfo as RawSerdesPLLParamInfo
from tx375_device.raw_serdes.raw_serdes_prop_id import RawSerdesConfigParamInfo
from tx375_device.raw_serdes.design_param_info import RawSerdesDesignParamInfo as SWParamInfo
from tx375_device.common_quad.quad_prop_id import CommonQuadConfigParamInfo as CommonQuadParamInfo
from tx375_device.common_quad.design_param_info import QuadDesignParamInfo as CmnDesignParamInfo


if TYPE_CHECKING:
    from common_device.property import PropertyMetaData
    from design.db_item import GenericPinGroup

    from tx375_device.raw_serdes.design import RawSerdes

IS_SHOW_HIDDEN = True if pt_version.PT_DEBUG_VERSION == True else False

class RawSerdesDependencyGraph(LaneBaseDependencyGraph, RawSerdesConfigDependencyGraph, RawSerdesPinDependencyGraph):

    def build_dependency(self, param_info: PropertyMetaData, port_info: GenericPinGroup):
        super().build_param_dependency(param_info, port_info)
        super().build_pin_dependency(param_info, port_info)
        self.inst: 'RawSerdes' # For type checking

        pll_config_keys = [
            CommonQuadParamInfo.Id.ss_raw_refclk_freq,
            RawSerdesConfigParamInfo.Id.ss_raw_data_rate_lane_NID,
            RawSerdesConfigParamInfo.Id.ss_raw_serdes_width_lane_NID,
        ]
        for param_a in pll_config_keys:
            for param_b in RawSerdesPLLParamInfo.Id:
                self.add_param_dependency(
                    param_a_name=param_a.value,
                    param_b_name=param_b.value,
                    change_func=self.on_pll_config_key_changed)

        # PT-2559 Need to enable PCR if the pin is enabled and resource change
        param_list = [
            CommonQuadParamInfo.Id.pcr_q0_user_phy_reset_n_expose,
            CommonQuadParamInfo.Id.pcr_q1_user_phy_reset_n_expose
        ]
        for param_id in param_list:
            self.add_param_dependency(
                param_a_name=CmnDesignParamInfo.Id.phy_reset_en.value,
                param_b_name=param_id.value,
                change_func=self.on_reset_pin_enable
            )
        # PT-2558 Clock Resource related
        self.add_param_dependency(
            param_a_name=RawSerdesConfigParamInfo.Id.ss_raw_bundle_mode_lane_NID.value,
            param_b_name=SWParamInfo.Id.clk_resource_en.value,
            change_func=partial(self.on_parent_param_changed,
                                condition_str='ss_raw_bundle_mode_lane_NID != "x1"')
        )
        self.build_clock_dep()

    def build_clock_dep(self):
        # Tx clock
        tx_condition = 'ss_raw_mode_lane_NID != "Rx FIFO" && (ss_raw_bundle_mode_lane_NID == "x1" || '\
            f'(ss_raw_bundle_mode_lane_NID != "x1" && {SWParamInfo.Id.clk_resource_en.value}))'

        depend_param_list = [
            RawSerdesConfigParamInfo.Id.ss_raw_bundle_mode_lane_NID,
            SWParamInfo.Id.clk_resource_en,
            RawSerdesConfigParamInfo.Id.ss_raw_mode_lane_NID,
        ]
        for param_id in depend_param_list:
            self.add_param_to_port_dependency(
                param_name=param_id.value,
                port_name='RAW_SERDES_TX_CLK',
                change_func=partial(self.on_enable_changed_pin,
                                    condition_str=tx_condition) # type: ignore
            )

            self.add_param_dependency(
                param_a_name=param_id.value,
                param_b_name=SWParamInfo.Id.tx_clk_conn_type.value,
                change_func=partial(self.on_parent_param_changed,
                                    condition_str=tx_condition)
            )

        # Rx clock
        condition1 = 'ss_raw_bundle_mode_lane_NID == "x1"'
        condition2 = f'ss_raw_bundle_mode_lane_NID != "x1" && {SWParamInfo.Id.clk_resource_en.value} &&'\
                     'ss_raw_mode_lane_NID != "Tx FIFO, Rx Register"'
        condition3 = 'ss_raw_mode_lane_NID == "Tx FIFO, Rx Register"'
        rx_condition = f'({condition1}) || ({condition2}) || ({condition3})'

        for param_id in depend_param_list:
            self.add_param_to_port_dependency(
                param_name=param_id.value,
                port_name='RAW_SERDES_RX_CLK',
                change_func=partial(self.on_enable_changed_pin,
                                    condition_str=rx_condition) # type: ignore
            )

            self.add_param_dependency(
                param_a_name=param_id.value,
                param_b_name=SWParamInfo.Id.rx_clk_conn_type.value,
                change_func=partial(self.on_parent_param_changed,
                                    condition_str=rx_condition)
            )

    def on_pll_config_key_changed(self, graph, param_group, parent_param_name: str, my_param_name: str):
        self.logger.debug(
            f"{self.inst.name}: Calling on_pll_config_key_changed {parent_param_name}-{my_param_name}")

        target_param_id: Optional[RawSerdesConfigParamInfo.Id | RawSerdesPLLParamInfo.Id] = None

        if RawSerdesConfigParamInfo.Id.has_key(my_param_name):
            target_param_id = RawSerdesConfigParamInfo.Id(my_param_name)
        else:
            assert RawSerdesPLLParamInfo.Id.has_key(my_param_name), \
                f"Invalid parameter name: {my_param_name}"
            target_param_id = RawSerdesPLLParamInfo.Id(my_param_name)

        cmn_inst = self.get_cmn_inst()
        if cmn_inst is None:
            return False

        data_rate, ref_clk_freq, data_width = self.get_preset_key()

        assert self.inst.pll_cfg is not None
        settings = self.inst.pll_cfg.get_pll_config_settings(
            data_rate, ref_clk_freq, data_width, IS_SHOW_HIDDEN)

        if settings is None:
            return False

        val = settings.get(target_param_id)
        if val is None:
            # Some params do not need to update due to different mode
            return True

        self.inst.param_group.set_param_value(target_param_id.value, val)
        return True

    def get_preset_key(self):
        cmn_inst = self.get_cmn_inst()
        assert cmn_inst is not None

        data_rate = self.inst.param_group.get_param_value(
            RawSerdesConfigParamInfo.Id.ss_raw_data_rate_lane_NID.value)
        data_width = self.inst.param_group.get_param_value(
            RawSerdesConfigParamInfo.Id.ss_raw_serdes_width_lane_NID.value)
        ref_clk_freq = cmn_inst.param_group.get_param_value(
            CommonQuadParamInfo.Id.ss_raw_refclk_freq.value)

        return data_rate, ref_clk_freq, data_width

    def on_reset_pin_enable(self, graph, param_group, parent_param_name: str, my_param_name: str):
        """
        PT-2559 Update reset pin related PCR value.

        Enable PCR when:
        1) enable reset pin
        2) resource change
            - PCIe resource [Q0, Q2] -> pcr_q0_user_phy_reset_n_expose
            - Quad resource [Q1, Q3] -> pcr_q1_user_phy_reset_n_expose
        """
        assert parent_param_name == CmnDesignParamInfo.Id.phy_reset_en.value

        cmn_inst = self.get_cmn_inst()
        if cmn_inst is None:
            return False

        quad_res = cmn_inst.get_device()
        quad_idx, _ = QuadResService.break_res_name(quad_res)

        # Do nothing if no resource
        if quad_idx == -1:
            return True

        is_reset_pin_en: bool = cmn_inst.param_group.get_param_value(parent_param_name)

        if my_param_name == CommonQuadParamInfo.Id.pcr_q0_user_phy_reset_n_expose.value:
            cmn_inst.param_group.set_param_value(my_param_name, is_reset_pin_en and quad_idx in [0, 2])

        elif my_param_name == CommonQuadParamInfo.Id.pcr_q1_user_phy_reset_n_expose.value:
            cmn_inst.param_group.set_param_value(my_param_name, is_reset_pin_en and quad_idx in [1, 3])

        else:
            return False

        return True
