Archipelago/worlds/cvcotm/lz10.py

266 lines
6.4 KiB
Python

from collections import defaultdict
from operator import itemgetter
import struct
from typing import Union
ByteString = Union[bytes, bytearray, memoryview]
"""
Taken from the Archipelago Metroid: Zero Mission implementation by Lil David at:
https://github.com/lilDavid/Archipelago-Metroid-Zero-Mission/blob/main/lz10.py
Tweaked version of nlzss modified to work with raw data and return bytes instead of operating on whole files.
LZ11 functionality has been removed since it is not necessary for Zero Mission nor Circle of the Moon.
https://github.com/magical/nlzss
"""
def decompress(data: ByteString):
"""Decompress LZSS-compressed bytes. Returns a bytearray containing the decompressed data."""
header = data[:4]
if header[0] == 0x10:
decompress_raw = decompress_raw_lzss10
else:
raise DecompressionError("not as lzss-compressed file")
decompressed_size = int.from_bytes(header[1:], "little")
data = data[4:]
return decompress_raw(data, decompressed_size)
def compress(data: bytearray):
byteOut = bytearray()
# header
byteOut.extend(struct.pack("<L", (len(data) << 8) + 0x10))
# body
length = 0
for tokens in chunkit(_compress(data), 8):
flags = [type(t) is tuple for t in tokens]
byteOut.extend(struct.pack(">B", packflags(flags)))
for t in tokens:
if type(t) is tuple:
count, disp = t
count -= 3
disp = (-disp) - 1
assert 0 <= disp < 4096
sh = (count << 12) | disp
byteOut.extend(struct.pack(">H", sh))
else:
byteOut.extend(struct.pack(">B", t))
length += 1
length += sum(2 if f else 1 for f in flags)
# padding
padding = 4 - (length % 4 or 4)
if padding:
byteOut.extend(b'\xff' * padding)
return byteOut
class SlidingWindow:
# The size of the sliding window
size = 4096
# The minimum displacement.
disp_min = 2
# The hard minimum — a disp less than this can't be represented in the
# compressed stream.
disp_start = 1
# The minimum length for a successful match in the window
match_min = 3
# The maximum length of a successful match, inclusive.
match_max = 3 + 0xf
def __init__(self, buf):
self.data = buf
self.hash = defaultdict(list)
self.full = False
self.start = 0
self.stop = 0
# self.index = self.disp_min - 1
self.index = 0
assert self.match_max is not None
def next(self):
if self.index < self.disp_start - 1:
self.index += 1
return
if self.full:
olditem = self.data[self.start]
assert self.hash[olditem][0] == self.start
self.hash[olditem].pop(0)
item = self.data[self.stop]
self.hash[item].append(self.stop)
self.stop += 1
self.index += 1
if self.full:
self.start += 1
else:
if self.size <= self.stop:
self.full = True
def advance(self, n=1):
"""Advance the window by n bytes"""
for _ in range(n):
self.next()
def search(self):
match_max = self.match_max
match_min = self.match_min
counts = []
indices = self.hash[self.data[self.index]]
for i in indices:
matchlen = self.match(i, self.index)
if matchlen >= match_min:
disp = self.index - i
if self.disp_min <= disp:
counts.append((matchlen, -disp))
if matchlen >= match_max:
return counts[-1]
if counts:
match = max(counts, key=itemgetter(0))
return match
return None
def match(self, start, bufstart):
size = self.index - start
if size == 0:
return 0
matchlen = 0
it = range(min(len(self.data) - bufstart, self.match_max))
for i in it:
if self.data[start + (i % size)] == self.data[bufstart + i]:
matchlen += 1
else:
break
return matchlen
def _compress(input, windowclass=SlidingWindow):
"""Generates a stream of tokens. Either a byte (int) or a tuple of (count,
displacement)."""
window = windowclass(input)
i = 0
while True:
if len(input) <= i:
break
match = window.search()
if match:
yield match
window.advance(match[0])
i += match[0]
else:
yield input[i]
window.next()
i += 1
def packflags(flags):
n = 0
for i in range(8):
n <<= 1
try:
if flags[i]:
n |= 1
except IndexError:
pass
return n
def chunkit(it, n):
buf = []
for x in it:
buf.append(x)
if n <= len(buf):
yield buf
buf = []
if buf:
yield buf
def bits(byte):
return ((byte >> 7) & 1,
(byte >> 6) & 1,
(byte >> 5) & 1,
(byte >> 4) & 1,
(byte >> 3) & 1,
(byte >> 2) & 1,
(byte >> 1) & 1,
byte & 1)
def decompress_raw_lzss10(indata, decompressed_size, _overlay=False):
"""Decompress LZSS-compressed bytes. Returns a bytearray."""
data = bytearray()
it = iter(indata)
if _overlay:
disp_extra = 3
else:
disp_extra = 1
def writebyte(b):
data.append(b)
def readbyte():
return next(it)
def readshort():
# big-endian
a = next(it)
b = next(it)
return (a << 8) | b
def copybyte():
data.append(next(it))
while len(data) < decompressed_size:
b = readbyte()
flags = bits(b)
for flag in flags:
if flag == 0:
copybyte()
elif flag == 1:
sh = readshort()
count = (sh >> 0xc) + 3
disp = (sh & 0xfff) + disp_extra
for _ in range(count):
writebyte(data[-disp])
else:
raise ValueError(flag)
if decompressed_size <= len(data):
break
if len(data) != decompressed_size:
raise DecompressionError("decompressed size does not match the expected size")
return data
class DecompressionError(ValueError):
pass