from array import array
from pathlib import Path
import os
from time import sleep
from tornado.escape import url_unescape, url_escape

from google.protobuf.empty_pb2 import Empty
from google.protobuf.wrappers_pb2 import Int32Value, StringValue
from efx_dbg.jcf import JtagChainEngine
from efx_dbg.jtag import JtagManager
from efx_pgm.efx_hw_common.manager import EfxFTDIHwConnectionManager
from efx_pgm.usb_resolver import UsbResolver
from efx_pgm.efx_hw_common.boards import EfxHwBoardProfileSelector
from efx_pgm.efx_hw_common.boards import InitProfile
from pyftdi.bits import BitSequence
from pyftdi.jtag import JtagState

from efx_pgm.efx_hw_server.server_pb2 import BitSequenceMessage, Jtag2ApbConfigureRequest, Jtag2ApbUnconfigureRequest, Jtag2ApbWriteRequest, Jtag2ApbReadRequest, JtagReadRequest, JtagWriteRequest, JtagSMPathRequest, JtagSMPath, JtagChangeStateRequest, JtagConnectionRequest, Jtag2ApbConfigureMessage
from efx_pgm.efx_hw_server.server_pb2_grpc import Jtag2ApbEngineHandlerServicer
from efx_pgm.jtag2SpiDriver import Jtag2SpiChainEngine
from efx_dbg.engine import SharedJtagEngine
from efx_dbg.jtag import JtagSession
from efx_pgm.efx_hw_server.handler.DeviceController import DeviceController
from efx_pgm.efx_hw_common.device import EfxHwDevice


def bs_to_message(bseq: BitSequence) -> BitSequenceMessage:
    result = BitSequenceMessage()
    result.data.extend(bseq.sequence())
    return result

def str_to_message(msg: str) -> Jtag2ApbConfigureMessage:
    result = Jtag2ApbConfigureMessage()
    result.resp += msg
    return result

def message_to_bs(message: BitSequenceMessage) -> BitSequence:
    result = BitSequence()
    for bit in message.data:
        result.append(bit)
    return result

