2022-03-15 12:55:57 +00:00
import itertools
2023-04-08 20:52:34 +00:00
from . TotalSMZ3 . Text . Texts import openFile
2022-03-15 12:55:57 +00:00
def range_union ( ranges ) :
ret = [ ]
for rg in sorted ( [ [ r . start , r . stop ] for r in ranges ] ) :
begin , end = rg [ 0 ] , rg [ - 1 ]
if ret and ret [ - 1 ] [ 1 ] > begin :
ret [ - 1 ] [ 1 ] = max ( ret [ - 1 ] [ 1 ] , end )
else :
ret . append ( [ begin , end ] )
return [ range ( r [ 0 ] , r [ 1 ] ) for r in ret ]
# adapted from ips-util for python 3.2 (https://pypi.org/project/ips-util/)
class IPS_Patch ( object ) :
def __init__ ( self , patchDict = None ) :
self . records = [ ]
self . truncate_length = None
self . max_size = 0
if patchDict is not None :
for addr , data in patchDict . items ( ) :
byteData = bytearray ( data )
self . add_record ( addr , byteData )
def toDict ( self ) :
ret = { }
for record in self . records :
if ' rle_count ' in record :
ret [ record [ ' address ' ] ] = [ int . from_bytes ( record [ ' data ' ] , ' little ' ) ] * record [ ' rle_count ' ]
else :
ret [ record [ ' address ' ] ] = [ int ( b ) for b in record [ ' data ' ] ]
return ret
@staticmethod
def load ( filename ) :
loaded_patch = IPS_Patch ( )
2023-04-08 20:52:34 +00:00
with openFile ( filename , ' rb ' ) as file :
2022-03-15 12:55:57 +00:00
header = file . read ( 5 )
if header != b ' PATCH ' :
raise Exception ( ' Not a valid IPS patch file! ' )
while True :
address_bytes = file . read ( 3 )
if address_bytes == b ' EOF ' :
break
address = int . from_bytes ( address_bytes , byteorder = ' big ' )
length = int . from_bytes ( file . read ( 2 ) , byteorder = ' big ' )
rle_count = 0
if length == 0 :
rle_count = int . from_bytes ( file . read ( 2 ) , byteorder = ' big ' )
length = 1
data = file . read ( length )
if rle_count > 0 :
loaded_patch . add_rle_record ( address , data , rle_count )
else :
loaded_patch . add_record ( address , data )
truncate_bytes = file . read ( 3 )
if len ( truncate_bytes ) == 3 :
loaded_patch . set_truncate_length ( int . from_bytes ( truncate_bytes , byteorder = ' big ' ) )
return loaded_patch
@staticmethod
def create ( original_data , patched_data ) :
# The heuristics for optimizing a patch were chosen with reference to
# the source code of Flips: https://github.com/Alcaro/Flips
patch = IPS_Patch ( )
run_in_progress = False
current_run_start = 0
current_run_data = bytearray ( )
runs = [ ]
if len ( original_data ) > len ( patched_data ) :
patch . set_truncate_length ( len ( patched_data ) )
original_data = original_data [ : len ( patched_data ) ]
elif len ( original_data ) < len ( patched_data ) :
original_data + = bytes ( [ 0 ] * ( len ( patched_data ) - len ( original_data ) ) )
if original_data [ - 1 ] == 0 and patched_data [ - 1 ] == 0 :
patch . add_record ( len ( patched_data ) - 1 , bytes ( [ 0 ] ) )
for index , ( original , patched ) in enumerate ( zip ( original_data , patched_data ) ) :
if not run_in_progress :
if original != patched :
run_in_progress = True
current_run_start = index
current_run_data = bytearray ( [ patched ] )
else :
if original == patched :
runs . append ( ( current_run_start , current_run_data ) )
run_in_progress = False
else :
current_run_data . append ( patched )
if run_in_progress :
runs . append ( ( current_run_start , current_run_data ) )
for start , data in runs :
if start == int . from_bytes ( b ' EOF ' , byteorder = ' big ' ) :
start - = 1
data = bytes ( [ patched_data [ start - 1 ] ] ) + data
grouped_byte_data = list ( [
{ ' val ' : key , ' count ' : sum ( 1 for _ in group ) , ' is_last ' : False }
for key , group in itertools . groupby ( data )
] )
grouped_byte_data [ - 1 ] [ ' is_last ' ] = True
record_in_progress = bytearray ( )
pos = start
for group in grouped_byte_data :
if len ( record_in_progress ) > 0 :
# We don't want to interrupt a record in progress with a new header unless
# this group is longer than two complete headers.
if group [ ' count ' ] > 13 :
patch . add_record ( pos , record_in_progress )
pos + = len ( record_in_progress )
record_in_progress = bytearray ( )
patch . add_rle_record ( pos , bytes ( [ group [ ' val ' ] ] ) , group [ ' count ' ] )
pos + = group [ ' count ' ]
else :
record_in_progress + = bytes ( [ group [ ' val ' ] ] * group [ ' count ' ] )
elif ( group [ ' count ' ] > 3 and group [ ' is_last ' ] ) or group [ ' count ' ] > 8 :
# We benefit from making this an RLE record if the length is at least 8,
# or the length is at least 3 and we know it to be the last part of this diff.
# Make sure not to overflow the maximum length. Split it up if necessary.
remaining_length = group [ ' count ' ]
while remaining_length > 0xffff :
patch . add_rle_record ( pos , bytes ( [ group [ ' val ' ] ] ) , 0xffff )
remaining_length - = 0xffff
pos + = 0xffff
patch . add_rle_record ( pos , bytes ( [ group [ ' val ' ] ] ) , remaining_length )
pos + = remaining_length
else :
# Just begin a new standard record.
record_in_progress + = bytes ( [ group [ ' val ' ] ] * group [ ' count ' ] )
if len ( record_in_progress ) > 0xffff :
patch . add_record ( pos , record_in_progress [ : 0xffff ] )
record_in_progress = record_in_progress [ 0xffff : ]
pos + = 0xffff
# Finalize any record still in progress.
if len ( record_in_progress ) > 0 :
patch . add_record ( pos , record_in_progress )
return patch
def add_record ( self , address , data ) :
if address == int . from_bytes ( b ' EOF ' , byteorder = ' big ' ) :
raise RuntimeError ( ' Start address {0:x} is invalid in the IPS format. Please shift your starting address back by one byte to avoid it. ' . format ( address ) )
if address > 0xffffff :
raise RuntimeError ( ' Start address {0:x} is too large for the IPS format. Addresses must fit into 3 bytes. ' . format ( address ) )
if len ( data ) > 0xffff :
raise RuntimeError ( ' Record with length {0} is too large for the IPS format. Records must be less than 65536 bytes. ' . format ( len ( data ) ) )
if len ( data ) == 0 : # ignore empty records
return
record = { ' address ' : address , ' data ' : data , ' size ' : len ( data ) }
self . appendRecord ( record )
def add_rle_record ( self , address , data , count ) :
if address == int . from_bytes ( b ' EOF ' , byteorder = ' big ' ) :
raise RuntimeError ( ' Start address {0:x} is invalid in the IPS format. Please shift your starting address back by one byte to avoid it. ' . format ( address ) )
if address > 0xffffff :
raise RuntimeError ( ' Start address {0:x} is too large for the IPS format. Addresses must fit into 3 bytes. ' . format ( address ) )
if count > 0xffff :
raise RuntimeError ( ' RLE record with length {0} is too large for the IPS format. RLE records must be less than 65536 bytes. ' . format ( count ) )
if len ( data ) != 1 :
raise RuntimeError ( ' Data for RLE record must be exactly one byte! Received {0} . ' . format ( data ) )
record = { ' address ' : address , ' data ' : data , ' rle_count ' : count , ' size ' : count }
self . appendRecord ( record )
def appendRecord ( self , record ) :
sz = record [ ' address ' ] + record [ ' size ' ]
if sz > self . max_size :
self . max_size = sz
self . records . append ( record )
def set_truncate_length ( self , truncate_length ) :
self . truncate_length = truncate_length
def encode ( self ) :
encoded_bytes = bytearray ( )
encoded_bytes + = ' PATCH ' . encode ( ' ascii ' )
for record in self . records :
encoded_bytes + = record [ ' address ' ] . to_bytes ( 3 , byteorder = ' big ' )
if ' rle_count ' in record :
encoded_bytes + = ( 0 ) . to_bytes ( 2 , byteorder = ' big ' )
encoded_bytes + = record [ ' rle_count ' ] . to_bytes ( 2 , byteorder = ' big ' )
else :
encoded_bytes + = len ( record [ ' data ' ] ) . to_bytes ( 2 , byteorder = ' big ' )
encoded_bytes + = record [ ' data ' ]
encoded_bytes + = ' EOF ' . encode ( ' ascii ' )
if self . truncate_length is not None :
encoded_bytes + = self . truncate_length . to_bytes ( 3 , byteorder = ' big ' )
return encoded_bytes
# save patch into IPS file
def save ( self , path ) :
with open ( path , ' wb ' ) as ipsFile :
ipsFile . write ( self . encode ( ) )
# applies patch on an existing bytearray
def apply ( self , in_data ) :
out_data = bytearray ( in_data )
for record in self . records :
if record [ ' address ' ] > = len ( out_data ) :
out_data + = bytes ( [ 0 ] * ( record [ ' address ' ] - len ( out_data ) + 1 ) )
if ' rle_count ' in record :
out_data [ record [ ' address ' ] : record [ ' address ' ] + record [ ' rle_count ' ] ] = b ' ' . join ( [ record [ ' data ' ] ] * record [ ' rle_count ' ] )
else :
out_data [ record [ ' address ' ] : record [ ' address ' ] + len ( record [ ' data ' ] ) ] = record [ ' data ' ]
if self . truncate_length is not None :
out_data = out_data [ : self . truncate_length ]
return out_data
# applies patch on an opened file
def applyFile ( self , handle ) :
for record in self . records :
handle . seek ( record [ ' address ' ] )
if ' rle_count ' in record :
handle . write ( bytearray ( b ' ' ) . join ( [ record [ ' data ' ] ] ) * record [ ' rle_count ' ] )
else :
handle . write ( record [ ' data ' ] )
# appends an IPS_Patch on top of this one
def append ( self , patch ) :
if patch . truncate_length is not None and ( self . truncate_length is None or patch . truncate_length > self . truncate_length ) :
self . set_truncate_length ( patch . truncate_length )
for record in patch . records :
if record [ ' size ' ] > 0 : # ignore empty records
self . appendRecord ( record )
# gets address ranges written to by this patch
def getRanges ( self ) :
def getRange ( record ) :
return range ( record [ ' address ' ] , record [ ' address ' ] + record [ ' size ' ] )
return range_union ( [ getRange ( record ) for record in self . records ] )