import re
import os
import logging
from enum import IntEnum
from threading import Lock
from time import time, sleep
import random
import binascii
import enum
import json
from typing import List
from dataclasses import dataclass
from PyQt5.QtCore import Qt, QObject, QAbstractTableModel, QModelIndex, QVariant, QMutex, QWaitCondition, QThreadPool
from PyQt5 import QtCore, QtWidgets
from PyQt5.QtWidgets import QSizePolicy, QSpacerItem, QWidget, QMainWindow, QDialog, QApplication, QPushButton, QVBoxLayout, QHBoxLayout, QSpinBox, QLabel
from pyftdi.jtag import JtagError
from pyftdi.ftdi import FtdiError

import pandas as pd
import numpy as np
from bitstring import BitArray
import matplotlib
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
import matplotlib.patches as patches
import matplotlib.pyplot as plt
from pyftdi.bits import BitSequence

from worker import Worker, WorkerSignals
from efx_serdes_dbg.gui.util.app_logger import Logger

class PiStepSize(IntEnum):
    PI_STEP_1 = 3
    PI_STEP_2 = 0
    PI_STEP_4 = 1
    PI_STEP_6 = 2

class InterpolationStepsPerUI(IntEnum):
    STEPS_32 = 0
    STEPS_64 = 1

class RunMode(enum.Enum):
    CONTINUOUS="CONTINUOUS"
    PAUSE="pause"
    STOP="stop"