def bytes2binstr(b, n=None) -> str:
    s: str  = ' '.join(f'{x:08b}' for x in b)
    return s if n is None else s[:n + n // 8 + (0 if n % 8 else -1)]


class Jtag2ApbEngineHandler(Jtag2ApbEngineHandlerServicer):
    jtag_session = None

    def __init__(self, server):
        super().__init__()
        self._server = server

    def configure(self, request: Jtag2ApbConfigureRequest, context):
        try:
            if request.user not in ('USER1', 'USER2', 'USER3', 'USER4'):
                return str_to_message(f'User {request.user} not valid')

            self.user = request.user

            def get_param_from_url(u, param_name):
                return [i.split("=")[-1] for i in u.split("?", 1)[-1].split("&") if i.startswith(param_name + "=")][0]
            dev_url = url_unescape(get_param_from_url(request.url, 'dev_url'))

            self.conn_id = self._server.find_active_connection(request.url, 'jtag', {})
            if self.conn_id != None:
                shared_engine = self._get_shared_engine(self.conn_id)

            else:
                self.conn_id = self._server.acquire_connection(request.url, conn_type='jtag')

                engine_provider = EfxFTDIHwConnectionManager.get_controller(dev_url, EfxFTDIHwConnectionManager.ConnectionType.JTAG, trst=request.trst, frequency=request.frequency)

                board_profile_selector = EfxHwBoardProfileSelector([Path(os.environ["EFXPGM_HOME"], 'bin', 'efx_pgm', 'efx_hw_common', 'boards')])
                dev_list = UsbResolver(console=None, mixed_backend=True).get_usb_connections()
                for dev in dev_list:
                    if dev_url in dev.URLS:
                        best = dev
                        break
                else:
                    best = None
                init_val = ""
                init_dir = ""
                if best is None:
                    init_val = ""
                    init_dir = ""
                else:
                    best_profile = board_profile_selector.get_best_profile(best)
                    if best_profile:
                        init_val = best_profile.init.val
                        init_dir = best_profile.init.dir
                init = InitProfile(val=init_val, dir=init_dir)
                engine_provider.configure(dev_url, init)

                self._server.set_connection_context(self.conn_id, {
                    'engine': engine_provider,
                    'dev_url': dev_url
                })

                shared_engine = self._create_shared_jtag_engine(self.conn_id, request.tap, request.user)

            self.jtag_session = JtagSession(url=request.url, jtag_user=request.user, tap=request.tap)
            self.jtag_session.attach_engine(shared_engine)

        except Exception as e:
            return str_to_message(e)
        return str_to_message('Connected')

    def unconfigure(self, request: Jtag2ApbUnconfigureRequest, context):
        if self.jtag_session:
            try:
                shared_engine = self.jtag_session.detach_engine()
                if shared_engine.has_users():
                    conn_context = self._server.get_connection_context(self.conn_id)
                    conn_context['shared_engine'] = shared_engine
                    conn_context[self.user] = False
                    self._server.set_connection_context(self.conn_id, conn_context)
                else:
                    self._server.release_connection(self.conn_id)

                self.jtag_session = None
                return str_to_message('Disconnected')
            except Exception as e:
                return str_to_message(f'Encountered exception: {e}')
        return str_to_message('Encountered error: jtag session not founded')

    def write(self, request: Jtag2ApbWriteRequest, context):
        #engine = self.get_engine(request.header.connection_id)
        if self.jtag_session == None:
            return bs_to_message(BitSequence('10'))

        with self.jtag_session.get_engine() as jtag_engine:
            address = request.address
            wrdata = request.data
            wrparity: BitSequence = message_to_bs(request.wrparity)

            bseq_write_byte_data = bytes(tuple(address)) + bytes(tuple(wrdata))
            bseq_write_bit_data = BitSequence(wrparity+BitSequence("000000")+BitSequence("1"), length=11)

            bseq_write_byte_data = bytes2binstr(bseq_write_byte_data).replace(" ", "")

            jtag_engine.write_dr(BitSequence(bseq_write_byte_data, True)+bseq_write_bit_data)
            resp = jtag_engine.read_dr(2)
            jtag_engine.go_idle()
            return bs_to_message(resp)

    def read(self, request: Jtag2ApbReadRequest, context):
        #engine = self.get_engine(request.header.connection_id)
        if self.jtag_session == None:
            return bs_to_message(BitSequence('10'))

        with self.jtag_session.get_engine() as jtag_engine:
            address = request.address

            bseq_read_byte_data = bytes(tuple(address)) + bytes((0,0,0,0))
            bseq_read_bit_data = BitSequence("00000000000", length=11)

            bseq_read_byte_data = bytes2binstr(bseq_read_byte_data).replace(" ", "")

            jtag_engine.write_dr(BitSequence(bseq_read_byte_data, True)+bseq_read_bit_data)
            sleep(0.001)
            resp = jtag_engine.read_dr(34)
            jtag_engine.go_idle()
            return bs_to_message(resp)

    def burst_read(self, request: Jtag2ApbWriteRequest, context):
        if self.jtag_session == None:
            return bs_to_message(BitSequence('10'))

        with self.jtag_session.get_engine() as jtag_engine:
            address = request.address
            wrdata = request.data
            wrparity: BitSequence = message_to_bs(request.wrparity)

            bseq_read_byte_data = bytes(tuple(address)) + bytes(tuple(wrdata))
            #parity bit[3:0] + burst_size[3:0] + burst_type[1:0] + read/write
            bseq_read_bit_data = BitSequence("1111"+"1000"+"10"+ "0", length=11)
            bseq_read_byte_data = bytes2binstr(bseq_read_byte_data).replace(" ", "")
            jtag_engine.write_dr(BitSequence(bseq_read_byte_data, True)+bseq_read_bit_data)
            sleep(0.001)

            #Read two data payload 
            resp = jtag_engine.read_dr(66)
            jtag_engine.go_idle()
            return bs_to_message(resp)

    #------------------------------------------------------------------------------------------------------
    # helper

    def _tap_to_instr(self, tap: str):
        return {
            'efx': JtagManager.EFX_INSTR,
            'efx_ti': JtagManager.EFX_TI_INSTR,
            'xlnx': JtagManager.XIL_INSTR
        }.get(tap)

    def _get_shared_engine(self, conn_id):
        conn_context = self._server.get_connection_context(conn_id)
        engine = conn_context.get('shared_engine')
        assert engine is not None
        return engine

    def _create_shared_jtag_engine(self, conn_id, tap, user):
        conn_context = self._server.get_connection_context(conn_id)
        chain_engine = conn_context.get('engine')
        instr_set = JtagManager()._tap_to_instr(tap)
        shared_engine = SharedJtagEngine(chain_engine, instr_set)
        shared_engine.setup_jtag_user(user)

        conn_context['shared_engine'] = shared_engine
        conn_context['tap'] = tap
        conn_context['USER1'] = False
        conn_context['USER2'] = False
        conn_context['USER3'] = False
        conn_context['USER4'] = False
        conn_context[user] = True
        self._server.set_connection_context(conn_id, conn_context)

        return shared_engine
