from __future__ import annotations
import os
import sys
import abc
import enum

from typing import List, Dict, TYPE_CHECKING, Final

import util.gen_util as pt_util
from util.excp import PTInteralAssertException, pt_assert

import common_device.writer as tbi

from common_device.quad.res_service import QuadResService
from tx375_device.common_quad.design import QuadLaneCommon

import device.block_definition as dev_blk
from design.db_item import GenericClockPin


if TYPE_CHECKING:
    import io
    from device.block_definition import BlockMode, PeripheryBlock
    from common_device.quad.lane_design import LaneBasedItem

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))


class QuadSERDESTiming(tbi.BlockTimingWithArcs):
    '''
    Builds the timing data for PCIE used for printing out
    timing report and sdc file.
    '''

    xml_ns = "{http://www.efinixinc.com/peri_device_db}"

    class PathType(enum.Enum):
        output = 0
        input = 1

    CLOCK_PRECISION: Final[int] = 3

    def __init__(self, bname, device, design, report, sdc,
                 max_model, min_model):
        '''
        Constructor
        '''
        super().__init__(bname, device, design, report, sdc,
                         max_model, min_model)

        # We're building it a bit different since there's no
        # really a mode for this block. The mode is actually
        # to differentiate between input and outupt arcs
        self._mode2arcs = {}

        self.common_quad_reg = None

    def build_arc(self):
        """
        Build Timing Arcs
        """
        # Nothing to be done
        pass

    def get_blk2mode_map(self, blk_list: List):
        # Although each block has its own mode and timing secttion.
        # The general rule is that they are all identical. So, we should
        # use one mode from each block
        #raw_serdes_mode_names = ["LN0_RAW_SERDES"]

        blk2modes_map: Dict[PeripheryBlock, BlockMode] = {}

        for blk in blk_list:
            mode2obj_map: Dict[str, BlockMode] = {}

            mode_name = self.get_mode_name()
            blk_mode = blk.get_mode(mode_name)

            if blk_mode is None:
                raise ValueError(f"Unable to find {blk.get_name()} mode {mode_name}")

            mode2obj_map[mode_name] = blk_mode

            # Just take the first mode
            blk2modes_map[blk] = mode2obj_map[mode_name]

        return blk2modes_map

    def get_instance_based_on_blk(self, blk, ins_name2obj_map):
        blk_ins_name2obj_map = {}

        pcie_quad_names = ["Q0", "Q2"]
        blk_name = blk.get_name()

        for ins, ins_obj in ins_name2obj_map.items():

            pt_assert(ins_obj.get_device() != "", f'{self.get_user_blk_name} instance {ins} with resource unassigned',
                      PTInteralAssertException)

            if (blk_name == "quad_pcie" and ins_obj.get_device().startswith(tuple(pcie_quad_names))) or \
                (blk_name == "quad" and (not ins_obj.get_device().startswith(tuple(pcie_quad_names)))):
                blk_ins_name2obj_map[ins] = ins_obj

        return blk_ins_name2obj_map

    def build_mode_based_arc(self, blk_mode):
        """
        Build Timing Arcs based on specified mode
        """
        self._mode2arcs = {}

        input_arcs = blk_mode.get_all_input_arcs(True)
        output_arcs = blk_mode.get_all_output_arcs(True)

        self._mode2arcs[self.PathType.input] = input_arcs
        self._mode2arcs[self.PathType.output] = output_arcs

    def write_instance(self, sdcfile,
                       max_clk, min_clk, ins_obj,
                       max_param2delay_map,
                       min_param2delay_map,
                       blk_mode, common_inst=None):

        # Write the create_clock constraint
        self.write_create_clock(sdcfile, ins_obj)

        # Get the available pins (only those that are relevant to current setting
        used_gpin_tname_list = ins_obj.get_available_pins()

        # Add up the common instance pins as well
        common_gpin_name_list = []
        if common_inst is not None:
            common_gpin_name_list = common_inst.get_available_pins_by_quad_type(ins_obj.quad_type, ins_obj.name)

        self.write_in_arc_adv(sdcfile, ins_obj, max_param2delay_map,
                              min_param2delay_map, max_clk, min_clk,
                              used_gpin_tname_list, common_inst, common_gpin_name_list, blk_mode)

        self.write_out_arc_adv(sdcfile, ins_obj, max_param2delay_map,
                               min_param2delay_map, max_clk, min_clk,
                               used_gpin_tname_list, common_inst, common_gpin_name_list, blk_mode)

    def check_skip_arc_sink(self, ins_obj, sink_name, clk_name,
                        used_gpin_tname_list, common_gpin_name_list):
        if sink_name not in used_gpin_tname_list and\
            sink_name not in common_gpin_name_list:
            return True

        return False

    def get_arc_info(self, arc, used_gpin_tname_list, common_gpin_name_list,
                     ins_obj, common_ins):
        skip_sink = False

        sink_name = arc.get_sink()
        clk_name = arc.get_source()
        clk_str = ""

        # Skip if:
        # 1) Sink name not part of the available pin (not configured then)
        skip_sink = self.check_skip_arc_sink(ins_obj, sink_name, clk_name,
                        used_gpin_tname_list, common_gpin_name_list)
        
        # if sink_name not in used_gpin_tname_list and\
        #     sink_name not in common_gpin_name_list:
        #     skip_sink = True

        if not skip_sink:

            # If the clock is empty but pin name valid, we
            # still print the constraint but commented out as
            # a template with clock name to be filled by user
            user_clk = self.get_user_gen_pin_name(ins_obj, common_ins, clk_name)
            if user_clk == "":
                # Use a reserve clock template
                user_clk = "<CLOCK>"

            if self.is_clock_inverted(ins_obj, common_ins, clk_name):
                clk_str = "-clock_fall -clock " + user_clk
            else:
                clk_str = "-clock " + user_clk

        return skip_sink, clk_name, clk_str

    def write_in_arc_adv(self, sdcfile, ins_obj, max_param2delay_map,
                         min_param2delay_map, max_clk, min_clk, 
                         used_gpin_tname_list, common_ins, common_gpin_name_list, blk_mode):

        # Writes out the input arcs first (Tsetup/Thold)
        in_arcs = self._mode2arcs[self.PathType.input]

        for arc in in_arcs:

            skip_sink, clk_name, clk_str = self.get_arc_info(
                arc, used_gpin_tname_list, common_gpin_name_list,
                ins_obj, common_ins)
            
            if skip_sink:
                continue

            self.write_sdc_out_delay(sdcfile, ins_obj, arc, max_param2delay_map,
                                     min_param2delay_map, max_clk, min_clk, clk_str,
                                     clk_name, blk_mode, common_ins, common_gpin_name_list)

    def write_out_arc_adv(self, sdcfile, ins_obj, max_param2delay_map,
                          min_param2delay_map, max_clk, min_clk,
                          used_gpin_tname_list, common_ins, common_gpin_name_list, blk_mode):

        out_arcs = self._mode2arcs[self.PathType.output]

        for arc in out_arcs:

            skip_sink, clk_name, clk_str = self.get_arc_info(
                arc, used_gpin_tname_list, common_gpin_name_list,
                ins_obj, common_ins)
            
            if skip_sink:
                continue

            self.write_sdc_in_delay(sdcfile, ins_obj, arc, max_param2delay_map,
                                    min_param2delay_map, max_clk, min_clk, clk_str,
                                    clk_name, blk_mode, common_ins, common_gpin_name_list)

    def get_user_gen_pin_name(self, ins_obj, common_ins, port_name):
        '''
        This could be also a common instance pin name.
        So search the common instance if cannot find at the block.

        :param ins_obj: User instance
        :param port_name: The port name
        :return the User specified pin name. empty string if not
                specified.
        '''
        pin_name = ins_obj.gen_pin.get_pin_name_by_type(port_name)
        if pin_name == "":
            # This could be one of the skip clock ports which
            # has a corresponding gpin to get the name from
            in_clk_type_name = ins_obj.get_clk_gpin_type_name(port_name)

            if in_clk_type_name != "":
                # This could be clock that is from other lane (ie bundle
                # mode in raw serdes))
                pin_name = self.get_user_clock_pin_name_for_timing(ins_obj, in_clk_type_name)
                
            if pin_name == "" and common_ins is not None:
                pin_name = common_ins.gen_pin.get_pin_name_by_type(port_name)

        return pin_name
    
    def get_user_clock_pin_name_for_timing(self, ins_obj, input_clock_port_name: str):
        return ins_obj.gen_pin.get_pin_name_by_type(input_clock_port_name)

    def is_clock_inverted(self, ins_obj, cmn_ins, clk_name):
        assert ins_obj.gen_pin is not None
        # Find the pin with the name
        gen_pin = ins_obj.gen_pin.get_pin_by_type_name(clk_name)

        if gen_pin is not None:
            # CHeck if it is a clock pin
            if isinstance(gen_pin, GenericClockPin):
                return gen_pin.is_inverted
            
        elif cmn_ins is not None:
            cmn_pin = cmn_ins.gen_pin.get_pin_by_type_name(clk_name)
            if cmn_pin is not None:
                # CHeck if it is a clock pin
                if isinstance(cmn_pin, GenericClockPin):
                    return cmn_pin.is_inverted
                
        return False   


    def get_constraint_pin_name(self, sink_name, common_gpin_name_list,
                                common_ins, ins_obj, sink_inf):
        # Check if the sink is a bus
        pin_name = ""

        if sink_name in common_gpin_name_list:
            pin_name = common_ins.gen_pin.get_pin_name_by_type(sink_name)
        else:
            pin_name = ins_obj.gen_pin.get_pin_name_by_type(sink_name)
       
        # If the interface is a bus
        if sink_inf is not None and sink_inf.is_bus_port() and pin_name != "":
            pin_name = "{}[*]".format(pin_name)

        return pin_name
    
    def write_sdc_out_delay(self, sdcfile, ins_obj, arc, max_param2delay_map,
                            min_param2delay_map, max_clk, min_clk, clk_str,
                            clk_port_name, blk_mode, common_ins, common_gpin_name_list):

        param_name = arc.get_delay()
        sink_name = arc.get_sink()
        sink_inf = blk_mode.get_interface_object(sink_name)

        pin_name = self.get_constraint_pin_name(sink_name, common_gpin_name_list,
                                common_ins, ins_obj, sink_inf)

        if pin_name != "":
            self.write_out_delay_constraints(
                sdcfile, ins_obj, arc, max_param2delay_map, min_param2delay_map,
                param_name, max_clk, min_clk, clk_str, pin_name, clk_port_name)

    def write_out_delay_constraints(self, sdcfile, ins_obj, arc, max_param2delay_map, min_param2delay_map,
                                    param_name, max_clk, min_clk, clk_str, pin_name,
                                    clk_port_name):

        # Only and the ref pin if user provided a clock name
        # Get the last string which is the clk name
        tmp_clk_list = clk_str.split(" ")
        assert len(tmp_clk_list) > 0
        user_clk = tmp_clk_list[len(tmp_clk_list) - 1]

        ref_arg_str = self.get_ins_clkout_ref(ins_obj, clk_port_name, user_clk)

        if arc.get_type() == dev_blk.TimingArc.TimingArcType.setup:
            if param_name in max_param2delay_map:
                param_delay = max_param2delay_map[param_name]
                # Subtract clock core delay
                max_delay = "{0:.3f}".format(param_delay - max_clk)

                sdcfile.write(
                    "set_output_delay {}{} -max {} [get_ports {{{}}}]\n".format(
                        clk_str, ref_arg_str, max_delay, pin_name))

        elif arc.get_type() == dev_blk.TimingArc.TimingArcType.hold:
            if param_name in min_param2delay_map:
                param_delay = min_param2delay_map[param_name]
                # Subtract clock core delay
                min_delay = "{0:.3f}".format((-1 * param_delay) - min_clk)

                sdcfile.write(
                    "set_output_delay {}{} -min {} [get_ports {{{}}}]\n".format(
                        clk_str, ref_arg_str, min_delay, pin_name))

    def write_sdc_in_delay(self, sdcfile, ins_obj, arc, max_param2delay_map,
                           min_param2delay_map, max_clk, min_clk, clk_str,
                           clk_port_name, blk_mode, common_ins, common_gpin_name_list):

        param_name = arc.get_delay()
        sink_name = arc.get_sink()
        sink_inf = blk_mode.get_interface_object(sink_name)

        pin_name = self.get_constraint_pin_name(sink_name, common_gpin_name_list,
                                common_ins, ins_obj, sink_inf)

        if pin_name != "" and arc.get_type() == dev_blk.TimingArc.TimingArcType.clk_to_q:
            self.write_in_delay_constraints(sdcfile, ins_obj, arc, max_param2delay_map, min_param2delay_map,
                                            param_name, max_clk, min_clk, clk_str, pin_name,
                                            clk_port_name)

    def write_in_delay_constraints(self, sdcfile, ins_obj, arc, max_param2delay_map, min_param2delay_map,
                                   param_name, max_clk, min_clk, clk_str, pin_name,
                                   clk_port_name):

        # Only and the ref pin if user provided a clock name
        # Get the last string which is the clk name
        tmp_clk_list = clk_str.split(" ")
        assert len(tmp_clk_list) > 0
        user_clk = tmp_clk_list[len(tmp_clk_list) - 1]

        ref_arg_str = self.get_ins_clkout_ref(ins_obj, clk_port_name, user_clk)

        if param_name in max_param2delay_map:
            param_delay = max_param2delay_map[param_name]
            # Subtract clock core delay
            max_delay = "{0:.3f}".format(param_delay + max_clk)

            sdcfile.write(
                "set_input_delay {}{} -max {} [get_ports {{{}}}]\n".format(
                    clk_str, ref_arg_str, max_delay, pin_name))

        if param_name in min_param2delay_map:
            param_delay = min_param2delay_map[param_name]
            # Subtract clock core delay
            min_delay = "{0:.3f}".format(param_delay + min_clk)

            sdcfile.write(
                "set_input_delay {}{} -min {} [get_ports {{{}}}]\n".format(
                    clk_str, ref_arg_str, min_delay, pin_name))

    def get_ins_clkout_ref(self, ins_obj, clk_port_name, user_clk):
        '''
        This is used to filter out if the reference pin is required or
        not. If it is not required than the return string is always empty.
        '''
        ref_pin_arg = self.set_clkout_ref_pin(
            ins_obj, self.get_clockout_block_name(clk_port_name, ins_obj), 
            user_clk, ins_obj.get_device())

        return ref_pin_arg

    def write_output(self, index, ins_to_block_map, clk2delay_map,
                     max_param2delay_map, min_param2delay_map, 
                     ins_name2obj_map, blk_mode, is_start):

        '''
        :param ins_name2obj_map: This only contain 10g instances
                that are block specific
        :return True if there was any instance
                that got printed
        '''
        pt_util.mark_unused(index)

        ins_written = False
        sdcfile = None
        write_successful = None

        try:
            sdcfile = self.get_sdc_file_open()

            write_successful = False

            valid_ins = 0

            # Get the clock network delay
            max_clk = clk2delay_map[self.DelayType.max]
            min_clk = clk2delay_map[self.DelayType.min]

            # While iterating, keep track of the quad resource
            # so that we know which one to iterate through on
            # the common registry
            common_des_instances: List[QuadLaneCommon] = []

            for ins_name in sorted(ins_name2obj_map.keys(),
                                   key=pt_util.natural_sort_key_for_list):
                ins_obj = ins_name2obj_map[ins_name]

                if ins_obj is None:
                    continue

                # Work on instance that are part of device
                if ins_obj.get_device() in ins_to_block_map:
                
                    if valid_ins == 0 and is_start:

                        sdcfile.write("\n# {} Constraints\n".format(self.get_user_blk_name()))
                        sdcfile.write("###########################\n")

                        valid_ins += 1

                    assert self.common_quad_reg is not None
                    common_inst = self.common_quad_reg.get_inst_by_device_name(ins_obj.get_device())
                    assert isinstance(common_inst, QuadLaneCommon)

                    if common_inst in common_des_instances:
                        # Nullfieid it since we already write out the
                        # common instance in the first visit
                        common_inst = None
                    else:
                        common_des_instances.append(common_inst)

                    self.write_instance(
                        sdcfile,
                        max_clk, min_clk, ins_obj,
                        max_param2delay_map,
                        min_param2delay_map,
                        blk_mode, common_inst)
                                       
            ins_written = True

            self.close_files([sdcfile])

            write_successful = True 

        except Exception as excp:
            if write_successful is not None and\
                    not write_successful and sdcfile is not None:
                self.close_files([sdcfile])
            raise excp

        return ins_written
       
    def get_sdc_file_open(self):
        return open(self._sdc_file, 'a')

    def close_files(self, file_list: List[io.TextIOWrapper]):
        for file in file_list:
            file.close()

    def get_clockout_block_name(self, clock_gpin_name, ins_obj=None):
        '''
        Translation of the interface port name to the clkout
        port name so that we can get the -reference_pin on the
        clkout interface.

        :param clock_gpin_name: the clock port type name associated to an arc.
                It could be either the input clock or the clkout type name (based
                on what the arc specified as its clock domain name)
        :return the corresponding pin at the instance based on the mode
                used since the interface port maps to different block port:
                ie PCS_CLK_TX maps to PCS_CLK_TX0 (LN0)
                                      PCS_CLK_TX1 (LN1)
        '''
        clk_pname = clock_gpin_name
        if ins_obj is not None and \
            self.is_pin_type_name_clkout(clock_gpin_name, ins_obj):
            if ins_obj.get_device() != "":

                def get_quad_lane_idx(res_name):
                    _, lane_idx = QuadResService.break_res_name(res_name)

                    return lane_idx
                
                lane_idx = get_quad_lane_idx(ins_obj.get_device())
                if lane_idx != -1:
                    # Add the lane index to the end of the port name
                    # since that's the actual port name used in the block
                    clk_pname = f"{clock_gpin_name}{lane_idx}"
        
        return clk_pname  

    @abc.abstractmethod
    def write_create_clock(self, sdcfile: io.TextIOWrapper, ins_obj: LaneBasedItem):
        '''
        Writing out clock that are output from the block to core

        :param sdcfile: The FP for the SDC file
        :param ins_obj: The lane-based design instance
        :return: Nothing but write out information into sdcfile
        '''
        pass

    @abc.abstractmethod
    def get_quad_type(self):
        pass
        
    @abc.abstractmethod
    def get_mode_name(self) -> str:
        pass
    
    @abc.abstractmethod
    def get_user_blk_name(self):
        pass

    @abc.abstractmethod
    def is_pin_type_name_clkout(self, type_name: str, ins_obj: LaneBasedItem):
        pass
    
