summaryrefslogtreecommitdiff
path: root/pokemontools/gfx.py
diff options
context:
space:
mode:
authoryenatch <yenatch@gmail.com>2014-06-02 17:38:30 -0700
committeryenatch <yenatch@gmail.com>2014-06-02 17:57:43 -0700
commitb07f9c7f76e221b15ec8a153fe52734a8174aa0d (patch)
treed5a79d5d13d1608596e5fb9740a58c5fbe7d7dc7 /pokemontools/gfx.py
parentc2c45d61ebf352b72d2692116925e8d1328d708e (diff)
Rewrite the lz compressor.
Diffstat (limited to 'pokemontools/gfx.py')
-rw-r--r--pokemontools/gfx.py576
1 files changed, 206 insertions, 370 deletions
diff --git a/pokemontools/gfx.py b/pokemontools/gfx.py
index 92cc0c1..93c8941 100644
--- a/pokemontools/gfx.py
+++ b/pokemontools/gfx.py
@@ -181,388 +181,188 @@ lz_end = 0xff
class Compressed:
- """
- Compress arbitrary data, usually 2bpp.
- """
-
- def __init__(self, image=None, mode='horiz', size=None):
- assert image, 'need something to compress!'
- image = list(image)
- self.image = image
- self.pic = []
- self.animtiles = []
-
- # only transpose pic (animtiles were never transposed in decompression)
- if size != None:
- for byte in range((size*size)*16):
- self.pic += image[byte]
- for byte in range(((size*size)*16),len(image)):
- self.animtiles += image[byte]
- else:
- self.pic = image
-
- if mode == 'vert':
- self.tiles = get_tiles(self.pic)
- self.tiles = transpose(self.tiles)
- self.pic = connect(self.tiles)
-
- self.image = self.pic + self.animtiles
-
- self.end = len(self.image)
-
- self.byte = None
- self.address = 0
-
- self.stream = []
-
- self.zeros = []
- self.alts = []
- self.iters = []
- self.repeats = []
- self.flips = []
- self.reverses = []
- self.literals = []
-
- self.output = []
-
+ def __init__(self, data=None, commands=lz_commands, debug=False):
+ self.data = list(bytearray(data))
+ self.commands = commands
+ self.debug = debug
self.compress()
+ def byte_at(self, address):
+ if address < len(self.data):
+ return self.data[address]
+ return None
def compress(self):
"""
- Incomplete, but outputs working compressed data.
+ This algorithm is greedy.
+ It aims to match the compressor it's based on as closely as possible.
+ It doesn't, but in the meantime the output is smaller.
"""
-
self.address = 0
-
- # todo
- #self.scanRepeats()
-
- while ( self.address < self.end ):
-
- #if (self.repeats):
- # self.doRepeats()
-
- #if (self.flips):
- # self.doFlips()
-
- #if (self.reverses):
- # self.doReverses
-
- if (self.checkWhitespace()):
- self.doLiterals()
- self.doWhitespace()
-
- elif (self.checkIter()):
- self.doLiterals()
- self.doIter()
-
- elif (self.checkAlts()):
- self.doLiterals()
- self.doAlts()
-
- else: # doesn't fit any pattern -> literal
- self.addLiteral()
- self.next()
-
- self.doStream()
-
- # add any literals we've been sitting on
- self.doLiterals()
-
- # done
- self.output.append(lz_end)
-
-
- def getCurByte(self):
- if self.address < self.end:
- self.byte = ord(self.image[self.address])
- else: self.byte = None
-
- def next(self):
- self.address += 1
- self.getCurByte()
-
- def addLiteral(self):
- self.getCurByte()
- self.literals.append(self.byte)
- if len(self.literals) > max_length:
- raise Exception, "literals exceeded max length and the compressor didn't catch it"
- elif len(self.literals) == max_length:
- self.doLiterals()
-
- def doLiterals(self):
- if len(self.literals) > lowmax:
- self.output.append( (lz_commands['long'] << 5) | (lz_commands['literal'] << 2) | ((len(self.literals) - 1) >> 8) )
- self.output.append( (len(self.literals) - 1) & 0xff )
- elif len(self.literals) > 0:
- self.output.append( (lz_commands['literal'] << 5) | (len(self.literals) - 1) )
- for byte in self.literals:
- self.output.append(byte)
- self.literals = []
-
- def doStream(self):
- for byte in self.stream:
- self.output.append(byte)
- self.stream = []
-
-
- def scanRepeats(self):
- """
- Works, but doesn't do flipped/reversed streams yet.
-
- This takes up most of the compress time and only saves a few bytes.
- It might be more effective to exclude it entirely.
- """
-
- self.repeats = []
- self.flips = []
- self.reverses = []
-
- # make a 5-letter word list of the sequence
- letters = 5 # how many bytes it costs to use a repeat over a literal
- # any shorter and it's not worth the trouble
- num_words = len(self.image) - letters
- words = []
- for i in range(self.address,num_words):
- word = []
- for j in range(letters):
- word.append( ord(self.image[i+j]) )
- words.append((word, i))
-
- zeros = []
- for zero in range(letters):
- zeros.append( 0 )
-
- # check for matches
- def get_matches():
- # TODO:
- # append to 3 different match lists instead of yielding to one
- #
- #flipped = []
- #for byte in enumerate(this[0]):
- # flipped.append( sum(1<<(7-i) for i in range(8) if (this[0][byte])>>i&1) )
- #reversed = this[0][::-1]
- #
- for whereabout, this in enumerate(words):
- for that in range(whereabout+1,len(words)):
- if words[that][0] == this[0]:
- if words[that][1] - this[1] >= letters:
- # remove zeros
- if this[0] != zeros:
- yield [this[0], this[1], words[that][1]]
-
- matches = list(get_matches())
-
- # remove more zeros
- buffer = []
- for match in matches:
- # count consecutive zeros in a word
- num_zeros = 0
- highest = 0
- for j in range(letters):
- if match[0][j] == 0:
- num_zeros += 1
- else:
- if highest < num_zeros: highest = num_zeros
- num_zeros = 0
- if highest < 4:
- # any more than 3 zeros in a row isn't worth it
- # (and likely to already be accounted for)
- buffer.append(match)
- matches = buffer
-
- # combine overlapping matches
- buffer = []
- for this, match in enumerate(matches):
- if this < len(matches) - 1: # special case for the last match
- if matches[this+1][1] <= (match[1] + len(match[0])): # check overlap
- if match[1] + len(match[0]) < match[2]:
- # next match now contains this match's bytes too
- # this only appends the last byte (assumes overlaps are +1
- match[0].append(matches[this+1][0][-1])
- matches[this+1] = match
- elif match[1] + len(match[0]) == match[2]:
- # we've run into the thing we matched
- buffer.append(match)
- # else we've gone past it and we can ignore it
- else: # no more overlaps
- buffer.append(match)
- else: # last match, so there's nothing to check
- buffer.append(match)
- matches = buffer
-
- # remove alternating sequences
- buffer = []
- for match in matches:
- for i in range(6 if letters > 6 else letters):
- if match[0][i] != match[0][i&1]:
- buffer.append(match)
+ self.end = len(self.data)
+ self.output = []
+ self.literal = []
+
+ while self.address < self.end:
+ # Tally up the number of bytes that can be compressed
+ # by a single command from the current address.
+ self.scores = {}
+ for method in self.commands.keys():
+ self.scores[method] = 0
+
+ # The most common byte by far is 0 (whitespace in
+ # images and padding in tilemaps and regular data).
+ address = self.address
+ while self.byte_at(address) == 0x00:
+ self.scores['blank'] += 1
+ address += 1
+
+ # In the same vein, see how long the same byte repeats for.
+ address = self.address
+ self.iter = self.byte_at(address)
+ while self.byte_at(address) == self.iter:
+ self.scores['iterate'] += 1
+ address += 1
+
+ # Do it again, but for alternating bytes.
+ address = self.address
+ self.alts = []
+ self.alts += [self.byte_at(address)]
+ self.alts += [self.byte_at(address + 1)]
+ while self.byte_at(address) == self.alts[(address - self.address) % 2]:
+ self.scores['alternate'] += 1
+ address += 1
+
+ # Check if we can repeat any data that the
+ # decompressor just output (here, the input data).
+ # TODO this includes the current command's output
+ self.matches = {}
+ last_matches = {}
+ address = self.address
+ min_length = 4 # minimum worthwhile length
+ max_length = 9 # any further and the time loss is too significant
+ for length in xrange(min_length, min(len(self.data) - address, max_length)):
+ keyword = self.data[address:address+length]
+ for offset, byte in enumerate(self.data[:address]):
+ # offset ranges are -0x80:-1 and 0:0x7fff
+ if offset > 0x7fff and offset < address - 0x80:
+ continue
+ if byte == keyword[0]:
+ # Straight repeat...
+ if self.data[offset:offset+length] == keyword:
+ if self.scores['repeat'] < length:
+ self.scores['repeat'] = length
+ self.matches['repeat'] = offset
+ # In reverse...
+ if self.data[offset-1:offset-length-1:-1] == keyword:
+ if self.scores['reverse'] < length:
+ self.scores['reverse'] = length
+ self.matches['reverse'] = offset
+ # Or bitflipped
+ if self.bit_flip([byte]) == self.bit_flip([keyword[0]]):
+ if self.bit_flip(self.data[offset:offset+length]) == self.bit_flip(keyword):
+ if self.scores['flip'] < length:
+ self.scores['flip'] = length
+ self.matches['flip'] = offset
+ if self.matches == last_matches:
break
- matches = buffer
-
- self.repeats = matches
-
-
- def doRepeats(self):
- """doesn't output the right values yet"""
-
- unusedrepeats = []
- for repeat in self.repeats:
- if self.address >= repeat[2]:
-
- # how far in we are
- length = (len(repeat[0]) - (self.address - repeat[2]))
-
- # decide which side we're copying from
- if (self.address - repeat[1]) <= 0x80:
- self.doLiterals()
- self.stream.append( (lz_commands['repeat'] << 5) | length - 1 )
-
- # wrong?
- self.stream.append( (((self.address - repeat[1])^0xff)+1)&0xff )
-
- else:
- self.doLiterals()
- self.stream.append( (lz_commands['repeat'] << 5) | length - 1 )
-
- # wrong?
- self.stream.append(repeat[1]>>8)
- self.stream.append(repeat[1]&0xff)
-
- #print hex(self.address) + ': ' + hex(len(self.output)) + ' ' + hex(length)
- self.address += length
-
- else: unusedrepeats.append(repeat)
-
- self.repeats = unusedrepeats
-
-
- def checkWhitespace(self):
- self.zeros = []
- self.getCurByte()
- original_address = self.address
-
- if ( self.byte == 0 ):
- while ( self.byte == 0 ) & ( len(self.zeros) <= max_length ):
- self.zeros.append(self.byte)
- self.next()
- if len(self.zeros) > 1:
- return True
- self.address = original_address
- return False
-
- def doWhitespace(self):
- if (len(self.zeros) + 1) >= lowmax:
- self.stream.append( (lz_commands['long'] << 5) | (lz_commands['blank'] << 2) | ((len(self.zeros) - 1) >> 8) )
- self.stream.append( (len(self.zeros) - 1) & 0xff )
- elif len(self.zeros) > 1:
- self.stream.append( lz_commands['blank'] << 5 | (len(self.zeros) - 1) )
+ last_matches = list(self.matches)
+
+ # If the scores are too low, try again from the next byte.
+ if not any(map(lambda x: {
+ 'blank': 1,
+ 'iterate': 2,
+ 'alternate': 3,
+ 'repeat': 3,
+ 'reverse': 3,
+ 'flip': 3,
+ }.get(x[0], 10000) < x[1], self.scores.items())):
+ self.literal += [self.data[self.address]]
+ self.address += 1
+
+ else: # payload
+ # bug: literal [00] is a byte longer than blank 1.
+ # this bug exists in the target compressor as well,
+ # so don't fix until we've given up on replicating it.
+ self.do_literal()
+ self.do_scored()
+
+ # unload any literals we're sitting on
+ self.do_literal()
+ self.output += [lz_end]
+
+ def bit_flip(self, data):
+ return [sum(((byte >> i) & 1) << (7 - i) for i in xrange(8)) for byte in data]
+
+ def do_literal(self):
+ if self.literal:
+ cmd = self.commands['literal']
+ length = len(self.literal)
+ self.do_cmd(cmd, length)
+ # self.address has already been
+ # incremented in the main loop
+ self.literal = []
+
+ def do_cmd(self, cmd, length):
+ if length > max_length:
+ length = max_length
+
+ cmd_length = length - 1
+
+ if length > lowmax:
+ output = [(self.commands['long'] << 5) + (cmd << 2) + (cmd_length >> 8)]
+ output += [cmd_length & 0xff]
else:
- raise Exception, "checkWhitespace() should prevent this from happening"
-
-
- def checkAlts(self):
- self.alts = []
- self.getCurByte()
- original_address = self.address
- num_alts = 0
-
- # make sure we don't check for alts at the end of the file
- if self.address+3 >= self.end: return False
-
- self.alts.append(self.byte)
- self.alts.append(ord(self.image[self.address+1]))
-
- # are we onto smething?
- if ( ord(self.image[self.address+2]) == self.alts[0] ):
- cur_alt = 0
- while (ord(self.image[(self.address)+1]) == self.alts[num_alts&1]) & (num_alts <= max_length):
- num_alts += 1
- self.next()
- # include the last alternated byte
- num_alts += 1
- self.address = original_address
- if num_alts > lowmax:
- return True
- elif num_alts > 2:
- return True
- return False
-
- def doAlts(self):
- original_address = self.address
- self.getCurByte()
-
- #self.alts = []
- #num_alts = 0
-
- #self.alts.append(self.byte)
- #self.alts.append(ord(self.image[self.address+1]))
-
- #i = 0
- #while (ord(self.image[self.address+1]) == self.alts[i^1]) & (num_alts <= max_length):
- # num_alts += 1
- # i ^=1
- # self.next()
- ## include the last alternated byte
- #num_alts += 1
-
- num_alts = len(self.iters) + 1
-
- if num_alts > lowmax:
- self.stream.append( (lz_commands['long'] << 5) | (lz_commands['alternate'] << 2) | ((num_alts - 1) >> 8) )
- self.stream.append( num_alts & 0xff )
- self.stream.append( self.alts[0] )
- self.stream.append( self.alts[1] )
- elif num_alts > 2:
- self.stream.append( (lz_commands['alternate'] << 5) | (num_alts - 1) )
- self.stream.append( self.alts[0] )
- self.stream.append( self.alts[1] )
+ output = [(cmd << 5) + cmd_length]
+
+ if cmd == self.commands['literal']:
+ output += self.literal
+ elif cmd == self.commands['iterate']:
+ output += [self.iter]
+ elif cmd == self.commands['alternate']:
+ output += self.alts
else:
- raise Exception, "checkAlts() should prevent this from happening"
-
- self.address = original_address
- self.address += num_alts
-
-
- def checkIter(self):
- self.iters = []
- self.getCurByte()
- iter = self.byte
- original_address = self.address
- while (self.byte == iter) & (len(self.iters) < max_length):
- self.iters.append(self.byte)
- self.next()
- self.address = original_address
- if len(self.iters) > 3:
- # 3 or fewer isn't worth the trouble and actually longer
- # if part of a larger literal set
- return True
-
- return False
-
- def doIter(self):
- self.getCurByte()
- iter = self.byte
- original_address = self.address
+ for command in ['repeat', 'reverse', 'flip']:
+ if cmd == self.commands[command]:
+ offset = self.matches[command]
+ # negative offsets are a byte shorter
+ if self.address - offset <= 0x80:
+ offset = self.address - offset + 0x80
+ if cmd == self.commands['repeat']:
+ offset -= 1 # this is a hack, but it seems to work
+ output += [offset]
+ else:
+ output += [offset / 0x100, offset % 0x100]
+
+ if self.debug:
+ print (
+ dict(map(reversed, self.commands.items()))[cmd],
+ length, '\t',
+ ' '.join(map('{:02x}'.format, output))
+ )
- self.iters = []
- while (self.byte == iter) & (len(self.iters) < max_length):
- self.iters.append(self.byte)
- self.next()
+ self.output += output
+ return length
+
+ def do_scored(self):
+ # Which command did the best?
+ winner, score = sorted(
+ self.scores.items(),
+ key=lambda x:(-x[1], [
+ 'blank',
+ 'repeat',
+ 'reverse',
+ 'flip',
+ 'iterate',
+ 'alternate',
+ 'literal',
+ 'long', # hack
+ ].index(x[0]))
+ )[0]
+ cmd = self.commands[winner]
+ length = self.do_cmd(cmd, score)
+ self.address += length
- if (len(self.iters) - 1) >= lowmax:
- self.stream.append( (lz_commands['long'] << 5) | (lz_commands['iterate'] << 2) | ((len(self.iters)-1) >> 8) )
- self.stream.append( (len(self.iters) - 1) & 0xff )
- self.stream.append( iter )
- elif len(self.iters) > 3:
- # 3 or fewer isn't worth the trouble and actually longer
- # if part of a larger literal set
- self.stream.append( (lz_commands['iterate'] << 5) | (len(self.iters) - 1) )
- self.stream.append( iter )
- else:
- self.address = original_address
- raise Exception, "checkIter() should prevent this from happening"
class Decompressed:
@@ -615,6 +415,42 @@ class Decompressed:
self.output = self.pic + self.animtiles
+ def command_list(self):
+ """
+ Print a list of commands that were used. Useful for debugging.
+ """
+
+ data = bytearray(self.lz)
+ address = self.address
+ while 1:
+ cmd_addr = address
+ byte = data[address]
+ address += 1
+ if byte == lz_end: break
+ cmd = (byte >> 5) & 0b111
+ if cmd == lz_commands['long']:
+ cmd = (byte >> 2) & 0b111
+ length = (byte & 0b11) << 8
+ length += data[address]
+ address += 1
+ else:
+ length = byte & 0b11111
+ length += 1
+ name = dict(map(reversed, lz_commands.items()))[cmd]
+ if name == 'iterate':
+ address += 1
+ elif name == 'alternate':
+ address += 2
+ elif name in ['repeat', 'reverse', 'flip']:
+ if data[address] < 0x80:
+ address += 2
+ else:
+ address += 1
+ elif name == 'literal':
+ address += length
+ print name, length, '\t', ' '.join(map('{:02x}'.format, list(data)[cmd_addr:address]))
+
+
def decompress(self):
"""
Replica of crystal's decompression.