class EyeDiagramViewModel(QObject):
    event_eyesurf_start = QtCore.pyqtSignal()
    event_eyesurf_stop = QtCore.pyqtSignal()

    aspect_ratio_override = 0.25

    def __init__(self, ui, eye_diagram_config_view_model, figureCanvas, btn_run, btn_stop, lb_horizontal_open, lb_vertical_open):
        super().__init__()
        self.logger = Logger.get_logger()
        self._lock = Lock()
        self.ui = ui
        self.mutex = QMutex()
        self.cond = QWaitCondition()
        self.threadPool = QThreadPool()
        self.worker = None
        self.draw_worker = None
        self.run_mode = RunMode.STOP
        self.eye_diagram_config_view_model = eye_diagram_config_view_model
        self.canvas = figureCanvas
        self.btn_run = btn_run
        self.btn_stop = btn_stop
        self.lb_horizontal_open = lb_horizontal_open
        self.lb_vertical_open = lb_vertical_open
        self.jtag2cdb = None

        self.addr_offset = 1
        self.vt = 0
        self.pi = 0
        self.max_allowed_wait = 2 # Simply set a 2s timeout

        self.accumulation_result = {}

        self.mask_patch = None

        self.im = None
        self.nx = 0
        self.ny = 0

        self.horizontal_open = -1
        self.vertical_open = -1

        self.colormap = self.gen_color_map()

    def gen_color_map(self):
        colors = ["yellow", "darkorange", "red"]
        cmap1 = LinearSegmentedColormap.from_list("mycmap", colors)
        viridis = cmap1.resampled(1024)
        newcolors = viridis(np.linspace(0, 1, 1024))
        black = np.array([0, 0, 0, 1])
        newcolors[:1, :] = black
        newcmp = ListedColormap(newcolors)

        return newcmp

    def average(self, lst):
        return sum(lst) / len(lst)

    def find_eye_open_by_raw(self, raw_arr, scale_x = 1, scale_y = 1):
        top = None
        bottom = None
        left = None
        right = None

        black_regions = []

        for idx_y, row in enumerate(raw_arr):
            black_spots = []
            for idx_x, col in enumerate(row):
                if (col == 0):
                    black_spots.append(idx_x)

            black_regions.append(black_spots)

        # Iterate through the list
        black_spot_counts = []
        for idx, y in enumerate(black_regions):
            if(top is None):
                if(len(y)):
                    top = idx

            if(top and (bottom is None)):
                if((len(y) == 0) and (idx - top > 10) ):
                    bottom = idx - 1

            black_spot_counts.append(len(y))

        # Invalid data, no point to continue
        if(bottom == None or top == None):
            return None, None, None

        vertical_open = bottom - top + 1

        max_black_spot_position = black_spot_counts.index(max(black_spot_counts)) + int(black_spot_counts.count(max(black_spot_counts)) / 2)

        horizontal_open = len(black_regions[max_black_spot_position])
        left = black_regions[max_black_spot_position][0]
        right = black_regions[max_black_spot_position][-1]

        top_coor = (int(self.average(black_regions[top])), top)
        bottom_coor = (int(self.average(black_regions[bottom])), bottom)

        left_coor = (left, max_black_spot_position)
        right_coor = (right, max_black_spot_position)

        polygon = [
                [int(self.average(black_regions[top])) * scale_x, top * scale_y],        # Top
                [right * scale_x, max_black_spot_position * scale_y],               # Right
                [int(self.average(black_regions[bottom])) * scale_x, bottom * scale_y],  # Bottom
                [left * scale_x, max_black_spot_position * scale_y]                 # Left
            ]

        return polygon, horizontal_open, vertical_open

    def bs_to_str(self, bs: BitSequence):
        return binascii.hexlify(bs.tobytes()).decode('ascii').upper()

    def bytes_to_str(self, bytes: bytes):
        return binascii.hexlify(bytes).decode('ascii').upper()

    def twos_comp(self, value, bitWidth):
        if value >= 2**bitWidth:
            # This catches when someone tries to give a value that is out of range
            raise ValueError("Value: {} out of range of {}-bit value.".format(value, bitWidth))
        return ((value - int((value << 1) & 2**bitWidth)) & int('1'*bitWidth, 2))

    def convert_2s_to_decimal(self, val, bitWidth=7):
        if val & (1<<(bitWidth-1)):  # High bit set indicates its a negative value
            return -(2**bitWidth-val)
        return val  # Positive value.

    def get_vt_pos_in_graph(self, vt_ranges: list, vt):
        mid = len(vt_ranges)
        positive_vt = (mid - vt_ranges.index(vt)) - 1
        negative_vt = (mid + 1 + vt_ranges.index(vt)) - 1

        return positive_vt, negative_vt

    def get_pi_pos_in_graph(self, pi_ranges, pi):
        return pi_ranges.index(pi)

    def read_eyesurf_status(self):
        addr = bytes((0x20, 0x40 | self.addr_offset, 0xA0))
        eyesurf_ready = 0

        with self._lock:
            if self.jtag2cdb == None:
                return False

            read_resp = self.jtag2cdb.read(address=addr)

        msg = "Read:: Addr: 0x{0}, Data: {1}".format(self.bytes_to_str(addr), read_resp)
        self.logger.debug(msg)

        # Ready when bit-8 is high
        eyesurf_ready = (read_resp.sequence()[8] == 1)
        return (eyesurf_ready > 0)

    def set_step_size(self, step_size : PiStepSize, step_size_per_ui : InterpolationStepsPerUI = InterpolationStepsPerUI.STEPS_64, write = True) -> BitSequence:
        data = BitSequence(0, length=32)
        addr = bytes((0x20, 0x40 | self.addr_offset, 0x61))

        # step size
        data = data | (BitSequence(step_size << 3, length=32))

        # Interpolation Steps
        data = data | (BitSequence(step_size_per_ui << 5, length=32))

        # DPI and EPI control setting time. Take the default value from programmer gui now.
        data = data | (BitSequence(3, length=32))

        if(write):
            with self._lock:
                if self.jtag2cdb == None:
                    return data

                write_resp = self.jtag2cdb.write(address=addr, wrdata=bytes(data.tobytes()))

                msg = "Write:: Addr: 0x{0}, Data: 0x{1}, Response: 0x{2}".format(self.bytes_to_str(addr), self.bs_to_str(data), self.bs_to_str(write_resp))
                self.logger.debug(msg)

        return data

    def set_duration(self, duration, enable_eye_surf_mode = False, write = True) -> BitSequence:
        data = BitSequence(0, length=32)
        addr = bytes((0x20, 0x40 | self.addr_offset, 0xA0))

        # 5 bits, so maximum value is 31
        data = data | (BitSequence((min(duration, 31)) << 11, length=32))

        if(enable_eye_surf_mode):
            data = data | (BitSequence(1 << 9, length=32))

        if(write):
            with self._lock:
                if self.jtag2cdb == None:
                    return data

                write_resp = self.jtag2cdb.write(address=addr, wrdata=bytes(data.tobytes()))
                msg = "Write:: Addr: 0x{0}, Data: 0x{1}, Response: 0x{2}".format(self.bytes_to_str(addr), self.bs_to_str(data), self.bs_to_str(write_resp))
                self.logger.debug(msg)

        return data

    def set_comparator_voltage(self, vt, write = True) -> BitSequence:
        data = BitSequence(0, length=32)
        addr = bytes((0x20, 0x40 | self.addr_offset, 0xA1))

        # 7 bits, so maximum is 0x7F
        data = data | (BitSequence((min(vt, 0x7F)), length=32))

        if(write):
            with self._lock:
                if self.jtag2cdb == None:
                    return data

                write_resp = self.jtag2cdb.write(address=addr, wrdata=bytes(data.tobytes()))

                msg = "Write:: Addr: 0x{0}, Data: 0x{1}, Response: 0x{2}".format(self.bytes_to_str(addr), self.bs_to_str(data), self.bs_to_str(write_resp))
                self.logger.debug(msg)

        return data

    def set_pi_value(self, pi, in_value : BitSequence = None, start_eyesurf = True, write = True) -> BitSequence:
        data = BitSequence(0, length=32)
        addr = bytes((0x20, 0x40 | self.addr_offset, 0xA0))

        if(in_value):
            data = data | in_value

        # 7 bits, so maximum is 0x7F
        data = data | (BitSequence((min(pi, 0x7F)), length=32))

        if(write):
            with self._lock:
                if self.jtag2cdb == None:
                    return data

                write_resp = self.jtag2cdb.write(address=addr, wrdata=bytes(data.tobytes()))

                msg = "Write:: Addr: 0x{0}, Data: 0x{1}, Response: 0x{2}".format(self.bytes_to_str(addr), self.bs_to_str(data), self.bs_to_str(write_resp))
                self.logger.debug(msg)

        # TODO: See if we can combine set pi value with start eye surf
        if(start_eyesurf):
            data = data | (BitSequence(1 << 10, length=32))

            if(write):
                with self._lock:
                    if self.jtag2cdb == None:
                        return data

                    write_resp = self.jtag2cdb.write(address=addr, wrdata=bytes(data.tobytes()))

                    msg = "Write:: Addr: 0x{0}, Data: 0x{1}, Response: 0x{2}".format(self.bytes_to_str(addr), self.bs_to_str(data), self.bs_to_str(write_resp))
                    self.logger.debug(msg)

        return data

    def update_eye_open(self, clear = False):
        if(clear):
            self.lb_horizontal_open.setText("-")
            self.lb_vertical_open.setText("-")
        else:
            if(self.horizontal_open and self.vertical_open):
                #self.lb_horizontal_open.setText("{0} Steps ({1:.2f} ps)".format(self.horizontal_open * self.eye_diagram_config_view_model.getEyeSurfConfig().piStep, self.horizontal_open / 64 * 62.5))
                #self.lb_vertical_open.setText("{0} Steps ({1} mV)".format(self.vertical_open * self.eye_diagram_config_view_model.getEyeSurfConfig().dacStep, self.vertical_open * 3))
                self.lb_horizontal_open.setText("{0} Steps".format(self.horizontal_open * self.eye_diagram_config_view_model.getEyeSurfConfig().piStep))
                self.lb_vertical_open.setText("{0} Steps".format(self.vertical_open * self.eye_diagram_config_view_model.getEyeSurfConfig().dacStep))

    def draw_mask(self ,scale_x, scale_y):
        mask_horizontal_steps = 19
        mask_vertical_steps = 5
        mask_horizontal_offset_px = (len(self.data[0]) * 0.3) / scale_x
        mask_vertixal_offset_px = (len(self.data) * 0.7) / scale_y

        mid_coor = [(len(self.data[0]) / 2), (len(self.data) / 2)]
        eye_path = [
            [int(mid_coor[0]), int((mid_coor[1] - ((mask_vertical_steps/2) * scale_y)))],           # Top
            [int((mid_coor[0] + ((mask_horizontal_steps/2) * scale_x))), int((mid_coor[1]))],       # Right
            [int(mid_coor[0]), int((mid_coor[1] + ((mask_vertical_steps/2) * scale_y)))],           # Bottom
            [int((mid_coor[0] - ((mask_horizontal_steps/2) * scale_x))), int((mid_coor[1]))],       # Left
        ]

        ax = plt.gca()
        pol = patches.Polygon(eye_path, linewidth=1, edgecolor='blue',facecolor='blue')
        ax.add_patch(pol)

        return pol

    def start_acquisition(self):
        config = self.eye_diagram_config_view_model.getEyeSurfConfig()
        lane = config.get_selected_lane()
        if(lane == 0):
            self.addr_offset = 1
        elif(lane == 1):
            self.addr_offset = 3
        elif(lane == 2):
            self.addr_offset = 5
        elif(lane == 3):
            self.addr_offset = 7
        else:
            self.addr_offset = 1

        # Clear calculation
        self.update_eye_open(True)


        self.vt_ranges = list(range(0, config.dacValue, config.dacStep))
        self.pi_ranges = [self.twos_comp(i, 7) for i in range(-config.piValue, config.piValue + 1, config.piStep)]

        nx = int(len(self.pi_ranges))
        ny = int(len(self.vt_ranges) * 2)

        self.data = np.zeros((ny, nx))
        self.data[0, 0] = 1024

        if((self.nx != nx) or (self.ny != ny)):
            self.nx = nx
            self.ny = ny

            if(self.aspect_ratio_override):
                self.aspect_ratio = self.aspect_ratio_override
            else:
                self.aspect_ratio = (self.nx / self.ny)
            #self.im = plt.imshow(self.data, origin='upper', aspect=self.aspect_ratio, cmap=cmap)
            if(self.im is None):
                self.im = plt.imshow(self.data, origin='upper', aspect=self.aspect_ratio, cmap=self.colormap)
                plt.colorbar()
            else:
                self.im = plt.imshow(self.data, origin='upper', aspect=self.aspect_ratio, cmap=self.colormap)

        ax = plt.gca()
        ax.set_facecolor('k')
        plt.xticks([0, int(len(self.pi_ranges) / 2) , len(self.pi_ranges) - 1], [str(self.convert_2s_to_decimal(self.pi_ranges[0])) + " Steps", "0", str(self.convert_2s_to_decimal(self.pi_ranges[-1])) + " Steps"])
        plt.yticks([0 , (len(self.vt_ranges))*2 - 1], [str(config.dacValue) + " Steps", str(-config.dacValue) + " Steps"])

        # Remove previous mask
        if(self.mask_patch):
            self.mask_patch.remove()

        # Remove the mask for now
        #self.mask_patch = self.draw_mask(1 / config.piStep, 1 / config.dacStep)

        self.data.fill(0)
        self.im.set_data(self.data)
        self.canvas.draw()

        self.worker = Worker(self.worker_collect_eyesurf, mtx=self.mutex, cond=self.cond)
        self.worker.signals.progress.connect(self.runprogress_update)
        self.run_mode = RunMode.CONTINUOUS
        self.threadPool.start(self.worker)

        self.draw_worker = Worker(self.worker_draweye)
        self.threadPool.start(self.draw_worker)

    # noinspection PyArgumentList
    @QtCore.pyqtSlot()
    def on_run_pressed(self):
        # Start new acquisition
        if((self.worker == None) or self.run_mode == RunMode.STOP):
            self.start_acquisition()
        # Continue the paused acquisition
        elif (self.run_mode == RunMode.PAUSE):
            self.run_mode = RunMode.CONTINUOUS
            self.cond.wakeAll()
        # Pause the acquisition
        elif (self.run_mode == RunMode.CONTINUOUS):
            self.run_mode = RunMode.PAUSE
            #self.cond.wakeAll()
        else:
            raise Exception("Unknown run_mode!")    # TODO: Handle this

        self.update_buttons(self.run_mode)

    @QtCore.pyqtSlot()
    def on_stop_pressed(self):
        self.run_mode = RunMode.STOP
        self.cond.wakeAll()
        self.update_buttons(self.run_mode)

    @QtCore.pyqtSlot()
    def on_save_result_pressed(self):
        result_file_path = 'result_{0}.json'.format(str(time()))

        if (self.accumulation_result is not None):
            with open(result_file_path, "w") as outfile: 
                json.dump(self.accumulation_result, outfile, indent=4)

    def update_buttons(self, mode: RunMode, disconnect = False):
        if(mode == RunMode.CONTINUOUS):
            self.btn_run.setText("Pause")
            self.btn_run.setEnabled(True)
            self.btn_stop.setEnabled(True)
        elif(mode == RunMode.PAUSE):
            self.btn_run.setText("Continue")
            self.btn_run.setEnabled(True)
            self.btn_stop.setEnabled(True)
        elif(mode == RunMode.STOP):
            self.btn_run.setText("Start")
            if (disconnect == True):
                self.btn_run.setEnabled(False)
                self.btn_stop.setEnabled(False)
            else:
                self.btn_run.setEnabled(True)
                self.btn_stop.setEnabled(False)


    def onPostJtagConfig(self, jtag2cdb):
        self.jtag2cdb = jtag2cdb
        self.update_buttons(self.run_mode)

    def onPostJtagUnconfig(self):
        self.run_mode = RunMode.STOP
        self.cond.wakeAll()
        self.update_buttons(self.run_mode, disconnect= True)

    def generate_eyesurf(self, progress_callback, duration_bs, mtx, cond):
        dummy_read = 1
        old_pi = 0
        old_vt = 0
        row_result = {}
        for vt in self.vt_ranges:
            self.vt = vt
            self.set_comparator_voltage(
                    vt = self.vt,
                    write = True
                )

            next_row = 1
            for pi in self.pi_ranges:
                if(self.run_mode == RunMode.PAUSE):
                    cond.wait(mtx)
                
                if(self.run_mode == RunMode.STOP):
                    break

                self.pi = pi
                # self.set_pi_value(
                #         pi = self.pi,
                #         in_value = duration_bs,
                #         start_eyesurf = True,
                #         write = True
                #     )

                # Wait until eyesurf is ready, with timeout defined in max_allowed_wait(seconds)
                # start = time()
                # while not self.read_eyesurf_status():
                #     if time() - start >= self.max_allowed_wait:
                #         raise Exception("Timeout!")
                with self._lock:
                    if self.jtag2cdb == None:
                        return

                    # accum_a = self.jtag2cdb.read(address=bytes((0x20, 0x40 | self.addr_offset, 0xA2)))
                    # accum_b = self.jtag2cdb.read(address=bytes((0x20, 0x40 | self.addr_offset, 0xA3)))
                    burst_a = self.jtag2cdb.burst_read(address=bytes((0x20, 0x40 | self.addr_offset, 0xA0)),pi = self.pi, in_value = duration_bs)
                    accum_a = str(burst_a)[3:42]
                    accum_b = "01" + str(burst_a)[43:80]
                    accum_a = BitSequence(int(''.join(accum_a.split(" ")),2), length = 34)
                    accum_b = BitSequence(int(''.join(accum_b.split(" ")),2), length = 34)
                    msg = 'vt: {0}, pi: {1}, accum_a: 0x{2}, accum_b: 0x{3}'.format(vt, self.pi, accum_a, accum_b)
                    self.logger.debug(msg)

                if not dummy_read:
                    row_result['pi_{0}'.format(old_pi)] = {
                    'accum_a' : str(accum_a),
                    'accum_b' : str(accum_b),}
                    progress_callback.emit(old_vt, old_pi, accum_a, accum_b)  
                    if next_row:
                        self.accumulation_result['vt_{0}'.format(old_vt)] = row_result                   
                
                old_vt = self.vt
                old_pi = self.pi
                dummy_read = 0
                # row_result['pi_{0}'.format(pi)] = {
                #         'accum_a' : str(accum_a),
                #         'accum_b' : str(accum_b),
                #     }

                # progress_callback.emit(self.vt, self.pi, accum_a, accum_b)
                #sleep(0.01)

                # self.set_pi_value(
                #         pi = self.pi,
                #         in_value = duration_bs,
                #         start_eyesurf = False,
                #         write = True
                #     )

            if(self.run_mode == RunMode.STOP):
                break
        burst_a = self.jtag2cdb.burst_read(address=bytes((0x20, 0x40 | self.addr_offset, 0xA0)),pi = self.pi, in_value = duration_bs)
        accum_a = str(burst_a)[3:42]
        accum_b = "01" + str(burst_a)[43:80]
        accum_a = BitSequence(int(''.join(accum_a.split(" ")),2), length = 34)
        accum_b = BitSequence(int(''.join(accum_b.split(" ")),2), length = 34)
        row_result['pi_{0}'.format(old_pi)] = {
                'accum_a' : str(accum_a),
                'accum_b' : str(accum_b),}
        progress_callback.emit(old_vt, old_pi, accum_a, accum_b)  
        self.accumulation_result['vt_{0}'.format(old_vt)] = row_result                   
                
    def worker_collect_eyesurf(self, progress_callback, mtx, cond):
        global_start = time()
        self.accumulation_result = {}

        self.event_eyesurf_start.emit()

        # This is to exit eye surf mode, in case previous run didn't end properly
        self.set_duration(
            duration = self.eye_diagram_config_view_model.getEyeSurfConfig().accumulationPeriod,
            enable_eye_surf_mode = False,
            write = True
        )

        # Set high resolution and step size of 1
        self.set_step_size(
                        step_size = PiStepSize.PI_STEP_1,
                        step_size_per_ui = InterpolationStepsPerUI.STEPS_64)

        # Read back the value

        with self._lock:
            if self.jtag2cdb == None:
                return

            read_resp = self.jtag2cdb.read(address=bytes((0x20, 0x40 | self.addr_offset, 0x61)))
            self.logger.debug(read_resp)

        # Set eye surf duration and enter eye surf mode
        duration_bs = self.set_duration(
                        duration = self.eye_diagram_config_view_model.getEyeSurfConfig().accumulationPeriod,
                        enable_eye_surf_mode = True,
                        write = True
                    )
        # Start acquisition
        jtag_error = False
        mtx.lock()
        try:
            self.generate_eyesurf(progress_callback, duration_bs, mtx, cond)
            self.set_duration(
                duration = self.eye_diagram_config_view_model.getEyeSurfConfig().accumulationPeriod,
                enable_eye_surf_mode = False,
                write = True
            )
        # except AssertionError :
        #     self.logger.debug("Assertion Error Recieve")
        except (FtdiError,JtagError,AssertionError): 
            self.logger.debug("Communication with device has been interrupted")
            jtag_error = True

        _, self.horizontal_open, self.vertical_open = self.find_eye_open_by_raw(self.data, 1, 1)
        self.update_eye_open()
        mtx.unlock()

        self.event_eyesurf_stop.emit()

        self.run_mode = RunMode.STOP
        self.update_buttons(self.run_mode, disconnect = jtag_error)

        self.logger.debug('Eyesurf finished in %d seconds' % int(time() - global_start))


    def worker_draweye(self, progress_callback):
        # Some delay, if the data acquisition is not yet start
        if(self.worker.is_running() == False):
            sleep(1)

        while(self.worker.is_running()):
            self.canvas.draw()
            sleep(0.2)

        self.canvas.draw()

    @QtCore.pyqtSlot(int, int, BitSequence, BitSequence)
    def runprogress_update(self, vt, pi, accum_a, accum_b):
        self.logger.debug("vt: {0}, pi: {1}".format(vt, pi))
        pi = int(self.get_pi_pos_in_graph(self.pi_ranges, pi))
        vt1, vt2 = self.get_vt_pos_in_graph(self.vt_ranges, vt)
        #accum_a = random.randint(200,800)
        #accum_b = random.randint(200,800)
        accum_a = BitArray(bin=str(accum_a)[6:].replace(" ", "")).int
        accum_b = BitArray(bin=str(accum_b)[6:].replace(" ", "")).int

        self.data[vt1, pi] = accum_a
        self.data[vt2, pi] = accum_b

        self.im.set_data(self.data)

    def mass_command_started(self):
        self.btn_run.setEnabled(False)
        self.btn_stop.setEnabled(False)

    def mass_command_stopped(self):
        self.update_buttons(self.run_mode)
