Core: implement APProcedurePatch and APTokenMixin (#2536)

* initial work on procedure patch

* more flexibility

load default procedure for version 5 patches
add args for procedure
add default extension for tokens and bsdiff
allow specifying additional required extensions for generation

* pushing current changes to go fix tloz bug

* move tokens into a separate inheritable class

* forgot the commit to remove token from ProcedurePatch

* further cleaning from bad commit

* start on docstrings

* further work on docstrings and typing

* improve docstrings

* fix incorrect docstring

* cleanup

* clean defaults and docstring

* define interface that has only the bare minimum required
for `Patch.create_rom_file`

* change to dictionary.get

* remove unnecessary if statement

* update to explicitly check for procedure, restore compatible version and manual override

* Update Files.py

* remove struct uses

* ensure returning bytes, add token type checking

* Apply suggestions from code review

Co-authored-by: Doug Hoskisson <beauxq@users.noreply.github.com>

* pep8

---------

Co-authored-by: beauxq <beauxq@yahoo.com>
Co-authored-by: Doug Hoskisson <beauxq@users.noreply.github.com>
This commit is contained in:
Silvris 2024-03-19 17:08:29 -05:00 committed by GitHub
parent 8a8263fa61
commit 94650a02de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 247 additions and 35 deletions

View File

@ -3,10 +3,11 @@ from __future__ import annotations
import abc
import json
import zipfile
from enum import IntEnum
import os
import threading
from typing import ClassVar, Dict, List, Literal, Tuple, Any, Optional, Union, BinaryIO
from typing import ClassVar, Dict, List, Literal, Tuple, Any, Optional, Union, BinaryIO, overload
import bsdiff4
@ -38,6 +39,32 @@ class AutoPatchRegister(abc.ABCMeta):
return None
class AutoPatchExtensionRegister(abc.ABCMeta):
extension_types: ClassVar[Dict[str, AutoPatchExtensionRegister]] = {}
required_extensions: List[str] = []
def __new__(mcs, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> AutoPatchExtensionRegister:
# construct class
new_class = super().__new__(mcs, name, bases, dct)
if "game" in dct:
AutoPatchExtensionRegister.extension_types[dct["game"]] = new_class
return new_class
@staticmethod
def get_handler(game: str) -> Union[AutoPatchExtensionRegister, List[AutoPatchExtensionRegister]]:
handler = AutoPatchExtensionRegister.extension_types.get(game, APPatchExtension)
if handler.required_extensions:
handlers = [handler]
for required in handler.required_extensions:
ext = AutoPatchExtensionRegister.extension_types.get(required)
if not ext:
raise NotImplementedError(f"No handler for {required}.")
handlers.append(ext)
return handlers
else:
return handler
container_version: int = 6
@ -157,27 +184,14 @@ class APAutoPatchInterface(APPatch, abc.ABC, metaclass=AutoPatchRegister):
""" create the output file with the file name `target` """
class APDeltaPatch(APAutoPatchInterface):
"""An implementation of `APAutoPatchInterface` that additionally
has delta.bsdiff4 containing a delta patch to get the desired file."""
class APProcedurePatch(APAutoPatchInterface):
"""
An APPatch that defines a procedure to produce the desired file.
"""
hash: Optional[str] # base checksum of source file
patch_file_ending: str = ""
delta: Optional[bytes] = None
source_data: bytes
procedure = None # delete this line when APPP is added
def __init__(self, *args: Any, patched_path: str = "", **kwargs: Any) -> None:
self.patched_path = patched_path
super(APDeltaPatch, self).__init__(*args, **kwargs)
def get_manifest(self) -> Dict[str, Any]:
manifest = super(APDeltaPatch, self).get_manifest()
manifest["base_checksum"] = self.hash
manifest["result_file_ending"] = self.result_file_ending
manifest["patch_file_ending"] = self.patch_file_ending
manifest["compatible_version"] = 5 # delete this line when APPP is added
return manifest
patch_file_ending: str = ""
files: Dict[str, bytes] = {}
@classmethod
def get_source_data(cls) -> bytes:
@ -190,21 +204,219 @@ class APDeltaPatch(APAutoPatchInterface):
cls.source_data = cls.get_source_data()
return cls.source_data
def write_contents(self, opened_zipfile: zipfile.ZipFile):
super(APDeltaPatch, self).write_contents(opened_zipfile)
# write Delta
opened_zipfile.writestr("delta.bsdiff4",
bsdiff4.diff(self.get_source_data_with_cache(), open(self.patched_path, "rb").read()),
compress_type=zipfile.ZIP_STORED) # bsdiff4 is a format with integrated compression
def __init__(self, *args: Any, **kwargs: Any):
super(APProcedurePatch, self).__init__(*args, **kwargs)
def read_contents(self, opened_zipfile: zipfile.ZipFile):
super(APDeltaPatch, self).read_contents(opened_zipfile)
self.delta = opened_zipfile.read("delta.bsdiff4")
def get_manifest(self) -> Dict[str, Any]:
manifest = super(APProcedurePatch, self).get_manifest()
manifest["base_checksum"] = self.hash
manifest["result_file_ending"] = self.result_file_ending
manifest["patch_file_ending"] = self.patch_file_ending
manifest["procedure"] = self.procedure
if self.procedure == APDeltaPatch.procedure:
manifest["compatible_version"] = 5
return manifest
def patch(self, target: str):
"""Base + Delta -> Patched"""
if not self.delta:
def read_contents(self, opened_zipfile: zipfile.ZipFile) -> None:
super(APProcedurePatch, self).read_contents(opened_zipfile)
with opened_zipfile.open("archipelago.json", "r") as f:
manifest = json.load(f)
if "procedure" not in manifest:
# support patching files made before moving to procedures
self.procedure = [("apply_bsdiff4", ["delta.bsdiff4"])]
else:
self.procedure = manifest["procedure"]
for file in opened_zipfile.namelist():
if file not in ["archipelago.json"]:
self.files[file] = opened_zipfile.read(file)
def write_contents(self, opened_zipfile: zipfile.ZipFile) -> None:
super(APProcedurePatch, self).write_contents(opened_zipfile)
for file in self.files:
opened_zipfile.writestr(file, self.files[file],
compress_type=zipfile.ZIP_STORED if file.endswith(".bsdiff4") else None)
def get_file(self, file: str) -> bytes:
""" Retrieves a file from the patch container."""
if file not in self.files:
self.read()
result = bsdiff4.patch(self.get_source_data_with_cache(), self.delta)
with open(target, "wb") as f:
f.write(result)
return self.files[file]
def write_file(self, file_name: str, file: bytes) -> None:
""" Writes a file to the patch container, to be retrieved upon patching. """
self.files[file_name] = file
def patch(self, target: str) -> None:
self.read()
base_data = self.get_source_data_with_cache()
patch_extender = AutoPatchExtensionRegister.get_handler(self.game)
assert not isinstance(self.procedure, str), f"{type(self)} must define procedures"
for step, args in self.procedure:
if isinstance(patch_extender, list):
extension = next((item for item in [getattr(extender, step, None) for extender in patch_extender]
if item is not None), None)
else:
extension = getattr(patch_extender, step, None)
if extension is not None:
base_data = extension(self, base_data, *args)
else:
raise NotImplementedError(f"Unknown procedure {step} for {self.game}.")
with open(target, 'wb') as f:
f.write(base_data)
class APDeltaPatch(APProcedurePatch):
"""An APProcedurePatch that additionally has delta.bsdiff4
containing a delta patch to get the desired file, often a rom."""
procedure = [
("apply_bsdiff4", ["delta.bsdiff4"])
]
def __init__(self, *args: Any, patched_path: str = "", **kwargs: Any) -> None:
super(APDeltaPatch, self).__init__(*args, **kwargs)
self.patched_path = patched_path
def write_contents(self, opened_zipfile: zipfile.ZipFile):
self.write_file("delta.bsdiff4",
bsdiff4.diff(self.get_source_data_with_cache(), open(self.patched_path, "rb").read()))
super(APDeltaPatch, self).write_contents(opened_zipfile)
class APTokenTypes(IntEnum):
WRITE = 0
COPY = 1
RLE = 2
AND_8 = 3
OR_8 = 4
XOR_8 = 5
class APTokenMixin:
"""
A class that defines functions for generating a token binary, for use in patches.
"""
tokens: List[
Tuple[APTokenTypes, int, Union[
bytes, # WRITE
Tuple[int, int], # COPY, RLE
int # AND_8, OR_8, XOR_8
]]] = []
def get_token_binary(self) -> bytes:
"""
Returns the token binary created from stored tokens.
:return: A bytes object representing the token data.
"""
data = bytearray()
data.extend(len(self.tokens).to_bytes(4, "little"))
for token_type, offset, args in self.tokens:
data.append(token_type)
data.extend(offset.to_bytes(4, "little"))
if token_type in [APTokenTypes.AND_8, APTokenTypes.OR_8, APTokenTypes.XOR_8]:
assert isinstance(args, int), f"Arguments to AND/OR/XOR must be of type int, not {type(args)}"
data.extend(int.to_bytes(1, 4, "little"))
data.append(args)
elif token_type in [APTokenTypes.COPY, APTokenTypes.RLE]:
assert isinstance(args, tuple), f"Arguments to COPY/RLE must be of type tuple, not {type(args)}"
data.extend(int.to_bytes(4, 4, "little"))
data.extend(args[0].to_bytes(4, "little"))
data.extend(args[1].to_bytes(4, "little"))
elif token_type == APTokenTypes.WRITE:
assert isinstance(args, bytes), f"Arguments to WRITE must be of type bytes, not {type(args)}"
data.extend(len(args).to_bytes(4, "little"))
data.extend(args)
else:
raise ValueError(f"Unknown token type {token_type}")
return bytes(data)
@overload
def write_token(self,
token_type: Literal[APTokenTypes.AND_8, APTokenTypes.OR_8, APTokenTypes.XOR_8],
offset: int,
data: int) -> None:
...
@overload
def write_token(self,
token_type: Literal[APTokenTypes.COPY, APTokenTypes.RLE],
offset: int,
data: Tuple[int, int]) -> None:
...
@overload
def write_token(self,
token_type: Literal[APTokenTypes.WRITE],
offset: int,
data: bytes) -> None:
...
def write_token(self, token_type: APTokenTypes, offset: int, data: Union[bytes, Tuple[int, int], int]):
"""
Stores a token to be used by patching.
"""
self.tokens.append((token_type, offset, data))
class APPatchExtension(metaclass=AutoPatchExtensionRegister):
"""Class that defines patch extension functions for a given game.
Patch extension functions must have the following two arguments in the following order:
caller: APProcedurePatch (used to retrieve files from the patch container)
rom: bytes (the data to patch)
Further arguments are passed in from the procedure as defined.
Patch extension functions must return the changed bytes.
"""
game: str
required_extensions: List[str] = []
@staticmethod
def apply_bsdiff4(caller: APProcedurePatch, rom: bytes, patch: str):
"""Applies the given bsdiff4 from the patch onto the current file."""
return bsdiff4.patch(rom, caller.get_file(patch))
@staticmethod
def apply_tokens(caller: APProcedurePatch, rom: bytes, token_file: str) -> bytes:
"""Applies the given token file from the patch onto the current file."""
token_data = caller.get_file(token_file)
rom_data = bytearray(rom)
token_count = int.from_bytes(token_data[0:4], "little")
bpr = 4
for _ in range(token_count):
token_type = token_data[bpr:bpr + 1][0]
offset = int.from_bytes(token_data[bpr + 1:bpr + 5], "little")
size = int.from_bytes(token_data[bpr + 5:bpr + 9], "little")
data = token_data[bpr + 9:bpr + 9 + size]
if token_type in [APTokenTypes.AND_8, APTokenTypes.OR_8, APTokenTypes.XOR_8]:
arg = data[0]
if token_type == APTokenTypes.AND_8:
rom_data[offset] = rom_data[offset] & arg
elif token_type == APTokenTypes.OR_8:
rom_data[offset] = rom_data[offset] | arg
else:
rom_data[offset] = rom_data[offset] ^ arg
elif token_type in [APTokenTypes.COPY, APTokenTypes.RLE]:
length = int.from_bytes(data[:4], "little")
value = int.from_bytes(data[4:], "little")
if token_type == APTokenTypes.COPY:
rom_data[offset: offset + length] = rom_data[value: value + length]
else:
rom_data[offset: offset + length] = bytes([value] * length)
else:
rom_data[offset:offset + len(data)] = data
bpr += 9 + size
return bytes(rom_data)
@staticmethod
def calc_snes_crc(caller: APProcedurePatch, rom: bytes):
"""Calculates and applies a valid CRC for the SNES rom header."""
rom_data = bytearray(rom)
if len(rom) < 0x8000:
raise Exception("Tried to calculate SNES CRC on file too small to be a SNES ROM.")
crc = (sum(rom_data[:0x7FDC] + rom_data[0x7FE0:]) + 0x01FE) & 0xFFFF
inv = crc ^ 0xFFFF
rom_data[0x7FDC:0x7FE0] = [inv & 0xFF, (inv >> 8) & 0xFF, crc & 0xFF, (crc >> 8) & 0xFF]
return bytes(rom_data)