'''
Copyright (C) 2013-2020 Efinix Inc. All rights reserved.

No portion of this code may be reused, modified or
distributed in any way without the expressed written
consent of Efinix Inc.

Created on Nov 9, 2020

@author: maryam
'''

import re
from enum import Enum

from typing import Dict, List, Tuple

import util.gen_util
from util.singleton_logger import Logger

from device.block_connectivity import ConnVertex, MuxConnGraph


class TopLevelClockMuxGraph:
    '''
    Class for routing at the top-level where it sees all instances of
    many clock muxes at once.
    '''

    def __init__(self):
        # A map of graph name to the Graph object. Each graph
        # object represents the clock mux representing each side
        self.graph_map = {}

        # Combined graph at the top
        self.top_graph = None

        # List of input created (overwritten each time route_inputs is called)
        self.inputs_assigned = []

        # A a map of input vertex names to vertex object itself
        self.inputs_used = {}

        # A map of graph name to the list of {input: [dest]}
        self.graph_to_inputs_conn = {}

        self._logger = Logger

        # input with dependencies on other inputs
        # which requires that they both use the same CLKMUX
        self.clk_dependencies_same_mux:  Dict[str, List[str]] = {}

    def reset_graph(self):
        '''
        This clears out the inputs_used without deleting the top-graph. It also clears out
        the result
        :return:
        '''
        self.inputs_used = {}
        self.inputs_assigned = []
        self.graph_to_inputs_conn = {}
        self.clk_dependencies_same_mux = {}

        # Go through all the graph in the list
        for graph_side in self.graph_map.values():
            if graph_side is not None:
                # reset the node var that is used in routing
                graph_side.reset_graph()

        # Temporarily for now we clear the graph and the rest of inputs
        # because it contains the connection based on input that was set
        del (self.top_graph)
        self.top_graph = None

    def add_graph(self, name, graph_obj):
        if name not in self.graph_map:
            self.graph_map[name] = graph_obj

    def add_input_connection_to_graph(self, gname, in_src, graph_vdest):
        '''

        :param gname: the graph instance name
        :param in_src: The top-level input name
        :param graph_vdest: destination node name in graph
        :return:
        '''
        if gname in self.graph_to_inputs_conn:
            input2nodes_map = self.graph_to_inputs_conn[gname]

            if in_src in input2nodes_map:
                dest_list = input2nodes_map[in_src]
                if graph_vdest not in dest_list:
                    dest_list.append(graph_vdest)

            else:
                # Create a new input to dest entry
                dest_list = [graph_vdest]
                input2nodes_map[in_src] = dest_list

            self.graph_to_inputs_conn[gname] = input2nodes_map

        else:
            # Create a new graph entry
            input2nodes_map = {}
            dest_list = [graph_vdest]
            input2nodes_map[in_src] = dest_list

            self.graph_to_inputs_conn[gname] = input2nodes_map

    def _create_top_graph(self):
        '''
        Create the top-level (device) clock mux graph
        '''
        if self.top_graph is None:
            self.top_graph = MuxConnGraph()

            for gname, gobj in self.graph_map.items():
                self.top_graph.copy_graph_with_prefix(gname, gobj)

    def _print_input_mapping(self):
        '''
        Debug function used to print out the input assignment, showing
        the connectivity of the top-level input to which mux it connects to.
        '''
        if self.graph_to_inputs_conn:
            for gname in self.graph_to_inputs_conn:

                in2nodes_map = self.graph_to_inputs_conn[gname]
                self._logger.debug("Input mapping of graph {}".format(gname))

                for in_name, dest_list in in2nodes_map.items():
                    self._logger.debug("Input name: {}".format(in_name))
                    self._logger.debug(
                        "Dest list: {}".format(",".join(dest_list)))

    def route_inputs(self, is_use_both=False):
        '''
        Function that does the routing of all inputs in the top-level graph.
        :param is_use_both: To indicate whether we should use:
                    a) True = 2 algorithm in sequence
                    (BFS and then Pathfinder if BFS still cannot route all)
                    b) False = Pathfinder only
                    * Doing this so that we maintain earlier design which were
                    already routed with the same result.

        :return:
            route_results: A map of the input name to a list that indicate the
                        routing path that the input has to go through to output
            unrouted_input: Name of inputs that were not routed
        '''

        # Force to always create top
        # is_create_top = True
        if self.top_graph is None:
            # Duplicate the all the graphs
            self._create_top_graph()

        # Debugging
        #self._print_input_mapping()
        #self._logger.debug("Print top graph")
        #self.top_graph.print_graph()

        # After creating the top, we connect the input
        # The original list in graph_to_inputs_conn might look like:
        # left: [pll_br0: pll0[0]]
        # right: [pll_br0: pll1[2]]

        # inputs2graph_map would then be (incomplete):
        # pll_br0: [left, right]

        # top graph:
        # pll_br0 -> left:pll0[0], right:pll1[2]

        # Now that we have the top_graph, we add the edges between
        # the top-level input and the next node in the uniquify
        # graphs
        uniquified_dep_clocks: Dict[str, Tuple[List[str], List[str]]] = {}

        for gname in self.graph_to_inputs_conn:
            inputs_to_node = self.graph_to_inputs_conn[gname]

            for input_node in inputs_to_node:
                next_nodes = inputs_to_node[input_node]

                other_dep_inputs = []
                # If this input is one of those with dependencies
                if input_node in self.clk_dependencies_same_mux:
                    other_dep_inputs =  self.clk_dependencies_same_mux[input_node]

                gclk_mux_inputs: List[str] = []
                # Uniquify the connection at the top to a hierarchical
                # name of the mux combined with the block pin name
                for next_name in next_nodes:
                    new_next_name = "{}.{}".format(gname, next_name)

                    new_vertex = self.top_graph.get_vertex(new_next_name)
                    if new_vertex is None:
                        raise ValueError(
                            "Unable to find node {} in top graph".format(new_next_name))

                    input_vertex = self.top_graph.get_vertex(input_node)
                    if input_vertex is None:
                        # Create it
                        input_vertex = ConnVertex(
                            input_node, ConnVertex.VertexType.input)
                        if not self.top_graph.add_vertex(input_vertex, True):
                            raise ValueError(
                                "Input node {}, has already been created".format(input_node))

                    if not self.top_graph.add_edge(input_node, new_next_name):
                        raise ValueError("Unable to add edge from {} to {}".format(
                            input_node, new_next_name))

                    if input_node not in self.inputs_used:
                        self.inputs_used[input_node] = input_vertex

                    if other_dep_inputs:
                        gclk_mux_inputs.append(new_next_name)                        

                if gclk_mux_inputs:
                    if input_node in uniquified_dep_clocks:
                        cur_gclk_in, cur_other_dep = uniquified_dep_clocks[input_node]
                        assert cur_other_dep == other_dep_inputs
                        gclk_mux_inputs = cur_gclk_in + gclk_mux_inputs

                    uniquified_dep_clocks[input_node] = (gclk_mux_inputs, other_dep_inputs)
                    self._logger.debug(f"clk dependencies on: {input_node}: {gclk_mux_inputs} dep_list: {other_dep_inputs}")
                    
        if uniquified_dep_clocks:
            self.top_graph.input_node_dependencies = uniquified_dep_clocks

        # Now that we have the top-graph built:
        # 1) Route the input
        vertex_input_list = []

        for in_name in sorted(self.inputs_used):
            in_vertex = self.inputs_used[in_name]
            vertex_input_list.append(in_vertex)

        self.inputs_assigned = vertex_input_list

        if not is_use_both:
            # Only use the Pathfinder algorithm
            route_results, unrouted_input = self.top_graph.route_input(
                vertex_input_list)

        else:
            # Start with the BFS
            route_results, unrouted_input = self.top_graph.route_input_bfs(
                vertex_input_list)

            if unrouted_input:
                # Try another approach (Pathfinder) if it is still unrouted
                route_results, unrouted_input = self.top_graph.route_input(
                    vertex_input_list)

        return route_results, unrouted_input

    def iterate_result(self, route_results, unrouted_input):
        '''
        Used to figure out the mux setting based on the routed
        input path.
        :param route_results: A map of the input name to a list that indicate the
                        routing path that the input has to go through to output
        :param unrouted_input: Name of inputs that were not routed
        :return: top_mux_to_setting_list: A map of the clock mux design object
            to the list of setting associated to it
             (ClockMuxIOSettingAdv)
        '''

        # After routing, trace through all inputs to figure out the
        # mux sel assignment
        from tx60_device.clock_mux.clkmux_design_adv import ClockMuxIOSettingAdv

        def get_node_name(vertex):
            return vertex.name

        top_mux_to_setting_list = {}

        for in_vertex in sorted(self.inputs_assigned, key=get_node_name):
            in_name = in_vertex.name

            if in_name in route_results:
                route_path = route_results[in_name]

                start = ""
                end_name = ""
                top_clkmux = ""
                io_mux_obj = None
                internal_mux_names = []

                for iter in route_path:
                    vertex_name = iter

                    if iter == in_name and start == "":
                        start = in_name
                        self._logger.debug("Start: {}".format(start))
                        io_mux_obj = ClockMuxIOSettingAdv(start)

                    else:
                        m = re.match(
                            r'^([A-Za-z0-9_\[\]]+).([A-Za-z0-9_\[\]]+)$', vertex_name)
                        if m:
                            top_clkmux, port = m.groups()

                            if start != "":
                                end_name = port
                            self._logger.debug(
                                "{} - {}".format(top_clkmux, port))
                            if io_mux_obj.input_name == "":
                                io_mux_obj.input_name = port

                        else:
                            m = re.match(
                                r'^([A-Za-z0-9_\[\]]+).([A-Za-z0-9_\[\]]+):([A-Za-z0-9_\[\]]+)$', vertex_name)
                            if m:
                                top_clkmux, mux_name, mux_index = m.groups()

                                str_index = re.match(r'(\d+)', mux_index)
                                if str_index:
                                    self._logger.debug(
                                        "{} - mux: {}, index: {}".format(top_clkmux, mux_name, mux_index))
                                    if mux_name in internal_mux_names:
                                        raise ValueError(
                                            "Error routing input {} due to repeated mux {}".format(start, mux_name))

                                    io_mux_obj.mux_assignment[mux_name] = int(
                                        mux_index)
                                    internal_mux_names.append(mux_name)

                            else:
                                raise ValueError(
                                    "Unable to tell the nodes in the routing path: {}".format(vertex_name))

                self._logger.debug(
                    "End at {}: {}".format(top_clkmux, end_name))
                io_mux_obj.output_name = end_name
                io_mux_obj.set_output_no()

                self._logger.debug(
                    "Adding {} to mux {} setting list".format(start, top_clkmux))

                if io_mux_obj is not None:
                    if top_clkmux not in top_mux_to_setting_list:
                        top_mux_to_setting_list[top_clkmux] = [io_mux_obj]
                    else:
                        setting_list = top_mux_to_setting_list[top_clkmux]
                        setting_list.append(io_mux_obj)
                        top_mux_to_setting_list[top_clkmux] = setting_list

            elif in_name not in unrouted_input:
                raise ValueError("No results on input {}".format(in_name))

        #Debugging
        #self.print_setting_results(top_mux_to_setting_list)

        return top_mux_to_setting_list

    def print_setting_results(self, top_mux_to_setting_list):
        '''
        Debug function that prints out the mux setting assignment
        :param top_mux_to_setting_list: map of
            clock mux object to the list of ClockMuxIOSettingAdv
        '''

        # Debug
        for mux_name, io_setting_list in top_mux_to_setting_list.items():
            self._logger.debug("Reading {} results: ".format(mux_name))

            for io_set_obj in io_setting_list:
                self._logger.debug(io_set_obj)