266 lines
6.4 KiB
Python
266 lines
6.4 KiB
Python
from collections import defaultdict
|
|
from operator import itemgetter
|
|
import struct
|
|
from typing import Union
|
|
|
|
ByteString = Union[bytes, bytearray, memoryview]
|
|
|
|
|
|
"""
|
|
Taken from the Archipelago Metroid: Zero Mission implementation by Lil David at:
|
|
https://github.com/lilDavid/Archipelago-Metroid-Zero-Mission/blob/main/lz10.py
|
|
|
|
Tweaked version of nlzss modified to work with raw data and return bytes instead of operating on whole files.
|
|
LZ11 functionality has been removed since it is not necessary for Zero Mission nor Circle of the Moon.
|
|
|
|
https://github.com/magical/nlzss
|
|
"""
|
|
|
|
|
|
def decompress(data: ByteString):
|
|
"""Decompress LZSS-compressed bytes. Returns a bytearray containing the decompressed data."""
|
|
header = data[:4]
|
|
if header[0] == 0x10:
|
|
decompress_raw = decompress_raw_lzss10
|
|
else:
|
|
raise DecompressionError("not as lzss-compressed file")
|
|
|
|
decompressed_size = int.from_bytes(header[1:], "little")
|
|
|
|
data = data[4:]
|
|
return decompress_raw(data, decompressed_size)
|
|
|
|
|
|
def compress(data: bytearray):
|
|
byteOut = bytearray()
|
|
# header
|
|
byteOut.extend(struct.pack("<L", (len(data) << 8) + 0x10))
|
|
|
|
# body
|
|
length = 0
|
|
for tokens in chunkit(_compress(data), 8):
|
|
flags = [type(t) is tuple for t in tokens]
|
|
byteOut.extend(struct.pack(">B", packflags(flags)))
|
|
|
|
for t in tokens:
|
|
if type(t) is tuple:
|
|
count, disp = t
|
|
count -= 3
|
|
disp = (-disp) - 1
|
|
assert 0 <= disp < 4096
|
|
sh = (count << 12) | disp
|
|
byteOut.extend(struct.pack(">H", sh))
|
|
else:
|
|
byteOut.extend(struct.pack(">B", t))
|
|
|
|
length += 1
|
|
length += sum(2 if f else 1 for f in flags)
|
|
|
|
# padding
|
|
padding = 4 - (length % 4 or 4)
|
|
if padding:
|
|
byteOut.extend(b'\xff' * padding)
|
|
return byteOut
|
|
|
|
|
|
class SlidingWindow:
|
|
# The size of the sliding window
|
|
size = 4096
|
|
|
|
# The minimum displacement.
|
|
disp_min = 2
|
|
|
|
# The hard minimum — a disp less than this can't be represented in the
|
|
# compressed stream.
|
|
disp_start = 1
|
|
|
|
# The minimum length for a successful match in the window
|
|
match_min = 3
|
|
|
|
# The maximum length of a successful match, inclusive.
|
|
match_max = 3 + 0xf
|
|
|
|
def __init__(self, buf):
|
|
self.data = buf
|
|
self.hash = defaultdict(list)
|
|
self.full = False
|
|
|
|
self.start = 0
|
|
self.stop = 0
|
|
# self.index = self.disp_min - 1
|
|
self.index = 0
|
|
|
|
assert self.match_max is not None
|
|
|
|
def next(self):
|
|
if self.index < self.disp_start - 1:
|
|
self.index += 1
|
|
return
|
|
|
|
if self.full:
|
|
olditem = self.data[self.start]
|
|
assert self.hash[olditem][0] == self.start
|
|
self.hash[olditem].pop(0)
|
|
|
|
item = self.data[self.stop]
|
|
self.hash[item].append(self.stop)
|
|
self.stop += 1
|
|
self.index += 1
|
|
|
|
if self.full:
|
|
self.start += 1
|
|
else:
|
|
if self.size <= self.stop:
|
|
self.full = True
|
|
|
|
def advance(self, n=1):
|
|
"""Advance the window by n bytes"""
|
|
for _ in range(n):
|
|
self.next()
|
|
|
|
def search(self):
|
|
match_max = self.match_max
|
|
match_min = self.match_min
|
|
|
|
counts = []
|
|
indices = self.hash[self.data[self.index]]
|
|
for i in indices:
|
|
matchlen = self.match(i, self.index)
|
|
if matchlen >= match_min:
|
|
disp = self.index - i
|
|
if self.disp_min <= disp:
|
|
counts.append((matchlen, -disp))
|
|
if matchlen >= match_max:
|
|
return counts[-1]
|
|
|
|
if counts:
|
|
match = max(counts, key=itemgetter(0))
|
|
return match
|
|
|
|
return None
|
|
|
|
def match(self, start, bufstart):
|
|
size = self.index - start
|
|
|
|
if size == 0:
|
|
return 0
|
|
|
|
matchlen = 0
|
|
it = range(min(len(self.data) - bufstart, self.match_max))
|
|
for i in it:
|
|
if self.data[start + (i % size)] == self.data[bufstart + i]:
|
|
matchlen += 1
|
|
else:
|
|
break
|
|
return matchlen
|
|
|
|
|
|
def _compress(input, windowclass=SlidingWindow):
|
|
"""Generates a stream of tokens. Either a byte (int) or a tuple of (count,
|
|
displacement)."""
|
|
|
|
window = windowclass(input)
|
|
|
|
i = 0
|
|
while True:
|
|
if len(input) <= i:
|
|
break
|
|
match = window.search()
|
|
if match:
|
|
yield match
|
|
window.advance(match[0])
|
|
i += match[0]
|
|
else:
|
|
yield input[i]
|
|
window.next()
|
|
i += 1
|
|
|
|
|
|
def packflags(flags):
|
|
n = 0
|
|
for i in range(8):
|
|
n <<= 1
|
|
try:
|
|
if flags[i]:
|
|
n |= 1
|
|
except IndexError:
|
|
pass
|
|
return n
|
|
|
|
|
|
def chunkit(it, n):
|
|
buf = []
|
|
for x in it:
|
|
buf.append(x)
|
|
if n <= len(buf):
|
|
yield buf
|
|
buf = []
|
|
if buf:
|
|
yield buf
|
|
|
|
|
|
def bits(byte):
|
|
return ((byte >> 7) & 1,
|
|
(byte >> 6) & 1,
|
|
(byte >> 5) & 1,
|
|
(byte >> 4) & 1,
|
|
(byte >> 3) & 1,
|
|
(byte >> 2) & 1,
|
|
(byte >> 1) & 1,
|
|
byte & 1)
|
|
|
|
|
|
def decompress_raw_lzss10(indata, decompressed_size, _overlay=False):
|
|
"""Decompress LZSS-compressed bytes. Returns a bytearray."""
|
|
data = bytearray()
|
|
|
|
it = iter(indata)
|
|
|
|
if _overlay:
|
|
disp_extra = 3
|
|
else:
|
|
disp_extra = 1
|
|
|
|
def writebyte(b):
|
|
data.append(b)
|
|
|
|
def readbyte():
|
|
return next(it)
|
|
|
|
def readshort():
|
|
# big-endian
|
|
a = next(it)
|
|
b = next(it)
|
|
return (a << 8) | b
|
|
|
|
def copybyte():
|
|
data.append(next(it))
|
|
|
|
while len(data) < decompressed_size:
|
|
b = readbyte()
|
|
flags = bits(b)
|
|
for flag in flags:
|
|
if flag == 0:
|
|
copybyte()
|
|
elif flag == 1:
|
|
sh = readshort()
|
|
count = (sh >> 0xc) + 3
|
|
disp = (sh & 0xfff) + disp_extra
|
|
|
|
for _ in range(count):
|
|
writebyte(data[-disp])
|
|
else:
|
|
raise ValueError(flag)
|
|
|
|
if decompressed_size <= len(data):
|
|
break
|
|
|
|
if len(data) != decompressed_size:
|
|
raise DecompressionError("decompressed size does not match the expected size")
|
|
|
|
return data
|
|
|
|
|
|
class DecompressionError(ValueError):
|
|
pass
|