diff options
author | yenatch <yenatch@gmail.com> | 2014-06-02 17:38:30 -0700 |
---|---|---|
committer | yenatch <yenatch@gmail.com> | 2014-06-02 17:57:43 -0700 |
commit | b07f9c7f76e221b15ec8a153fe52734a8174aa0d (patch) | |
tree | d5a79d5d13d1608596e5fb9740a58c5fbe7d7dc7 /pokemontools/gfx.py | |
parent | c2c45d61ebf352b72d2692116925e8d1328d708e (diff) |
Rewrite the lz compressor.
Diffstat (limited to 'pokemontools/gfx.py')
-rw-r--r-- | pokemontools/gfx.py | 576 |
1 files changed, 206 insertions, 370 deletions
diff --git a/pokemontools/gfx.py b/pokemontools/gfx.py index 92cc0c1..93c8941 100644 --- a/pokemontools/gfx.py +++ b/pokemontools/gfx.py @@ -181,388 +181,188 @@ lz_end = 0xff class Compressed: - """ - Compress arbitrary data, usually 2bpp. - """ - - def __init__(self, image=None, mode='horiz', size=None): - assert image, 'need something to compress!' - image = list(image) - self.image = image - self.pic = [] - self.animtiles = [] - - # only transpose pic (animtiles were never transposed in decompression) - if size != None: - for byte in range((size*size)*16): - self.pic += image[byte] - for byte in range(((size*size)*16),len(image)): - self.animtiles += image[byte] - else: - self.pic = image - - if mode == 'vert': - self.tiles = get_tiles(self.pic) - self.tiles = transpose(self.tiles) - self.pic = connect(self.tiles) - - self.image = self.pic + self.animtiles - - self.end = len(self.image) - - self.byte = None - self.address = 0 - - self.stream = [] - - self.zeros = [] - self.alts = [] - self.iters = [] - self.repeats = [] - self.flips = [] - self.reverses = [] - self.literals = [] - - self.output = [] - + def __init__(self, data=None, commands=lz_commands, debug=False): + self.data = list(bytearray(data)) + self.commands = commands + self.debug = debug self.compress() + def byte_at(self, address): + if address < len(self.data): + return self.data[address] + return None def compress(self): """ - Incomplete, but outputs working compressed data. + This algorithm is greedy. + It aims to match the compressor it's based on as closely as possible. + It doesn't, but in the meantime the output is smaller. """ - self.address = 0 - - # todo - #self.scanRepeats() - - while ( self.address < self.end ): - - #if (self.repeats): - # self.doRepeats() - - #if (self.flips): - # self.doFlips() - - #if (self.reverses): - # self.doReverses - - if (self.checkWhitespace()): - self.doLiterals() - self.doWhitespace() - - elif (self.checkIter()): - self.doLiterals() - self.doIter() - - elif (self.checkAlts()): - self.doLiterals() - self.doAlts() - - else: # doesn't fit any pattern -> literal - self.addLiteral() - self.next() - - self.doStream() - - # add any literals we've been sitting on - self.doLiterals() - - # done - self.output.append(lz_end) - - - def getCurByte(self): - if self.address < self.end: - self.byte = ord(self.image[self.address]) - else: self.byte = None - - def next(self): - self.address += 1 - self.getCurByte() - - def addLiteral(self): - self.getCurByte() - self.literals.append(self.byte) - if len(self.literals) > max_length: - raise Exception, "literals exceeded max length and the compressor didn't catch it" - elif len(self.literals) == max_length: - self.doLiterals() - - def doLiterals(self): - if len(self.literals) > lowmax: - self.output.append( (lz_commands['long'] << 5) | (lz_commands['literal'] << 2) | ((len(self.literals) - 1) >> 8) ) - self.output.append( (len(self.literals) - 1) & 0xff ) - elif len(self.literals) > 0: - self.output.append( (lz_commands['literal'] << 5) | (len(self.literals) - 1) ) - for byte in self.literals: - self.output.append(byte) - self.literals = [] - - def doStream(self): - for byte in self.stream: - self.output.append(byte) - self.stream = [] - - - def scanRepeats(self): - """ - Works, but doesn't do flipped/reversed streams yet. - - This takes up most of the compress time and only saves a few bytes. - It might be more effective to exclude it entirely. - """ - - self.repeats = [] - self.flips = [] - self.reverses = [] - - # make a 5-letter word list of the sequence - letters = 5 # how many bytes it costs to use a repeat over a literal - # any shorter and it's not worth the trouble - num_words = len(self.image) - letters - words = [] - for i in range(self.address,num_words): - word = [] - for j in range(letters): - word.append( ord(self.image[i+j]) ) - words.append((word, i)) - - zeros = [] - for zero in range(letters): - zeros.append( 0 ) - - # check for matches - def get_matches(): - # TODO: - # append to 3 different match lists instead of yielding to one - # - #flipped = [] - #for byte in enumerate(this[0]): - # flipped.append( sum(1<<(7-i) for i in range(8) if (this[0][byte])>>i&1) ) - #reversed = this[0][::-1] - # - for whereabout, this in enumerate(words): - for that in range(whereabout+1,len(words)): - if words[that][0] == this[0]: - if words[that][1] - this[1] >= letters: - # remove zeros - if this[0] != zeros: - yield [this[0], this[1], words[that][1]] - - matches = list(get_matches()) - - # remove more zeros - buffer = [] - for match in matches: - # count consecutive zeros in a word - num_zeros = 0 - highest = 0 - for j in range(letters): - if match[0][j] == 0: - num_zeros += 1 - else: - if highest < num_zeros: highest = num_zeros - num_zeros = 0 - if highest < 4: - # any more than 3 zeros in a row isn't worth it - # (and likely to already be accounted for) - buffer.append(match) - matches = buffer - - # combine overlapping matches - buffer = [] - for this, match in enumerate(matches): - if this < len(matches) - 1: # special case for the last match - if matches[this+1][1] <= (match[1] + len(match[0])): # check overlap - if match[1] + len(match[0]) < match[2]: - # next match now contains this match's bytes too - # this only appends the last byte (assumes overlaps are +1 - match[0].append(matches[this+1][0][-1]) - matches[this+1] = match - elif match[1] + len(match[0]) == match[2]: - # we've run into the thing we matched - buffer.append(match) - # else we've gone past it and we can ignore it - else: # no more overlaps - buffer.append(match) - else: # last match, so there's nothing to check - buffer.append(match) - matches = buffer - - # remove alternating sequences - buffer = [] - for match in matches: - for i in range(6 if letters > 6 else letters): - if match[0][i] != match[0][i&1]: - buffer.append(match) + self.end = len(self.data) + self.output = [] + self.literal = [] + + while self.address < self.end: + # Tally up the number of bytes that can be compressed + # by a single command from the current address. + self.scores = {} + for method in self.commands.keys(): + self.scores[method] = 0 + + # The most common byte by far is 0 (whitespace in + # images and padding in tilemaps and regular data). + address = self.address + while self.byte_at(address) == 0x00: + self.scores['blank'] += 1 + address += 1 + + # In the same vein, see how long the same byte repeats for. + address = self.address + self.iter = self.byte_at(address) + while self.byte_at(address) == self.iter: + self.scores['iterate'] += 1 + address += 1 + + # Do it again, but for alternating bytes. + address = self.address + self.alts = [] + self.alts += [self.byte_at(address)] + self.alts += [self.byte_at(address + 1)] + while self.byte_at(address) == self.alts[(address - self.address) % 2]: + self.scores['alternate'] += 1 + address += 1 + + # Check if we can repeat any data that the + # decompressor just output (here, the input data). + # TODO this includes the current command's output + self.matches = {} + last_matches = {} + address = self.address + min_length = 4 # minimum worthwhile length + max_length = 9 # any further and the time loss is too significant + for length in xrange(min_length, min(len(self.data) - address, max_length)): + keyword = self.data[address:address+length] + for offset, byte in enumerate(self.data[:address]): + # offset ranges are -0x80:-1 and 0:0x7fff + if offset > 0x7fff and offset < address - 0x80: + continue + if byte == keyword[0]: + # Straight repeat... + if self.data[offset:offset+length] == keyword: + if self.scores['repeat'] < length: + self.scores['repeat'] = length + self.matches['repeat'] = offset + # In reverse... + if self.data[offset-1:offset-length-1:-1] == keyword: + if self.scores['reverse'] < length: + self.scores['reverse'] = length + self.matches['reverse'] = offset + # Or bitflipped + if self.bit_flip([byte]) == self.bit_flip([keyword[0]]): + if self.bit_flip(self.data[offset:offset+length]) == self.bit_flip(keyword): + if self.scores['flip'] < length: + self.scores['flip'] = length + self.matches['flip'] = offset + if self.matches == last_matches: break - matches = buffer - - self.repeats = matches - - - def doRepeats(self): - """doesn't output the right values yet""" - - unusedrepeats = [] - for repeat in self.repeats: - if self.address >= repeat[2]: - - # how far in we are - length = (len(repeat[0]) - (self.address - repeat[2])) - - # decide which side we're copying from - if (self.address - repeat[1]) <= 0x80: - self.doLiterals() - self.stream.append( (lz_commands['repeat'] << 5) | length - 1 ) - - # wrong? - self.stream.append( (((self.address - repeat[1])^0xff)+1)&0xff ) - - else: - self.doLiterals() - self.stream.append( (lz_commands['repeat'] << 5) | length - 1 ) - - # wrong? - self.stream.append(repeat[1]>>8) - self.stream.append(repeat[1]&0xff) - - #print hex(self.address) + ': ' + hex(len(self.output)) + ' ' + hex(length) - self.address += length - - else: unusedrepeats.append(repeat) - - self.repeats = unusedrepeats - - - def checkWhitespace(self): - self.zeros = [] - self.getCurByte() - original_address = self.address - - if ( self.byte == 0 ): - while ( self.byte == 0 ) & ( len(self.zeros) <= max_length ): - self.zeros.append(self.byte) - self.next() - if len(self.zeros) > 1: - return True - self.address = original_address - return False - - def doWhitespace(self): - if (len(self.zeros) + 1) >= lowmax: - self.stream.append( (lz_commands['long'] << 5) | (lz_commands['blank'] << 2) | ((len(self.zeros) - 1) >> 8) ) - self.stream.append( (len(self.zeros) - 1) & 0xff ) - elif len(self.zeros) > 1: - self.stream.append( lz_commands['blank'] << 5 | (len(self.zeros) - 1) ) + last_matches = list(self.matches) + + # If the scores are too low, try again from the next byte. + if not any(map(lambda x: { + 'blank': 1, + 'iterate': 2, + 'alternate': 3, + 'repeat': 3, + 'reverse': 3, + 'flip': 3, + }.get(x[0], 10000) < x[1], self.scores.items())): + self.literal += [self.data[self.address]] + self.address += 1 + + else: # payload + # bug: literal [00] is a byte longer than blank 1. + # this bug exists in the target compressor as well, + # so don't fix until we've given up on replicating it. + self.do_literal() + self.do_scored() + + # unload any literals we're sitting on + self.do_literal() + self.output += [lz_end] + + def bit_flip(self, data): + return [sum(((byte >> i) & 1) << (7 - i) for i in xrange(8)) for byte in data] + + def do_literal(self): + if self.literal: + cmd = self.commands['literal'] + length = len(self.literal) + self.do_cmd(cmd, length) + # self.address has already been + # incremented in the main loop + self.literal = [] + + def do_cmd(self, cmd, length): + if length > max_length: + length = max_length + + cmd_length = length - 1 + + if length > lowmax: + output = [(self.commands['long'] << 5) + (cmd << 2) + (cmd_length >> 8)] + output += [cmd_length & 0xff] else: - raise Exception, "checkWhitespace() should prevent this from happening" - - - def checkAlts(self): - self.alts = [] - self.getCurByte() - original_address = self.address - num_alts = 0 - - # make sure we don't check for alts at the end of the file - if self.address+3 >= self.end: return False - - self.alts.append(self.byte) - self.alts.append(ord(self.image[self.address+1])) - - # are we onto smething? - if ( ord(self.image[self.address+2]) == self.alts[0] ): - cur_alt = 0 - while (ord(self.image[(self.address)+1]) == self.alts[num_alts&1]) & (num_alts <= max_length): - num_alts += 1 - self.next() - # include the last alternated byte - num_alts += 1 - self.address = original_address - if num_alts > lowmax: - return True - elif num_alts > 2: - return True - return False - - def doAlts(self): - original_address = self.address - self.getCurByte() - - #self.alts = [] - #num_alts = 0 - - #self.alts.append(self.byte) - #self.alts.append(ord(self.image[self.address+1])) - - #i = 0 - #while (ord(self.image[self.address+1]) == self.alts[i^1]) & (num_alts <= max_length): - # num_alts += 1 - # i ^=1 - # self.next() - ## include the last alternated byte - #num_alts += 1 - - num_alts = len(self.iters) + 1 - - if num_alts > lowmax: - self.stream.append( (lz_commands['long'] << 5) | (lz_commands['alternate'] << 2) | ((num_alts - 1) >> 8) ) - self.stream.append( num_alts & 0xff ) - self.stream.append( self.alts[0] ) - self.stream.append( self.alts[1] ) - elif num_alts > 2: - self.stream.append( (lz_commands['alternate'] << 5) | (num_alts - 1) ) - self.stream.append( self.alts[0] ) - self.stream.append( self.alts[1] ) + output = [(cmd << 5) + cmd_length] + + if cmd == self.commands['literal']: + output += self.literal + elif cmd == self.commands['iterate']: + output += [self.iter] + elif cmd == self.commands['alternate']: + output += self.alts else: - raise Exception, "checkAlts() should prevent this from happening" - - self.address = original_address - self.address += num_alts - - - def checkIter(self): - self.iters = [] - self.getCurByte() - iter = self.byte - original_address = self.address - while (self.byte == iter) & (len(self.iters) < max_length): - self.iters.append(self.byte) - self.next() - self.address = original_address - if len(self.iters) > 3: - # 3 or fewer isn't worth the trouble and actually longer - # if part of a larger literal set - return True - - return False - - def doIter(self): - self.getCurByte() - iter = self.byte - original_address = self.address + for command in ['repeat', 'reverse', 'flip']: + if cmd == self.commands[command]: + offset = self.matches[command] + # negative offsets are a byte shorter + if self.address - offset <= 0x80: + offset = self.address - offset + 0x80 + if cmd == self.commands['repeat']: + offset -= 1 # this is a hack, but it seems to work + output += [offset] + else: + output += [offset / 0x100, offset % 0x100] + + if self.debug: + print ( + dict(map(reversed, self.commands.items()))[cmd], + length, '\t', + ' '.join(map('{:02x}'.format, output)) + ) - self.iters = [] - while (self.byte == iter) & (len(self.iters) < max_length): - self.iters.append(self.byte) - self.next() + self.output += output + return length + + def do_scored(self): + # Which command did the best? + winner, score = sorted( + self.scores.items(), + key=lambda x:(-x[1], [ + 'blank', + 'repeat', + 'reverse', + 'flip', + 'iterate', + 'alternate', + 'literal', + 'long', # hack + ].index(x[0])) + )[0] + cmd = self.commands[winner] + length = self.do_cmd(cmd, score) + self.address += length - if (len(self.iters) - 1) >= lowmax: - self.stream.append( (lz_commands['long'] << 5) | (lz_commands['iterate'] << 2) | ((len(self.iters)-1) >> 8) ) - self.stream.append( (len(self.iters) - 1) & 0xff ) - self.stream.append( iter ) - elif len(self.iters) > 3: - # 3 or fewer isn't worth the trouble and actually longer - # if part of a larger literal set - self.stream.append( (lz_commands['iterate'] << 5) | (len(self.iters) - 1) ) - self.stream.append( iter ) - else: - self.address = original_address - raise Exception, "checkIter() should prevent this from happening" class Decompressed: @@ -615,6 +415,42 @@ class Decompressed: self.output = self.pic + self.animtiles + def command_list(self): + """ + Print a list of commands that were used. Useful for debugging. + """ + + data = bytearray(self.lz) + address = self.address + while 1: + cmd_addr = address + byte = data[address] + address += 1 + if byte == lz_end: break + cmd = (byte >> 5) & 0b111 + if cmd == lz_commands['long']: + cmd = (byte >> 2) & 0b111 + length = (byte & 0b11) << 8 + length += data[address] + address += 1 + else: + length = byte & 0b11111 + length += 1 + name = dict(map(reversed, lz_commands.items()))[cmd] + if name == 'iterate': + address += 1 + elif name == 'alternate': + address += 2 + elif name in ['repeat', 'reverse', 'flip']: + if data[address] < 0x80: + address += 2 + else: + address += 1 + elif name == 'literal': + address += length + print name, length, '\t', ' '.join(map('{:02x}'.format, list(data)[cmd_addr:address])) + + def decompress(self): """ Replica of crystal's decompression. |