258 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			258 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Python
		
	
	
	
| from collections import defaultdict
 | |
| from operator import itemgetter
 | |
| from struct import pack, unpack
 | |
| 
 | |
| """
 | |
| Tweaked version of nlzss modified to work with raw data and return bytes instead of operating on whole files.
 | |
| LZ11 functionality has been removed since it is not necessary for MMBN3
 | |
| 
 | |
| https://github.com/magical/nlzss
 | |
| """
 | |
| 
 | |
| def gba_decompress(data: bytearray):
 | |
|     """Decompress LZSS-compressed bytes. Returns a bytearray."""
 | |
|     header = data[:4]
 | |
|     if header[0] == 0x10:
 | |
|         decompress_raw = decompress_raw_lzss10
 | |
|     else:
 | |
|         raise DecompressionError("not as lzss-compressed file")
 | |
| 
 | |
|     decompressed_size, = unpack("<L", header[1:] + b'\x00')
 | |
| 
 | |
|     data = data[4:]
 | |
|     return decompress_raw(data, decompressed_size)
 | |
| 
 | |
| 
 | |
| def gba_compress(data: bytearray):
 | |
|     byteOut = bytearray()
 | |
|     # header
 | |
|     byteOut.extend(pack("<L", (len(data) << 8) + 0x10))
 | |
| 
 | |
|     # body
 | |
|     length = 0
 | |
|     for tokens in chunkit(_compress(data), 8):
 | |
|         flags = [type(t) == tuple for t in tokens]
 | |
|         byteOut.extend(pack(">B", packflags(flags)))
 | |
| 
 | |
|         for t in tokens:
 | |
|             if type(t) == tuple:
 | |
|                 count, disp = t
 | |
|                 count -= 3
 | |
|                 disp = (-disp) - 1
 | |
|                 assert 0 <= disp < 4096
 | |
|                 sh = (count << 12) | disp
 | |
|                 byteOut.extend(pack(">H", sh))
 | |
|             else:
 | |
|                 byteOut.extend(pack(">B", t))
 | |
| 
 | |
|         length += 1
 | |
|         length += sum(2 if f else 1 for f in flags)
 | |
| 
 | |
|     # padding
 | |
|     padding = 4 - (length % 4 or 4)
 | |
|     if padding:
 | |
|         byteOut.extend(b'\xff' * padding)
 | |
|     return byteOut
 | |
| 
 | |
| 
 | |
| class SlidingWindow:
 | |
|     # The size of the sliding window
 | |
|     size = 4096
 | |
| 
 | |
|     # The minimum displacement.
 | |
|     disp_min = 2
 | |
| 
 | |
|     # The hard minimum — a disp less than this can't be represented in the
 | |
|     # compressed stream.
 | |
|     disp_start = 1
 | |
| 
 | |
|     # The minimum length for a successful match in the window
 | |
|     match_min = 3
 | |
| 
 | |
|     # The maximum length of a successful match, inclusive.
 | |
|     match_max = 3 + 0xf
 | |
| 
 | |
|     def __init__(self, buf):
 | |
|         self.data = buf
 | |
|         self.hash = defaultdict(list)
 | |
|         self.full = False
 | |
| 
 | |
|         self.start = 0
 | |
|         self.stop = 0
 | |
|         #self.index = self.disp_min - 1
 | |
|         self.index = 0
 | |
| 
 | |
|         assert self.match_max is not None
 | |
| 
 | |
|     def next(self):
 | |
|         if self.index < self.disp_start - 1:
 | |
|             self.index += 1
 | |
|             return
 | |
| 
 | |
|         if self.full:
 | |
|             olditem = self.data[self.start]
 | |
|             assert self.hash[olditem][0] == self.start
 | |
|             self.hash[olditem].pop(0)
 | |
| 
 | |
|         item = self.data[self.stop]
 | |
|         self.hash[item].append(self.stop)
 | |
|         self.stop += 1
 | |
|         self.index += 1
 | |
| 
 | |
|         if self.full:
 | |
|             self.start += 1
 | |
|         else:
 | |
|             if self.size <= self.stop:
 | |
|                 self.full = True
 | |
| 
 | |
|     def advance(self, n=1):
 | |
|         """Advance the window by n bytes"""
 | |
|         for _ in range(n):
 | |
|             self.next()
 | |
| 
 | |
|     def search(self):
 | |
|         match_max = self.match_max
 | |
|         match_min = self.match_min
 | |
| 
 | |
|         counts = []
 | |
|         indices = self.hash[self.data[self.index]]
 | |
|         for i in indices:
 | |
|             matchlen = self.match(i, self.index)
 | |
