from __future__ import annotations
from decimal import Decimal
from enum import Enum, auto
from typing import TYPE_CHECKING, Dict, List, Any, Tuple, Optional
from xml.etree.ElementTree import ElementTree

from common_device.writer import SetDelayCommand, SetInputDelayCommand, SetOutputDelayCommand, SetCUCommand, write_cmds
from design.db import PeriDesign
from design.db_item import GenericClockPin, GenericPin
from device.block_definition import BusPort, Port, TimingArc
from device.block_instance import PeripheryBlockInstance
from device.db import PeripheryDevice

from tx180_device.pll.writer.timing import PLLTimingV3Complex

from util.gen_util import override

if TYPE_CHECKING:
    from typing import TextIO
    from device.block_definition import PeripheryBlock
    from tx375_device.fpll.design import EfxFpllV1, EfxFpllV1OutputClock
    from device.timing_model import Model


class EfxFpllV1Timing(PLLTimingV3Complex):
    '''
    Class for writing out the timing file for EfxFpllV1
    '''

    xml_ns = "{http://www.efinixinc.com/peri_device_db}"

    class PathType(Enum):
        input = auto()
        output = auto()

    def __init__(self, bname: str, device: PeripheryDevice, design: PeriDesign, report, sdc, max_model: Model, min_model: Model):
        super().__init__(bname=bname, device=device, design=design, report=report, sdc=sdc)
        self._max_model = max_model
        self._min_model = min_model
        self._mode2arcs = {}

    @override
    def build_arc(self):
        assert self._device is not None
        block = self._device.find_block(self._name)
        assert block is not None
        block: PeripheryBlock

        input_arcs = block.get_all_input_arcs(True) + block.get_all_input_arcs(False)
        output_arcs = block.get_all_output_arcs(True) + block.get_all_output_arcs(False)

        self._mode2arcs[self.PathType.input] = input_arcs
        self._mode2arcs[self.PathType.output] = output_arcs

    def _parse_simple_timing_table(self, table_model, blk_name):
        '''
        Parse the timing model to get the map of
        parameter name to the delay value for a specific block.
        This should be used on a simple block where it has no
        timing parameter variations.
        :param table_model: The timing model (i.e. Max/Min Model)
        :param blk_name: The block section in the model to parse
        :return A map of parameter name to the delay value after
                it is scaled
        '''

        param2delay_map = {}

        try:
            tree = ElementTree(file=table_model.get_filename())
            # print("REading file: {}".format(table_model.get_filename()))

            block_tag = ".//" + self.xml_ns + "block"

            blocksec = None
            for elem in tree.iterfind(block_tag):
                # print("{}: {}".format(elem.tag, elem.attrib))

                block_attrib = elem.attrib
                if block_attrib["type"] == blk_name:
                    blocksec = elem
                    break

            if blocksec is not None:
                # This will be the list of names without any
                # variables
                arc_param_tag = ".//" + self.xml_ns + "arc_parameter"

                for elem in blocksec.iterfind(arc_param_tag):
                    arc_attrib = elem.attrib
                    param_count = 0
                    delay_value = None

                    if "name" in arc_attrib and arc_attrib["name"] != "":
                        pname = arc_attrib["name"]
                        # Get the delay with the expectation that there is
                        # no variable parameter
                        if arc_attrib["pcount"] is not None and\
                                int(arc_attrib["pcount"]) > 0:
                            raise ValueError(
                                '{} variable pcount is expected '
                                'to be 0 instead of {}'.format(
                                    self._name, arc_attrib["pcount"]))

                        # Get the delay (children of arc_parameter)
                        param_tag = ".//" + self.xml_ns + "parameter"

                        for pelem in elem.iterfind(param_tag):
                            param_attrib = pelem.attrib

                            delay_value = float(param_attrib["delay"])
                            param_count += 1

                        if param_count == 1 and pname not in param2delay_map:

                            # We need to scale it
                            scaled_delay = Decimal(
                                delay_value * table_model.get_tscale()) / 1000

                            param2delay_map[pname] = scaled_delay
                            # self.logger.debug("Saving {} parameter {} from {} scaled {} to {}".format(
                            #    self._name, pname, delay_value, table_model.get_tscale(), scaled_delay))

                        elif param_count > 1:
                            raise ValueError(
                                '{} arc_parameter should only have one'
                                ' delay value. Instead found {}'.format(
                                    self._name, param_count))

                        elif pname in param2delay_map:
                            raise ValueError(
                                'Found duplicated arc_parameter {} in {} block'.format(
                                    pname, self._name))

                        else:
                            raise ValueError(
                                'No delay value stated for arc_parameter {}'.format(
                                    pname))

        except Exception as excp:
            # self.logger.error("Error with reading the timing common_models: {}".format(excp))
            raise excp

        return param2delay_map

    def get_arc_applied_pins(self, arc: TimingArc, gp: GenericPin) -> List[str]:
        if gp.is_bus:
            # FIXME: Update this when individual pins of a bus can be configured separately
            return [f"{gp.name}[*]"]
        else:
            return [gp.name]

    def process_setup_arc(self, arc: TimingArc,
                          source_gp: GenericClockPin,
                          sink_gp: GenericPin,
                          max_clk_network_delay: Decimal,
                          min_clk_network_delay: Decimal,
                          max_delay_map: Dict[str, Decimal],
                          min_delay_map: Dict[str, Decimal],
                          ref_pin_name: str | None = None) -> List[SetDelayCommand]:
        pin_list = self.get_arc_applied_pins(arc, sink_gp)

        delay_label = arc.get_delay()
        delay_val = max_delay_map[delay_label] - max_clk_network_delay

        cmd = SetOutputDelayCommand(clock_name=source_gp.name,
                                    pin_list=pin_list,
                                    delay=delay_val,
                                    delay_type='max',
                                    edge_type='rising' if arc.get_edge() is None else str(arc.get_edge().value),
                                    reference_pin_name=ref_pin_name)
        return [cmd]

    def process_hold_arc(self, arc: TimingArc,
                         source_gp: GenericClockPin,
                         sink_gp: GenericPin,
                         max_clk_network_delay: Decimal,
                         min_clk_network_delay: Decimal,
                         max_delay_map: Dict[str, Decimal],
                         min_delay_map: Dict[str, Decimal],
                         ref_pin_name: str | None = None) -> List[SetDelayCommand]:
        """
        Process the hold arc of the instance to generate the timing constraints in SDC
        """
        pin_list = self.get_arc_applied_pins(arc, sink_gp)

        delay_label = arc.get_delay()
        delay_val = -1 * (min_delay_map[delay_label]) - min_clk_network_delay

        cmd = SetOutputDelayCommand(clock_name=source_gp.name,
                                    pin_list=pin_list,
                                    delay=delay_val,
                                    delay_type='min',
                                    edge_type='rising' if arc.get_edge() is None else str(arc.get_edge().value),
                                    reference_pin_name=ref_pin_name)
        return [cmd]

    def process_cu_arc(self, clk_name: str) -> List[SetCUCommand]:
        # TODO: Hard code value for now
        cmd = SetCUCommand(clock_name=clk_name, mode="setup", value=0.55)
        return [cmd]

    def process_clock_to_q_arc(self, arc: TimingArc,
                               source_gp: GenericClockPin,
                               sink_gp: GenericPin,
                               max_clk_network_delay: Decimal,
                               min_clk_network_delay: Decimal,
                               max_delay_map: Dict[str, Decimal],
                               min_delay_map: Dict[str, Decimal],
                               ref_pin_name: str | None = None) -> List[SetDelayCommand]:
        """
        Process the clock to q arc of the instance to generate the timing constraints in SDC
        """
        pin_list = self.get_arc_applied_pins(arc, sink_gp)

        delay_label = arc.get_delay()
        max_delay_val = max_delay_map[delay_label] + max_clk_network_delay
        min_delay_val = min_delay_map[delay_label] + min_clk_network_delay

        max_cmd = SetInputDelayCommand(clock_name=source_gp.name,
                                       pin_list=pin_list,
                                       delay=max_delay_val,
                                       delay_type='max',
                                       edge_type='rising' if arc.get_edge() is None else str(arc.get_edge().value),
                                       reference_pin_name=ref_pin_name)

        min_cmd = SetInputDelayCommand(clock_name=source_gp.name,
                                       pin_list=pin_list,
                                       delay=min_delay_val,
                                       delay_type='min',
                                       edge_type='rising' if arc.get_edge() is None else str(arc.get_edge().value),
                                       reference_pin_name=ref_pin_name)
        return [max_cmd, min_cmd]

    @override
    def write_extra_sdc_constraint(self, sdcfile: TextIO, pll_obj: EfxFpllV1):
        # FIXME: remove hard-coded value
        max_clk_network_delay = Decimal(0)
        min_clk_network_delay = Decimal(0)

        max_delay_map = self._parse_simple_timing_table(self._max_model, self._name)
        min_delay_map = self._parse_simple_timing_table(self._min_model, self._name)

        device_inst = self._device.find_instance(pll_obj.get_device())
        assert device_inst is not None
        device_inst: PeripheryBlockInstance
        device_ports = device_inst.get_block_definition().get_all_ports()

        def is_pin_available_and_configured(pin_name: str) -> bool:
            """
            Check if the pin is available and configured by user
            """
            DYN_CFG_PINS = {
                'CFG_CLK',
                'CFG_DATA_IN',
                'CFG_DATA_OUT',
                'CFG_SEL'
            }
            if pll_obj.dyn_cfg_enable is False or pin_name not in DYN_CFG_PINS:
                return False
            user_name = pll_obj.gen_pin.get_pin_name_by_type(pin_name)
            if user_name == "":
                return False
            return True

        generated_cmds: Dict[Tuple[str, Any], List[SetDelayCommand]] = {}

        all_arcs = self._mode2arcs[self.PathType.input] + self._mode2arcs[self.PathType.output]
        for arc in all_arcs:
            assert isinstance(arc, TimingArc)
            source_pin = arc.get_source()
            sink_pin = arc.get_sink()

            if not is_pin_available_and_configured(source_pin):
                continue

            if not is_pin_available_and_configured(sink_pin):
                continue

            source_gp: Optional[GenericPin] = pll_obj.gen_pin.get_pin_by_type_name(
                source_pin)
            assert source_gp is not None
            sink_gp: Optional[GenericPin] = pll_obj.gen_pin.get_pin_by_type_name(sink_pin)
            assert sink_gp is not None

            device_sink_port = device_ports[sink_pin]
            assert device_sink_port is not None
            device_sink_port: Port | BusPort
            class_name = device_sink_port.get_class()
            assert class_name is not None

            key = (class_name, sink_pin)
            if key not in generated_cmds:
                generated_cmds[key] = []

            ref_pin_name, _ = self._get_instance_ref_pin_clkout_pin_name(
                device_inst, source_gp.type_name, source_gp.name)
            ref_pin_name = None if ref_pin_name == "" else ref_pin_name

            match arc.get_type():
                case TimingArc.TimingArcType.delay:
                    pass
                case TimingArc.TimingArcType.setup:
                    generated_cmds[key] += self.process_setup_arc(arc=arc,
                                                                  source_gp=source_gp,
                                                                  sink_gp=sink_gp,
                                                                  max_clk_network_delay=max_clk_network_delay,
                                                                  min_clk_network_delay=min_clk_network_delay,
                                                                  max_delay_map=max_delay_map,
                                                                  min_delay_map=min_delay_map,
                                                                  ref_pin_name=ref_pin_name)

                case TimingArc.TimingArcType.hold:
                    generated_cmds[key] += self.process_hold_arc(arc=arc,
                                                                 source_gp=source_gp,
                                                                 sink_gp=sink_gp,
                                                                 max_clk_network_delay=max_clk_network_delay,
                                                                 min_clk_network_delay=min_clk_network_delay,
                                                                 max_delay_map=max_delay_map,
                                                                 min_delay_map=min_delay_map,
                                                                 ref_pin_name=ref_pin_name)

                case TimingArc.TimingArcType.clk_to_q:
                    generated_cmds[key] += self.process_clock_to_q_arc(arc=arc,
                                                                       source_gp=source_gp,
                                                                       sink_gp=sink_gp,
                                                                       max_clk_network_delay=max_clk_network_delay,
                                                                       min_clk_network_delay=min_clk_network_delay,
                                                                       max_delay_map=max_delay_map,
                                                                       min_delay_map=min_delay_map,
                                                                       ref_pin_name=ref_pin_name)
                case _:
                    raise NotImplementedError(f'Unsupported arc type {arc.get_type()} {arc.get_type_str()}')

        # Rearrange the command written orders by groupping pins with same category (Which tell by the arc label)
        for _, cmds in sorted(generated_cmds.items(), key=lambda i: i[0]):
            write_cmds(sdcfile, cmds)

    def write_clock_extra_sdc_constraint(self, sdcfile: TextIO, pll_obj: EfxFpllV1,
                                         clk_obj: EfxFpllV1OutputClock):
        # DEVINFRA-908 set CU when fractional mode/ SSC enabled
        if pll_obj.fractional_mode or pll_obj.ssc_mode != "DISABLE":
            write_cmds(
                sdcfile,
                self.process_cu_arc(clk_name=clk_obj.name) # type: ignore
            )

    @override
    def is_output_clk_internal_only(self, pll_obj: EfxFpllV1, outclk_name: str) -> bool:
        clk_obj = pll_obj.get_output_clock(outclk_name)
        if clk_obj is None:
            return True
        if pll_obj.fractional_mode or pll_obj.ssc_mode != "DISABLE":
            if clk_obj.number == 1:
                return True

        # Call Ti180 to check if outclk is ddr clock
        return super().is_output_clk_internal_only(pll_obj, outclk_name)
