from pathlib import Path


class HexBinValidationError(Exception):
    pass


def is_intel_hex(p: Path) -> bool:
    try:
        with open(p, mode='r', encoding='utf-8') as fp:
            first_ch = fp.read(1)
            return first_ch == ':'
    except UnicodeDecodeError as exc:
        return False


def validate_intel_hex(p: Path, max_pcr_bit_length: int):
    """
    Raise Exception if the intel hex file invalid
    """
    # worst case: all records are data records with 8 data bytes
    DATA_MAX_BYTES = max_pcr_bit_length // 8
    # Each byte is represented by 2 ascii
    # start code (1) + byte count (2) + address (4) + record type (2) + data (2 * 8) + check sum (2) + new line (2)
    # newline can be \n or \r\n
    MAX_FILE_BYTES = (DATA_MAX_BYTES // 8) * 29
    if p.stat().st_size > MAX_FILE_BYTES:
        raise HexBinValidationError(f'File too large, max size is {MAX_FILE_BYTES} bytes, got = {p.stat().st_size}')

    with open(p, mode='r', encoding='utf-8') as fp:
        for line in fp:
            try:
                line = line.strip('\r\n')

                if line[0] != ':':
                    raise HexBinValidationError(f"Invalid Intel Hex record: '{line}', missing start code ':'")

                # Min line length: :00000001FF End of file record
                # start code (1) + byte count (2) + address (4) + record type (2) + data (bytes * 2) + check sum (2)
                record_length = len(line)
                if record_length < 11:
                    raise HexBinValidationError(f"Invalid Intel Hex record: '{line}', length too short")

                # Extract the fields from the record
                byte_count = int(line[1:3], 16)
                address = int(line[3:7], 16)
                record_type = int(line[7:9], 16)

                # Record type must be [0x00 - 0x05]
                if record_type < 0 and record_type > 5:
                    raise HexBinValidationError(f"Invalid record type: '{record_type}' in Intel HEX record: '{line}'")

                exp_record_length = (1 + 2 + 4 + 2 + byte_count * 2 + 2)
                if record_length < exp_record_length:
                    raise HexBinValidationError(f"Invalid Intel Hex record: '{line}', invalid length {exp_record_length} vs {record_length}")

                checksum = line[-2:]
                # Verify the checksum
                sum_value = sum(int(line[i:i+2], 16) for i in range(1, len(line) - 2, 2))

                def twos_complement_lsb(integer: int) -> str:
                    integer = integer & 0xff
                    lsb = format(integer, '08b')
                    first_com = ''
                    for i in lsb:
                        first_com += '0' if i == '1' else '1'
                    twos_complement = int(first_com, 2) + 1
                    twos_complement = twos_complement & 0xff
                    val = format(twos_complement, '02X')
                    # print(integer, lsb, first_com, val)
                    return val

                computed_checksum = twos_complement_lsb(sum_value)
                if computed_checksum != checksum:
                    raise HexBinValidationError(f"Checksum mismatch in Intel HEX record: '{line}', {computed_checksum} vs {checksum}")
            except ValueError as exc:
                raise HexBinValidationError(f"Invalid character in Intel Hex record: '{line}'")


def validate_generic_hex(p: Path, max_size: int):
    """
    Raise Exception if the generic hex file invalid
    """
    with open(p, mode='r', encoding='utf-8') as fp:
        pass


def validate_generic_bin(p: Path, max_bytes: int):
    """
    Raise Exception if the generic bin file invalid
    """
    if p.stat().st_size > max_bytes:
        raise HexBinValidationError(f'File too large, max size is {max_bytes} bytes')


def load_intel_hex(p: Path, count: int | None = None) -> bytearray:
    """
    Assume validate_intel_hex() is called
    """
    with open(p, mode='r', encoding='utf-8') as fp:
        offset = 0
        buffer = {}
        for line in fp:
            line = line.strip('\r\n')
            import binascii
            bdata = binascii.unhexlify(line[1:-2])

            byte_count = int(bdata[0])
            address = int(bdata[1]) * 256 + int(bdata[2])
            record_type = int(bdata[3])
            # print(byte_count, address, record_type)

            if record_type == 0:
                # data record
                address += offset
                for i in range(4, 4 + byte_count):
                    buffer[address] = bdata[i]
                    address += 1
            elif record_type == 1:
                # end of file record
                pass
            elif record_type == 2:
                # Extended 8086 Segment Record
                assert byte_count == 2 and address == 0
                offset = (bdata[4] * 256 + bdata[5]) * 16
            elif record_type == 4:
                # Extended Linear address record
                assert byte_count == 2 and address == 0
                offset = (bdata[4] * 256 + bdata[5]) * 65536
            else:
                # TODO: Handle record type 3 and 5
                pass

        start_addr = min(buffer)
        end_addr = start_addr + count - 1
        results = bytearray()
        # Remarks: Assume the data always aligned from 0
        for i in range(start_addr, end_addr + 1):
            results.append(buffer.get(i, 0))
        return results

# Remarks: For debugging only, using third-party library as cross-reference
# def load_intel_hex(p: Path, count: int | None = None) -> bytearray:
#     from intelhex import IntelHex
#     ih = IntelHex()
#     ih.padding = 0x00
#     ih.loadhex(str(p))
#     return ih.tobinarray()


def load_generic_bin(p: Path) -> bytearray:
    """
    assume validate_generic_bin called
    """
    with open(p, mode='rb') as fp:
        return bytearray(fp.read())
