diff options
author | yenatch <yenatch@gmail.com> | 2015-02-07 13:11:38 -0800 |
---|---|---|
committer | yenatch <yenatch@gmail.com> | 2015-02-07 13:11:38 -0800 |
commit | fb23f2754c9b94eac871a33621ce43fa71b53170 (patch) | |
tree | 08e59e8043fb1d773817e1734ace1a160b32af5b /pokemontools/lz.py | |
parent | 36f7ee513e0a2d3fe7278e58a56d05d1a34177f0 (diff) |
Split the lz compression tools out of gfx.py.
Diffstat (limited to 'pokemontools/lz.py')
-rw-r--r-- | pokemontools/lz.py | 566 |
1 files changed, 566 insertions, 0 deletions
diff --git a/pokemontools/lz.py b/pokemontools/lz.py new file mode 100644 index 0000000..2f30ddf --- /dev/null +++ b/pokemontools/lz.py @@ -0,0 +1,566 @@ +# -*- coding: utf-8 -*- +""" +Pokemon Crystal data de/compression. +""" + +""" +A rundown of Pokemon Crystal's compression scheme: + +Control commands occupy bits 5-7. +Bits 0-4 serve as the first parameter <n> for each command. +""" +lz_commands = { + 'literal': 0, # n values for n bytes + 'iterate': 1, # one value for n bytes + 'alternate': 2, # alternate two values for n bytes + 'blank': 3, # zero for n bytes +} + +""" +Repeater commands repeat any data that was just decompressed. +They take an additional signed parameter <s> to mark a relative starting point. +These wrap around (positive from the start, negative from the current position). +""" +lz_commands.update({ + 'repeat': 4, # n bytes starting from s + 'flip': 5, # n bytes in reverse bit order starting from s + 'reverse': 6, # n bytes backwards starting from s +}) + +""" +The long command is used when 5 bits aren't enough. Bits 2-4 contain a new control code. +Bits 0-1 are appended to a new byte as 8-9, allowing a 10-bit parameter. +""" +lz_commands.update({ + 'long': 7, # n is now 10 bits for a new control code +}) +max_length = 1 << 10 # can't go higher than 10 bits +lowmax = 1 << 5 # standard 5-bit param + +""" +If 0xff is encountered instead of a command, decompression ends. +""" +lz_end = 0xff + + +bit_flipped = [ + sum(((byte >> i) & 1) << (7 - i) for i in xrange(8)) + for byte in xrange(0x100) +] + + +class Compressed: + + """ + Usage: + lz = Compressed(data).output + or + lz = Compressed().compress(data) + or + c = Compressed() + c.data = data + lz = c.compress() + + There are some issues with reproducing the target compressor. + Some notes are listed here: + - the criteria for detecting a lookback is inconsistent + - sometimes lookbacks that are mostly 0s are pruned, sometimes not + - target appears to skip ahead if it can use a lookback soon, stopping the current command short or in some cases truncating it with literals. + - this has been implemented, but the specifics are unknown + - self.min_scores: It's unknown if blank's minimum score should be 1 or 2. Most likely it's 1, with some other hack to account for edge cases. + - may be related to the above + - target does not appear to compress backwards + """ + + def __init__(self, *args, **kwargs): + + self.min_scores = { + 'blank': 1, + 'iterate': 2, + 'alternate': 3, + 'repeat': 3, + 'reverse': 3, + 'flip': 3, + } + + self.preference = [ + 'repeat', + 'blank', + 'flip', + 'reverse', + 'iterate', + 'alternate', + #'literal', + ] + + self.lookback_methods = 'repeat', 'reverse', 'flip' + + self.__dict__.update({ + 'data': None, + 'commands': lz_commands, + 'debug': False, + 'literal_only': False, + }) + + self.arg_names = 'data', 'commands', 'debug', 'literal_only' + + self.__dict__.update(kwargs) + self.__dict__.update(dict(zip(self.arg_names, args))) + + if self.data is not None: + self.compress() + + def compress(self, data=None): + if data is not None: + self.data = data + + self.data = list(bytearray(self.data)) + + self.indexes = {} + self.lookbacks = {} + for method in self.lookback_methods: + self.lookbacks[method] = {} + + self.address = 0 + self.end = len(self.data) + self.output = [] + self.literal = None + + while self.address < self.end: + + if self.score(): + self.do_literal() + self.do_winner() + + else: + if self.literal == None: + self.literal = self.address + self.address += 1 + + self.do_literal() + + self.output += [lz_end] + return self.output + + def reset_scores(self): + self.scores = {} + self.offsets = {} + self.helpers = {} + for method in self.min_scores.iterkeys(): + self.scores[method] = 0 + + def bit_flip(self, byte): + return bit_flipped[byte] + + def do_literal(self): + if self.literal != None: + length = abs(self.address - self.literal) + start = min(self.literal, self.address + 1) + self.helpers['literal'] = self.data[start:start+length] + self.do_cmd('literal', length) + self.literal = None + + def score(self): + self.reset_scores() + + map(self.score_literal, ['iterate', 'alternate', 'blank']) + + for method in self.lookback_methods: + self.scores[method], self.offsets[method] = self.find_lookback(method, self.address) + + # Compatibility: + # If a lookback is close, reduce the scores of other commands + best_method, best_score = max( + self.scores.items(), + key = lambda x: ( + x[1], + -self.preference.index(x[0]) + ) + ) + for method in self.lookback_methods: + for address in xrange(self.address+1, self.address+min(best_score, 6)): + if self.find_lookback(method, address)[0] > max(self.min_scores[method], best_score): + # BUG: lookbacks can reduce themselves. This appears to be a bug in the target also. + for m, score in self.scores.items(): + self.scores[m] = min(score, address - self.address) + + return any( + score + > self.min_scores[method] + int(score > lowmax) + for method, score in self.scores.iteritems() + ) + + def read(self, address=None): + if address is None: + address = self.address + if 0 <= address < len(self.data): + return self.data[address] + return None + + def find_all_lookbacks(self): + for method in self.lookback_methods: + for address, byte in enumerate(self.data): + self.find_lookback(method, address) + + def find_lookback(self, method, address=None): + if address is None: + address = self.address + + existing = self.lookbacks.get(method, {}).get(address) + if existing != None: + return existing + + lookback = 0, None + + # Better to not carelessly optimize at the moment. + """ + if address < 2: + return lookback + """ + + byte = self.read(address) + if byte is None: + return lookback + + direction, mutate = { + 'repeat': ( 1, int), + 'reverse': (-1, int), + 'flip': ( 1, self.bit_flip), + }[method] + + # Doesn't seem to help + """ + if mutate == self.bit_flip: + if byte == 0: + self.lookbacks[method][address] = lookback + return lookback + """ + + data_len = len(self.data) + is_two_byte_index = lambda index: int(index < address - 0x7f) + + for index in self.get_indexes(mutate(byte)): + + if index >= address: + break + + old_length, old_index = lookback + if direction == 1: + if old_length > data_len - index: break + else: + if old_length > index: continue + + if self.read(index) in [None]: continue + + length = 1 # we know there's at least one match, or we wouldn't be checking this index + while 1: + this_byte = self.read(address + length) + that_byte = self.read(index + length * direction) + if that_byte == None or this_byte != mutate(that_byte): + break + length += 1 + """ + if direction == 1: + if not any(self.data[address+2:address+length]): continue + """ + if length - is_two_byte_index(index) >= old_length - is_two_byte_index(old_index): # XXX >? + # XXX maybe avoid two-byte indexes when possible + lookback = length, index + + self.lookbacks[method][address] = lookback + return lookback + + def get_indexes(self, byte): + if not self.indexes.has_key(byte): + self.indexes[byte] = [] + index = -1 + while 1: + try: + index = self.data.index(byte, index + 1) + except ValueError: + break + self.indexes[byte].append(index) + return self.indexes[byte] + + def score_literal(self, method): + address = self.address + + compare = { + 'blank': [0], + 'iterate': [self.read(address)], + 'alternate': [self.read(address), self.read(address + 1)], + }[method] + + # XXX may or may not be correct + if method == 'alternate' and compare[0] == 0: + return + + length = 0 + while self.read(address + length) == compare[length % len(compare)]: + length += 1 + + self.scores[method] = length + self.helpers[method] = compare + + def do_winner(self): + winners = filter( + lambda (method, score): + score + > self.min_scores[method] + int(score > lowmax), + self.scores.iteritems() + ) + winners.sort( + key = lambda (method, score): ( + -(score - self.min_scores[method] - int(score > lowmax)), + self.preference.index(method) + ) + ) + winner, score = winners[0] + + length = min(score, max_length) + self.do_cmd(winner, length) + self.address += length + + def do_cmd(self, cmd, length): + start_address = self.address + + cmd_length = length - 1 + + output = [] + + if length > lowmax: + output.append( + (self.commands['long'] << 5) + + (self.commands[cmd] << 2) + + (cmd_length >> 8) + ) + output.append( + cmd_length & 0xff + ) + else: + output.append( + (self.commands[cmd] << 5) + + cmd_length + ) + + self.helpers['blank'] = [] # quick hack + output += self.helpers.get(cmd, []) + + if cmd in self.lookback_methods: + offset = self.offsets[cmd] + # Negative offsets are one byte. + # Positive offsets are two. + if start_address - offset <= 0x7f: + offset = start_address - offset + 0x80 + offset -= 1 # this seems to work + output += [offset] + else: + output += [offset / 0x100, offset % 0x100] # big endian + + if self.debug: + print ' '.join(map(str, [ + cmd, length, '\t', + ' '.join(map('{:02x}'.format, output)), + self.data[start_address:start_address+length] if cmd in self.lookback_methods else '', + ])) + + self.output += output + + + +class Decompressed: + """ + Interpret and decompress lz-compressed data, usually 2bpp. + """ + + """ + Usage: + data = Decompressed(lz).output + or + data = Decompressed().decompress(lz) + or + d = Decompressed() + d.lz = lz + data = d.decompress() + + To decompress from offset 0x80000 in a rom: + data = Decompressed(rom, start=0x80000).output + """ + + lz = None + start = 0 + commands = lz_commands + debug = False + + arg_names = 'lz', 'start', 'commands', 'debug' + + def __init__(self, *args, **kwargs): + self.__dict__.update(dict(zip(self.arg_names, args))) + self.__dict__.update(kwargs) + + self.command_names = dict(map(reversed, self.commands.items())) + self.address = self.start + + if self.lz is not None: + self.decompress() + + if self.debug: print self.command_list() + + + def command_list(self): + """ + Print a list of commands that were used. Useful for debugging. + """ + + text = '' + + for name, attrs in self.used_commands: + length = attrs['length'] + address = attrs['address'] + offset = attrs['offset'] + direction = attrs['direction'] + + text += '{0}: {1}'.format(name, length) + text += '\t' + ' '.join( + '{:02x}'.format(int(byte)) + for byte in self.lz[ address : address + attrs['cmd_length'] ] + ) + + if offset is not None: + repeated_data = self.output[ offset : offset + length * direction : direction ] + text += ' [' + ' '.join(map('{:02x}'.format, repeated_data)) + ']' + + text += '\n' + + return text + + + def decompress(self, lz=None): + + if lz is not None: + self.lz = lz + + self.lz = bytearray(self.lz) + + self.used_commands = [] + self.output = [] + + while 1: + + cmd_address = self.address + self.offset = None + self.direction = None + + if (self.byte == lz_end): + self.next() + break + + self.cmd = (self.byte & 0b11100000) >> 5 + + if self.cmd_name == 'long': + # 10-bit length + self.cmd = (self.byte & 0b00011100) >> 2 + self.length = (self.next() & 0b00000011) * 0x100 + self.length += self.next() + 1 + else: + # 5-bit length + self.length = (self.next() & 0b00011111) + 1 + + self.__class__.__dict__[self.cmd_name](self) + + self.used_commands += [( + self.cmd_name, + { + 'length': self.length, + 'address': cmd_address, + 'offset': self.offset, + 'cmd_length': self.address - cmd_address, + 'direction': self.direction, + } + )] + + # Keep track of the data we just decompressed. + self.compressed_data = self.lz[self.start : self.address] + + + @property + def byte(self): + return self.lz[ self.address ] + + def next(self): + byte = self.byte + self.address += 1 + return byte + + @property + def cmd_name(self): + return self.command_names.get(self.cmd) + + + def get_offset(self): + + if self.byte >= 0x80: # negative + # negative + offset = self.next() & 0x7f + offset = len(self.output) - offset - 1 + else: + # positive + offset = self.next() * 0x100 + offset += self.next() + + self.offset = offset + + + def literal(self): + """ + Copy data directly. + """ + self.output += self.lz[ self.address : self.address + self.length ] + self.address += self.length + + def iterate(self): + """ + Write one byte repeatedly. + """ + self.output += [self.next()] * self.length + + def alternate(self): + """ + Write alternating bytes. + """ + alts = [self.next(), self.next()] + self.output += [ alts[x & 1] for x in xrange(self.length) ] + + def blank(self): + """ + Write zeros. + """ + self.output += [0] * self.length + + def flip(self): + """ + Repeat flipped bytes from output. + + Example: 11100100 -> 00100111 + """ + self._repeat(table=bit_flipped) + + def reverse(self): + """ + Repeat reversed bytes from output. + """ + self._repeat(direction=-1) + + def repeat(self): + """ + Repeat bytes from output. + """ + self._repeat() + + def _repeat(self, direction=1, table=None): + self.get_offset() + self.direction = direction + # Note: appends must be one at a time (this way, repeats can draw from themselves if required) + for i in xrange(self.length): + byte = self.output[ self.offset + i * direction ] + self.output.append( table[byte] if table else byte ) |