258 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			258 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Python
		
	
	
	
from collections import defaultdict
 | 
						|
from operator import itemgetter
 | 
						|
from struct import pack, unpack
 | 
						|
 | 
						|
"""
 | 
						|
Tweaked version of nlzss modified to work with raw data and return bytes instead of operating on whole files.
 | 
						|
LZ11 functionality has been removed since it is not necessary for MMBN3
 | 
						|
 | 
						|
https://github.com/magical/nlzss
 | 
						|
"""
 | 
						|
 | 
						|
def gba_decompress(data: bytearray):
 | 
						|
    """Decompress LZSS-compressed bytes. Returns a bytearray."""
 | 
						|
    header = data[:4]
 | 
						|
    if header[0] == 0x10:
 | 
						|
        decompress_raw = decompress_raw_lzss10
 | 
						|
    else:
 | 
						|
        raise DecompressionError("not as lzss-compressed file")
 | 
						|
 | 
						|
    decompressed_size, = unpack("<L", header[1:] + b'\x00')
 | 
						|
 | 
						|
    data = data[4:]
 | 
						|
    return decompress_raw(data, decompressed_size)
 | 
						|
 | 
						|
 | 
						|
def gba_compress(data: bytearray):
 | 
						|
    byteOut = bytearray()
 | 
						|
    # header
 | 
						|
    byteOut.extend(pack("<L", (len(data) << 8) + 0x10))
 | 
						|
 | 
						|
    # body
 | 
						|
    length = 0
 | 
						|
    for tokens in chunkit(_compress(data), 8):
 | 
						|
        flags = [type(t) == tuple for t in tokens]
 | 
						|
        byteOut.extend(pack(">B", packflags(flags)))
 | 
						|
 | 
						|
        for t in tokens:
 | 
						|
            if type(t) == tuple:
 | 
						|
                count, disp = t
 | 
						|
                count -= 3
 | 
						|
                disp = (-disp) - 1
 | 
						|
                assert 0 <= disp < 4096
 | 
						|
                sh = (count << 12) | disp
 | 
						|
                byteOut.extend(pack(">H", sh))
 | 
						|
            else:
 | 
						|
                byteOut.extend(pack(">B", t))
 | 
						|
 | 
						|
        length += 1
 | 
						|
        length += sum(2 if f else 1 for f in flags)
 | 
						|
 | 
						|
    # padding
 | 
						|
    padding = 4 - (length % 4 or 4)
 | 
						|
    if padding:
 | 
						|
        byteOut.extend(b'\xff' * padding)
 | 
						|
    return byteOut
 | 
						|
 | 
						|
 | 
						|
class SlidingWindow:
 | 
						|
    # The size of the sliding window
 | 
						|
    size = 4096
 | 
						|
 | 
						|
    # The minimum displacement.
 | 
						|
    disp_min = 2
 | 
						|
 | 
						|
    # The hard minimum — a disp less than this can't be represented in the
 | 
						|
    # compressed stream.
 | 
						|
    disp_start = 1
 | 
						|
 | 
						|
    # The minimum length for a successful match in the window
 | 
						|
    match_min = 3
 | 
						|
 | 
						|
    # The maximum length of a successful match, inclusive.
 | 
						|
    match_max = 3 + 0xf
 | 
						|
 | 
						|
    def __init__(self, buf):
 | 
						|
        self.data = buf
 | 
						|
        self.hash = defaultdict(list)
 | 
						|
        self.full = False
 | 
						|
 | 
						|
        self.start = 0
 | 
						|
        self.stop = 0
 | 
						|
        #self.index = self.disp_min - 1
 | 
						|
        self.index = 0
 | 
						|
 | 
						|
        assert self.match_max is not None
 | 
						|
 | 
						|
    def next(self):
 | 
						|
        if self.index < self.disp_start - 1:
 | 
						|
            self.index += 1
 | 
						|
            return
 | 
						|
 | 
						|
        if self.full:
 | 
						|
            olditem = self.data[self.start]
 | 
						|
            assert self.hash[olditem][0] == self.start
 | 
						|
            self.hash[olditem].pop(0)
 | 
						|
 | 
						|
        item = self.data[self.stop]
 | 
						|
        self.hash[item].append(self.stop)
 | 
						|
        self.stop += 1
 | 
						|
        self.index += 1
 | 
						|
 | 
						|
        if self.full:
 | 
						|
            self.start += 1
 | 
						|
        else:
 | 
						|
            if self.size <= self.stop:
 | 
						|
                self.full = True
 | 
						|
 | 
						|
    def advance(self, n=1):
 | 
						|
        """Advance the window by n bytes"""
 | 
						|
        for _ in range(n):
 | 
						|
            self.next()
 | 
						|
 | 
						|
    def search(self):
 | 
						|
        match_max = self.match_max
 | 
						|
        match_min = self.match_min
 | 
						|
 | 
						|
        counts = []
 | 
						|
        indices = self.hash[self.data[self.index]]
 | 
						|
        for i in indices:
 | 
						|
            matchlen = self.match(i, self.index)
 | 
						|
            if matchlen >= match_min:
 | 
						|
                disp = self.index - i
 | 
						|
                if self.disp_min <= disp:
 | 
						|
                    counts.append((matchlen, -disp))
 | 
						|
                    if matchlen >= match_max:
 | 
						|
                        return counts[-1]
 | 
						|
 | 
						|
        if counts:
 | 
						|
            match = max(counts, key=itemgetter(0))
 | 
						|
            return match
 | 
						|
 | 
						|
        return None
 | 
						|
 | 
						|
    def match(self, start, bufstart):
 | 
						|
        size = self.index - start
 | 
						|
 | 
						|
        if size == 0:
 | 
						|
            return 0
 | 
						|
 | 
						|
        matchlen = 0
 | 
						|
        it = range(min(len(self.data) - bufstart, self.match_max))
 | 
						|
        for i in it:
 | 
						|
            if self.data[start + (i % size)] == self.data[bufstart + i]:
 | 
						|
                matchlen += 1
 | 
						|
            else:
 | 
						|
                break
 | 
						|
        return matchlen
 | 
						|
 | 
						|
 | 
						|