|             if matchlen >= match_min:
 | |
|                 disp = self.index - i
 | |
|                 if self.disp_min <= disp:
 | |
|                     counts.append((matchlen, -disp))
 | |
|                     if matchlen >= match_max:
 | |
|                         return counts[-1]
 | |
| 
 | |
|         if counts:
 | |
|             match = max(counts, key=itemgetter(0))
 | |
|             return match
 | |
| 
 | |
|         return None
 | |
| 
 | |
|     def match(self, start, bufstart):
 | |
|         size = self.index - start
 | |
| 
 | |
|         if size == 0:
 | |
|             return 0
 | |
| 
 | |
|         matchlen = 0
 | |
|         it = range(min(len(self.data) - bufstart, self.match_max))
 | |
|         for i in it:
 | |
|             if self.data[start + (i % size)] == self.data[bufstart + i]:
 | |
|                 matchlen += 1
 | |
|             else:
 | |
|                 break
 | |
|         return matchlen
 | |
| 
 | |
| 
 | |
| def _compress(input, windowclass=SlidingWindow):
 | |
|     """Generates a stream of tokens. Either a byte (int) or a tuple of (count,
 | |
|     displacement)."""
 | |
| 
 | |
|     window = windowclass(input)
 | |
| 
 | |
|     i = 0
 | |
|     while True:
 | |
|         if len(input) <= i:
 | |
|             break
 | |
|         match = window.search()
 | |
|         if match:
 | |
|             yield match
 | |
|             window.advance(match[0])
 | |
|             i += match[0]
 | |
|         else:
 | |
|             yield input[i]
 | |
|             window.next()
 | |
|             i += 1
 | |
| 
 | |
| 
 | |
| def packflags(flags):
 | |
|     n = 0
 | |
|     for i in range(8):
 | |
|         n <<= 1
 | |
|         try:
 | |
|             if flags[i]:
 | |
|                 n |= 1
 | |
|         except IndexError:
 | |
|             pass
 | |
|     return n
 | |
| 
 | |
| 
 | |
| def chunkit(it, n):
 | |
|     buf = []
 | |
|     for x in it:
 | |
|         buf.append(x)
 | |
|         if n <= len(buf):
 | |
|             yield buf
 | |
|             buf = []
 | |
|     if buf:
 | |
|         yield buf
 | |
| 
 | |
| 
 | |
| def bits(byte):
 | |
|     return ((byte >> 7) & 1,
 | |
|             (byte >> 6) & 1,
 | |
|             (byte >> 5) & 1,
 | |
|             (byte >> 4) & 1,
 | |
|             (byte >> 3) & 1,
 | |
|             (byte >> 2) & 1,
 | |
|             (byte >> 1) & 1,
 | |
|             byte & 1)
 | |
| 
 | |
| 
 | |
| def decompress_raw_lzss10(indata, decompressed_size, _overlay=False):
 | |
|     """Decompress LZSS-compressed bytes. Returns a bytearray."""
 | |
|     data = bytearray()
 | |
| 
 | |
|     it = iter(indata)
 | |
| 
 | |
|     if _overlay:
 | |
|         disp_extra = 3
 | |
|     else:
 | |
|         disp_extra = 1
 | |
| 
 | |
|     def writebyte(b):
 | |
|         data.append(b)
 | |
| 
 | |
|     def readbyte():
 | |
|         return next(it)
 | |
| 
 | |
|     def readshort():
 | |
|         # big-endian
 | |
|         a = next(it)
 | |
|         b = next(it)
 | |
|         return (a << 8) | b
 | |
| 
 | |
|     def copybyte():
 | |
|         data.append(next(it))
 | |
| 
 | |
|     while len(data) < decompressed_size:
 | |
|         b = readbyte()
 | |
|         flags = bits(b)
 | |
|         for flag in flags:
 | |
|             if flag == 0:
 | |
|                 copybyte()
 | |
|             elif flag == 1:
 | |
|                 sh = readshort()
 | |
|                 count = (sh >> 0xc) + 3
 | |
|                 disp = (sh & 0xfff) + disp_extra
 | |
| 
 | |
|                 for _ in range(count):
 | |
|                     writebyte(data[-disp])
 | |
|             else:
 | |
|                 raise ValueError(flag)
 | |
| 
 | |
|             if decompressed_size <= len(data):
 | |
|                 break
 | |
| 
 | |
|     if len(data) != decompressed_size:
 | |
|         raise DecompressionError("decompressed size does not match the expected size")
 | |
| 
 | |
|     return data
 | |
| 
 | |
| 
 | |
| class DecompressionError(ValueError):
 | |
|     pass
 |