from __future__ import annotations
from typing import TYPE_CHECKING, Dict, List, Optional
from xml.etree.ElementTree import SubElement

from common_device.writer import LogicalPeriphery
from common_device.quad.res_service import QuadResService
from device.excp import BlockPCRDoesNotExistException
from device.pcr_device import PCRMap
from device.db_interface import DeviceDBService
from util.excp import PTInteralAssertException, pt_assert

from tx375_device.serdes_wrap.device_service import PCRDefnType, PCRModeType, SerdesWrapDeviceService


if TYPE_CHECKING:
    from xml.etree.ElementTree import Element
    from design.db import PeriDesign
    from device.block_definition import PeripheryBlock
    from tx375_device.quad_pcie.design import QuadPCIE


class SerdesWrapLogicalPeriphery(LogicalPeriphery):

    def __init__(self, device_db, name):
        '''
        Constructor
        '''

        super().__init__(device_db, name)

        # Save the PMA_PCR return value in this writer object
        self.pma_pcr_map = {}

    def generate_pcr_user_instance(self,
                                   blk_ins: str,
                                   parent: Element,
                                   design: PeriDesign,
                                   block_def: PeripheryBlock,
                                   user_ins: str):
        raise NotImplementedError

    def set_pma_pcr_map(self, pma_pcr_map):
        # Always clear before setting since this writer
        # is shared with all serdes wrap instances
        self.clear_pma_pcr_map()

        # Convert the key to all caps
        for key, value in pma_pcr_map.items():
            key_caps = key.upper()  
            if key_caps not in self.pma_pcr_map:
                if isinstance(value, str):
                    # Change to all caps
                    self.pma_pcr_map[key_caps] = value.upper()
                else:                    
                    self.pma_pcr_map[key_caps] = value 
         
    def clear_pma_pcr_map(self):
        '''
        Delete the content since it may be called multiple time
        on the same object but of different wrap instance.
        '''
        self.pma_pcr_map = {}

    def get_quad_used_resources(self, design):
        dev_ins_name_used: List[str] = []

        if design.quad_pcie_reg is not None and \
            design.quad_pcie_reg.get_inst_count() > 0:
            dev_ins_name_used.extend(design.quad_pcie_reg.get_all_device())

        # Para
        lane_reg = [design.lane_10g_reg, design.lane_1g_reg, design.raw_serdes_reg]
        for reg in lane_reg:
            if reg is not None and reg.get_inst_count() >  0:

                all_lane_inst = reg.get_all_inst()

                for lane_ins in all_lane_inst:
                    if lane_ins.get_device() != "":
                        quad_ins = QuadResService.translate_lane_res2quad_res_name(lane_ins.get_device())

                        if quad_ins != "" and quad_ins not in dev_ins_name_used:
                            dev_ins_name_used.append(quad_ins)
                        
        return dev_ins_name_used
        
    def is_quad_configured(self, design, blk_ins):
        '''
        If there is a quad configured then pcr_pma_nisoen has to be set to 1
        '''
        is_found = False

        # SERDES_WRAP is associated to quad pair:
        # 0 - Q0,Q1
        # 1 - Q2,Q3
        dev_ins_name_used = self.get_quad_used_resources(design)
        
        if len(dev_ins_name_used) > 0:
            if blk_ins == "SERDES_WRAP_0":
                if "QUAD_0" in dev_ins_name_used or "QUAD_1" in dev_ins_name_used:
                    is_found = True

            else:
                if "QUAD_2" in dev_ins_name_used or "QUAD_3" in dev_ins_name_used:
                    is_found = True

        return is_found
        
    def generate_default_pcr(self, blk_ins: str, parent: Element, design: PeriDesign, block_def: PeripheryBlock):
        pcr_block = block_def.get_block_pcr()
        pt_assert(pcr_block is not None, f'Block {self._name} does not have PCR defined', BlockPCRDoesNotExistException)

        ins_element = SubElement(parent, 'efxpt:instance', id=blk_ins, type=self._name)
        pcr_map = pcr_block.get_pcr_list()
        for pcr_name, pcr_obj in sorted(pcr_map.items()):
            # Use default unless it is part of the returned
            # parameter from the ICD script
            if pcr_name in self.pma_pcr_map:
                # We don't check for bitvec since it's not expected to be from bitvec
                pt_assert(pcr_obj.get_type() != PCRMap.PCRMapType.bitvec,
                          f'Block PCR {blk_ins} has {pcr_name} with unexpected type PCRMapType.bitvec',
                          PTInteralAssertException)

                # Get the value from map and if it's integer
                # check that it's within range and convert to string
                pcr_value = self.pma_pcr_map[pcr_name]

                if not isinstance(pcr_value, str):
                    # This could be integer. 
                    if pcr_obj.get_type() == PCRMap.PCRMapType.int:
                        min_val = pcr_obj.get_min()
                        max_val = pcr_obj.get_max()

                        if pcr_value > max_val or pcr_value < min_val:
                            raise PTInteralAssertException(
                                f"PCR for {blk_ins} Parameter {pcr_name} out-of-range(min={min_val}, max={max_val}")
                        
                        # Convert to string
                        value = str(pcr_value)

                    else:
                        # PT-2148: This is if we found that we should
                        # set the PCR mode type. The ICD script would return
                        # binary represenatation and we need to map it to the
                        # mode name in the lpf file
                        mode2val_map = pcr_obj.get_all_modes_map()
                        if pcr_value not in mode2val_map:
                            # Need to convert the value to the mode name
                            conv_value = self.get_pcr_mode_name_str(pcr_obj, pcr_value, mode2val_map)
                            if conv_value != "":
                                value = conv_value
                            else:
                                raise PTInteralAssertException(
                                    f"PCR for {blk_ins} Parameter {pcr_name} has unmapped PCR Mode value {pcr_value}")

                else:
                    value = pcr_value

                    if pcr_obj.get_type() == PCRMap.PCRMapType.mode:
                        pt_assert(pcr_obj.is_mode_exists(value),
                                  f'Invalid PCR for {blk_ins} Mode {pcr_name} with value {value}',
                                  PTInteralAssertException)

            elif pcr_name == PCRDefnType.PCR_PMA_NISOEN.value and\
                    self.is_quad_configured(design, blk_ins):
                
                mode_value = PCRModeType.DISABLE.value 
                pt_assert(pcr_obj.is_mode_exists(mode_value),
                          f'Invalid PCR for {blk_ins} Mode {pcr_name} with value {mode_value}',
                          PTInteralAssertException)
                value = mode_value

            # For PCIe resource
            elif pcr_name == PCRDefnType.PCR_Q0_PLL_REFCLK_SRC.value and \
                self.is_quad_configured(design, blk_ins):
                value = self.get_pcr_q0_pll_ref_clk_src_value(design, blk_ins)

            else:
                value = pcr_obj.get_default()
            
            pt_assert(value is not None, f'Block {self._name} PCR {pcr_name} does not have default value defined', PTInteralAssertException)
            pt_assert(isinstance(value, str), f'Block {self._name} PCR {pcr_name} default value not str, got {type}', PTInteralAssertException)
            SubElement(ins_element, 'efxpt:parameter', name=pcr_name, value=value)

    def get_pcr_q0_pll_ref_clk_src_value(self, design: PeriDesign, blk_ins: str):
        # Check PCIe resource only
        check_resource = "QUAD_0" if blk_ins == "SERDES_WRAP_0" else "QUAD_2"
        pcie_reg = design.quad_pcie_reg
        value = PCRModeType.DISABLE.value

        if pcie_reg is None:
            return value

        pcie_inst: Optional[QuadPCIE] = pcie_reg.get_inst_by_device_name(check_resource)
        if pcie_inst is None:
            return value

        dbi = DeviceDBService(design.device_db)
        dev_service = dbi.get_block_service(DeviceDBService.BlockType.QUAD_PCIE)
        pll_inst, refclk_name = pcie_inst.get_external_refclk_related_pll(design, dev_service)

        if pll_inst is None or refclk_name == "":
            return value

        match refclk_name:
            case "PMA_CMN_REFCLK_PLL_1":
                value = PCRModeType.I1.value
            case "PMA_CMN_REFCLK_PLL_2":
                value = PCRModeType.I2.value

        return value

    def get_pcr_mode_name_str(self, pcr_obj: PCRMap, pcr_value: int, pcr_map: Dict[str, str]):
        '''
        If the pcr value retruned by the ICD script is not matching the
        pcr mode name, then we convert it to the mode name by comparing
        the value with the binary value of the PCR Mode.  This is only
        applicable to PCR of type PCRMapType.mode only (pre-requisite).

        :param pcr_obj: PCRMap instance
        :param pcr_value: PCR value returned by ICD script
        :return a string of PCR mode name. Or an empty string if
                it cannot map the pcr_value to the mode value
        '''
        mode_str = ""
        pt_assert(isinstance(pcr_value, int), 
                  f'Block {self._name} PCR {pcr_obj.get_name()} passed value {pcr_value} is not integer', PTInteralAssertException)
            
        for mode_name, mode_val in pcr_map.items():
            pt_assert(self.is_string_binary(mode_val),
                      f'Block {self._name} PCR {pcr_obj.get_name()} with {mode_name}:{mode_val} - value is not binary', PTInteralAssertException)
            
            val_int = int(mode_val, 2)

            if val_int == pcr_value:
                mode_str = mode_name
                break

        return mode_str

    def is_string_binary(self, check_str: str):
        try:
            # this will raise value error if
            # string is not of base 2
            int(check_str, 2)
        except ValueError:
            return False
        return True