def _compress(input, windowclass=SlidingWindow):
 | 
						|
    """Generates a stream of tokens. Either a byte (int) or a tuple of (count,
 | 
						|
    displacement)."""
 | 
						|
 | 
						|
    window = windowclass(input)
 | 
						|
 | 
						|
    i = 0
 | 
						|
    while True:
 | 
						|
        if len(input) <= i:
 | 
						|
            break
 | 
						|
        match = window.search()
 | 
						|
        if match:
 | 
						|
            yield match
 | 
						|
            window.advance(match[0])
 | 
						|
            i += match[0]
 | 
						|
        else:
 | 
						|
            yield input[i]
 | 
						|
            window.next()
 | 
						|
            i += 1
 | 
						|
 | 
						|
 | 
						|
def packflags(flags):
 | 
						|
    n = 0
 | 
						|
    for i in range(8):
 | 
						|
        n <<= 1
 | 
						|
        try:
 | 
						|
            if flags[i]:
 | 
						|
                n |= 1
 | 
						|
        except IndexError:
 | 
						|
            pass
 | 
						|
    return n
 | 
						|
 | 
						|
 | 
						|
def chunkit(it, n):
 | 
						|
    buf = []
 | 
						|
    for x in it:
 | 
						|
        buf.append(x)
 | 
						|
        if n <= len(buf):
 | 
						|
            yield buf
 | 
						|
            buf = []
 | 
						|
    if buf:
 | 
						|
        yield buf
 | 
						|
 | 
						|
 | 
						|
def bits(byte):
 | 
						|
    return ((byte >> 7) & 1,
 | 
						|
            (byte >> 6) & 1,
 | 
						|
            (byte >> 5) & 1,
 | 
						|
            (byte >> 4) & 1,
 | 
						|
            (byte >> 3) & 1,
 | 
						|
            (byte >> 2) & 1,
 | 
						|
            (byte >> 1) & 1,
 | 
						|
            byte & 1)
 | 
						|
 | 
						|
 | 
						|
def decompress_raw_lzss10(indata, decompressed_size, _overlay=False):
 | 
						|
    """Decompress LZSS-compressed bytes. Returns a bytearray."""
 | 
						|
    data = bytearray()
 | 
						|
 | 
						|
    it = iter(indata)
 | 
						|
 | 
						|
    if _overlay:
 | 
						|
        disp_extra = 3
 | 
						|
    else:
 | 
						|
        disp_extra = 1
 | 
						|
 | 
						|
    def writebyte(b):
 | 
						|
        data.append(b)
 | 
						|
 | 
						|
    def readbyte():
 | 
						|
        return next(it)
 | 
						|
 | 
						|
    def readshort():
 | 
						|
        # big-endian
 | 
						|
        a = next(it)
 | 
						|
        b = next(it)
 | 
						|
        return (a << 8) | b
 | 
						|
 | 
						|
    def copybyte():
 | 
						|
        data.append(next(it))
 | 
						|
 | 
						|
    while len(data) < decompressed_size:
 | 
						|
        b = readbyte()
 | 
						|
        flags = bits(b)
 | 
						|
        for flag in flags:
 | 
						|
            if flag == 0:
 | 
						|
                copybyte()
 | 
						|
            elif flag == 1:
 | 
						|
                sh = readshort()
 | 
						|
                count = (sh >> 0xc) + 3
 | 
						|
                disp = (sh & 0xfff) + disp_extra
 | 
						|
 | 
						|
                for _ in range(count):
 | 
						|
                    writebyte(data[-disp])
 | 
						|
            else:
 | 
						|
                raise ValueError(flag)
 | 
						|
 | 
						|
            if decompressed_size <= len(data):
 | 
						|
                break
 | 
						|
 | 
						|
    if len(data) != decompressed_size:
 | 
						|
        raise DecompressionError("decompressed size does not match the expected size")
 | 
						|
 | 
						|
    return data
 | 
						|
 | 
						|
 | 
						|
class DecompressionError(ValueError):
 | 
						|
    pass
 |