from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, List
from xml.etree.ElementTree import SubElement

from common_device.writer import LogicalPeriphery
from design.db_item import GenericParamService

from tx375_device.soc.design_param_info import SOCParamId
from tx375_device.soc.device_service import PCRDefnType, PCRModeType
from device.excp import BlockPCRDoesNotExistException, ConfigurationInvalidException
from util.excp import PTInteralAssertException, pt_assert
from util.bin_util import is_intel_hex, load_generic_bin, load_intel_hex

if TYPE_CHECKING:
    from xml.etree.ElementTree import Element
    from design.db import PeriDesign
    from device.block_definition import PeripheryBlock
    from device.pcr_device import BlockPCR


class SOCLogicalPeriphery(LogicalPeriphery):
    """
    For generating SOC's PCR (LPF)
    """

    def __generate_user_instance_pcr_seq_group(self, blk_ins: str,
                                               parent: Element,
                                               design: PeriDesign,
                                               block_def: PeripheryBlock,
                                               pcr_block: BlockPCR,
                                               user_design_obj: SOC):
        "SOC_PCR"
        type_name = f"{self._name}:{self._seq_group_name_to_print}"
        pcr_map = pcr_block.get_pcr_list(self._seq_group_name_to_print)

        ins_element = SubElement(parent, 'efxpt:instance', id=blk_ins, type=type_name)
        param_svc = GenericParamService(user_design_obj.param_group, user_design_obj.param_info)
        for pcr_name, pcr_obj in sorted(pcr_map.items()):
            value = pcr_obj.get_default()
            pt_assert(value is not None, f'Block {self._name} PCR {pcr_name} does not have default value defined', PTInteralAssertException)
            pt_assert(isinstance(value, str), f'Block {self._name} PCR {pcr_name} default value not str, got {type}', PTInteralAssertException)

            if pcr_name == PCRDefnType.PCR_SOC_DDR_EN.value:
                # Always enable when SOC is configured
                value = PCRModeType.ENABLE.value
            elif pcr_name == PCRDefnType.PCR_MEMORYCLK_SEL.value:
                design_value = param_svc.get_param_value(SOCParamId.MEM_CLK_SOURCE)
                pt_assert(param_svc.check_param_valid(SOCParamId.MEM_CLK_SOURCE) == (True, ""),
                          "User instance's memory clock source invalids", ConfigurationInvalidException)
                value = {
                    "Unassign": PCRModeType.DISABLE.value,
                    "Clock 0": PCRModeType.PLL0.value,
                    "Clock 1": PCRModeType.PLL1.value,
                    "Clock 2": PCRModeType.PLL2.value,
                }[design_value]
            elif pcr_name == PCRDefnType.PCR_SYSTEMCLK_SEL.value:
                design_value = param_svc.get_param_value(SOCParamId.SYS_CLK_SOURCE)
                pt_assert(param_svc.check_param_valid(SOCParamId.SYS_CLK_SOURCE) == (True, ""),
                          "User instance's memory clock source invalids", ConfigurationInvalidException)
                value = {
                    "Unassign": PCRModeType.DISABLE.value,
                    "Clock 0": PCRModeType.PLL0.value,
                    "Clock 1": PCRModeType.PLL1.value,
                    "Clock 2": PCRModeType.PLL2.value,
                }[design_value]
            elif pcr_name == PCRDefnType.PCR_DDR2SOC_PIPE.value or pcr_name == PCRDefnType.PCR_SOC2DDR_PIPE.value:
                design_value = param_svc.get_param_value(SOCParamId.PIPELINE_SOC_AXI_MEM_INTERFACE_EN)
                value = {
                    False: PCRModeType.DISABLE.value,
                    True: PCRModeType.ENABLE.value,
                }[design_value]
            elif pcr_name == PCRDefnType.PCR_SOC_NISOEN.value:
                # always disable power isolation when SOC is configured
                value = PCRModeType.DISABLE.value

            SubElement(ins_element, 'efxpt:parameter', name=pcr_name, value=value)

    def bytearray_to_pcr_segments(self, bdata: bytearray, chain_length: int, num_chain: int) -> List[bytearray]:
        """
        Convert bytearray to PCR segments
        """
        segments_data = []
        for _ in range(0, num_chain):
            segments_data.append(bytearray(chain_length))

        # Each bit of a bytes is distributed to each chain
        # Example:
        # Data = 0xAB = 0b1010_1011
        # PCR_UNSUPPORTED0 = 0b1
        # PCR_UNSUPPORTED1 = 0b1
        # PCR_UNSUPPORTED2 = 0b0
        # PCR_UNSUPPORTED3 = 0b1
        # PCR_UNSUPPORTED4 = 0b0
        # PCR_UNSUPPORTED5 = 0b1
        # PCR_UNSUPPORTED6 = 0b0
        # PCR_UNSUPPORTED7 = 0b1
        for bin_idx, byte in enumerate(bdata):
            seg_idx = bin_idx // num_chain
            bit_idx = num_chain - (bin_idx % num_chain) - 1
            for curr_seg_idx in range(len(segments_data)):
                if byte & (0x1 << curr_seg_idx) > 0:
                    segments_data[curr_seg_idx][seg_idx] = segments_data[curr_seg_idx][seg_idx] | (0x1 << bit_idx)

        return segments_data

    def __generate_user_instance_cfg_seq_group(self, blk_ins: str,
                                               parent: Element,
                                               design: PeriDesign,
                                               block_def: PeripheryBlock,
                                               pcr_block: BlockPCR,
                                               user_design_obj: SOC):
        "SOC_CFG"
        param_svc = GenericParamService(user_design_obj.param_group, user_design_obj.param_info)
        ocr_file_path = param_svc.get_param_value(SOCParamId.OCR_FILE_PATH)
        if ocr_file_path == "":
            self.generate_default_pcr(blk_ins, parent, design, block_def)
            return
        
        type_name = f"{self._name}:{self._seq_group_name_to_print}"
        pcr_map = pcr_block.get_pcr_list(self._seq_group_name_to_print)
        ins_element = SubElement(parent, 'efxpt:instance', id=blk_ins, type=type_name)
    
        ocr_file_path = Path(ocr_file_path)
        if ocr_file_path.is_absolute():
            pt_assert(ocr_file_path.exists(), 'On-Chip-Ram file not exist', ConfigurationInvalidException)
        else:
            if not ocr_file_path.exists():
                # Try relative path from design location
                ocr_file_path = Path(design.location) / ocr_file_path

            pt_assert(ocr_file_path.exists(), 'On-Chip-Ram file not exist', ConfigurationInvalidException)

        if is_intel_hex(ocr_file_path):
            bdata = load_intel_hex(ocr_file_path, count=16*1024)
        else:
            bdata = load_generic_bin(ocr_file_path)

        if len(bdata) < 16384:
            # Fill the rest with zero
            bdata.extend(bytearray(16384 - len(bdata)))

        NUM_CHAINS = 8
        BYTES_PER_CHAIN = 2048
        segments_data = self.bytearray_to_pcr_segments(bdata, BYTES_PER_CHAIN, NUM_CHAINS)

        for pcr_name, pcr_obj in sorted(pcr_map.items()):
            value = pcr_obj.get_default()
            pt_assert(value is not None, f'Block {self._name} PCR {pcr_name} does not have default value defined', PTInteralAssertException)
            pt_assert(isinstance(value, str), f'Block {self._name} PCR {pcr_name} default value not str, got {type}', PTInteralAssertException)

            seg_idx = {
                PCRDefnType.PCR_UNSUPPORTED0.value: 0,
                PCRDefnType.PCR_UNSUPPORTED1.value: 1,
                PCRDefnType.PCR_UNSUPPORTED2.value: 2,
                PCRDefnType.PCR_UNSUPPORTED3.value: 3,
                PCRDefnType.PCR_UNSUPPORTED4.value: 4,
                PCRDefnType.PCR_UNSUPPORTED5.value: 5,
                PCRDefnType.PCR_UNSUPPORTED6.value: 6,
                PCRDefnType.PCR_UNSUPPORTED7.value: 7,
            }[pcr_name]
            seq_data = segments_data[seg_idx]
            value = seq_data.hex().upper()
            SubElement(ins_element, 'efxpt:parameter', name=pcr_name, value=value)

    def generate_pcr_user_instance(self,
                                   blk_ins: str,
                                   parent: Element,
                                   design: PeriDesign,
                                   block_def: PeripheryBlock,
                                   user_ins: str):
        soc_reg = design.soc_reg
        pt_assert(soc_reg is not None, "No SOC registry in this design", PTInteralAssertException)

        pcr_block = block_def.get_block_pcr()
        pt_assert(pcr_block is not None, f'Block {self._name} does not have PCR defined', BlockPCRDoesNotExistException)

        design_obj = soc_reg.get_inst_by_name(user_ins)
        pt_assert(design_obj is not None, f'Cannot find User instance {user_ins}', PTInteralAssertException)

        if self._seq_group_name_to_print == "SOC_PCR":
            self.__generate_user_instance_pcr_seq_group(blk_ins, parent, design, block_def, pcr_block, design_obj)
        elif self._seq_group_name_to_print == "SOC_CFG":
            self.__generate_user_instance_cfg_seq_group(blk_ins, parent, design, block_def, pcr_block, design_obj)
        else:
            raise PTInteralAssertException(f"Unhandled seq group: {self._seq_group_name_to_print}")

    def generate_default_pcr(self, blk_ins: str, parent: Element, design: PeriDesign, block_def: PeripheryBlock):
        pcr_block = block_def.get_block_pcr()
        pt_assert(pcr_block is not None, f'Block {self._name} does not have PCR defined', BlockPCRDoesNotExistException)

        if self._seq_group_name_to_print != "":
            type_name = f"{self._name}:{self._seq_group_name_to_print}"
            pcr_map = pcr_block.get_pcr_list(self._seq_group_name_to_print)
        else:
            type_name = self._name
            pcr_map = pcr_block.get_pcr_list()

        ins_element = SubElement(parent, 'efxpt:instance', id=blk_ins, type=type_name)
        for pcr_name, pcr_obj in sorted(pcr_map.items()):
            value = pcr_obj.get_default()
            pt_assert(value is not None, f'Block {self._name} PCR {pcr_name} does not have default value defined', PTInteralAssertException)
            pt_assert(isinstance(value, str), f'Block {self._name} PCR {pcr_name} default value not str, got {type}', PTInteralAssertException)
            SubElement(ins_element, 'efxpt:parameter', name=pcr_name, value=value)
