diff options
-rw-r--r-- | README.md | 5 | ||||
-rw-r--r-- | comparator.py | 40 | ||||
-rw-r--r-- | crystal.py | 1410 | ||||
-rwxr-xr-x | dump_sections | 14 | ||||
-rwxr-xr-x | dump_sections.py | 130 | ||||
-rw-r--r-- | gbz80disasm.py | 28 | ||||
-rw-r--r-- | gfx.py | 58 | ||||
-rw-r--r-- | graph.py | 13 | ||||
-rw-r--r-- | interval_map.py | 104 | ||||
-rw-r--r-- | item_constants.py | 23 | ||||
-rw-r--r-- | labels.py | 9 | ||||
-rw-r--r-- | move_constants.py | 2 | ||||
-rw-r--r-- | pksv.py | 16 | ||||
-rw-r--r-- | pointers.py | 8 | ||||
-rw-r--r-- | pokemon_constants.py | 2 | ||||
-rw-r--r-- | romstr.py | 354 | ||||
-rw-r--r-- | test_dump_sections.py | 74 | ||||
-rw-r--r-- | tests.py | 1015 | ||||
-rw-r--r-- | type_constants.py | 21 |
19 files changed, 1714 insertions, 1612 deletions
@@ -30,9 +30,8 @@ After running those lines, `cp extras/output.txt main.asm` and run `git diff mai Unit tests cover most of the classes. -```python -import crystal -crystal.run_tests() +```bash +python tests.py ``` #### Parsing a script at a known address diff --git a/comparator.py b/comparator.py index 690fa52..6d981e4 100644 --- a/comparator.py +++ b/comparator.py @@ -1,34 +1,30 @@ -#!/usr/bin/python # -*- coding: utf-8 -*- -# author: Bryan Bishop <kanzure@gmail.com> -# date: 2012-05-29 -# purpose: find shared functions between red/crystal - -from crystal import get_label_from_line, \ - get_address_from_line_comment, \ - AsmSection - -from romstr import RomStr, AsmList +""" +Finds shared functions between red/crystal. +""" + +from crystal import ( + get_label_from_line, + get_address_from_line_comment, + AsmSection, + direct_load_rom, + direct_load_asm, +) + +from romstr import ( + RomStr, + AsmList, +) def load_rom(path): """ Loads a ROM file into an abbreviated RomStr object. """ - - fh = open(path, "r") - x = RomStr(fh.read()) - fh.close() - - return x + return direct_load_rom(filename=path) def load_asm(path): """ Loads source ASM into an abbreviated AsmList object. """ - - fh = open(path, "r") - x = AsmList(fh.read().split("\n")) - fh.close() - - return x + return direct_load_asm(filename=path) def findall_iter(sub, string): # url: http://stackoverflow.com/a/3874760/687783 @@ -1,29 +1,19 @@ # -*- coding: utf-8 -*- # utilities to help disassemble pokémon crystal -import sys, os, inspect, hashlib, json +import os +import sys +import inspect +import hashlib +import json from copy import copy, deepcopy import subprocess from new import classobj import random -# for IntervalMap -from bisect import bisect_left, bisect_right -from itertools import izip - -# for testing all this crap -try: - import unittest2 as unittest -except ImportError: - import unittest - # for capwords import string -# Check for things we need in unittest. -if not hasattr(unittest.TestCase, 'setUpClass'): - sys.stderr.write("The unittest2 module or Python 2.7 is required to run this script.") - sys.exit(1) - +# for python2.6 if not hasattr(json, "dumps"): json.dumps = json.write @@ -31,6 +21,14 @@ if not hasattr(json, "dumps"): if not hasattr(json, "read"): json.read = json.loads +from labels import ( + remove_quoted_text, + line_has_comment_address, + line_has_label, + get_label_from_line, + get_address_from_line_comment +) + spacing = "\t" lousy_dragon_shrine_hack = [0x18d079, 0x18d0a9, 0x18d061, 0x18d091] @@ -63,132 +61,25 @@ constant_abbreviation_bytes = {} # Import the characters from its module. from chars import chars, jap_chars -from trainers import * +from trainers import ( + trainer_group_pointer_table_address, # 0x39999 + trainer_group_pointer_table_address_gs, # 0x3993E + trainer_group_names, +) from move_constants import moves # for fixing trainer_group_names import re -trainer_group_pointer_table_address = 0x39999 -trainer_group_pointer_table_address_gs = 0x3993E - -class Size(): - """a simple way to track whether or not a size - includes the first value or not, like for - whether or not the size of a command in a script - also includes the command byte or not""" - - def __init__(self, size, inclusive=False): - self.inclusive = inclusive - if inclusive: size = size-1 - self.size = size - - def inclusive(self): - return self.size + 1 - - def exclusive(self): - return self.size - -class IntervalMap(object): - """ - This class maps a set of intervals to a set of values. - - >>> i = IntervalMap() - >>> i[0:5] = "hello world" - >>> i[6:10] = "hello cruel world" - >>> print i[4] - "hello world" - """ - - def __init__(self): - """initializes an empty IntervalMap""" - self._bounds = [] - self._items = [] - self._upperitem = None - - def __setitem__(self, _slice, _value): - """sets an interval mapping""" - assert isinstance(_slice, slice), 'The key must be a slice object' - - if _slice.start is None: - start_point = -1 - else: - start_point = bisect_left(self._bounds, _slice.start) - - if _slice.stop is None: - end_point = -1 - else: - end_point = bisect_left(self._bounds, _slice.stop) - - if start_point>=0: - if start_point < len(self._bounds) and self._bounds[start_point]<_slice.start: - start_point += 1 - - if end_point>=0: - self._bounds[start_point:end_point] = [_slice.start, _slice.stop] - if start_point < len(self._items): - self._items[start_point:end_point] = [self._items[start_point], _value] - else: - self._items[start_point:end_point] = [self._upperitem, _value] - else: - self._bounds[start_point:] = [_slice.start] - if start_point < len(self._items): - self._items[start_point:] = [self._items[start_point], _value] - else: - self._items[start_point:] = [self._upperitem] - self._upperitem = _value - else: - if end_point>=0: - self._bounds[:end_point] = [_slice.stop] - self._items[:end_point] = [_value] - else: - self._bounds[:] = [] - self._items[:] = [] - self._upperitem = _value - - def __getitem__(self,_point): - """gets a value from the mapping""" - assert not isinstance(_point, slice), 'The key cannot be a slice object' - - index = bisect_right(self._bounds, _point) - if index < len(self._bounds): - return self._items[index] - else: - return self._upperitem - - def items(self): - """returns an iterator with each item being - ((low_bound, high_bound), value) - these items are returned in order""" - previous_bound = None - for (b, v) in izip(self._bounds, self._items): - if v is not None: - yield (previous_bound, b), v - previous_bound = b - if self._upperitem is not None: - yield (previous_bound, None), self._upperitem - - def values(self): - """returns an iterator with each item being a stored value - the items are returned in order""" - for v in self._items: - if v is not None: - yield v - if self._upperitem is not None: - yield self._upperitem - - def __repr__(self): - s = [] - for b,v in self.items(): - if v is not None: - s.append('[%r, %r] => %r'%( - b[0], - b[1], - v - )) - return '{'+', '.join(s)+'}' +from interval_map import IntervalMap +from pksv import ( + pksv_gs, + pksv_crystal, + pksv_crystal_unknowns, + pksv_crystal_more_enders +) # ---- script_parse_table explanation ---- # This is an IntervalMap that keeps track of previously parsed scripts, texts @@ -206,7 +97,8 @@ script_parse_table = IntervalMap() def is_script_already_parsed_at(address): """looks up whether or not a script is parsed at a certain address""" - if script_parse_table[address] == None: return False + if script_parse_table[address] == None: + return False return True def script_parse_table_pretty_printer(): @@ -230,7 +122,10 @@ def map_name_cleaner(input): replace("hooh", "HoOh").\ replace(" ", "") -from romstr import RomStr, AsmList +from romstr import ( + RomStr, + AsmList, +) rom = RomStr(None) @@ -253,12 +148,16 @@ def load_rom(filename="../baserom.gbc"): elif os.lstat(filename).st_size != len(rom): return direct_load_rom(filename) +def direct_load_asm(filename="../main.asm"): + """returns asm source code (AsmList) from a file""" + asm = open(filename, "r").read().split("\n") + asm = AsmList(asm) + return asm def load_asm(filename="../main.asm"): - """loads the asm source code into memory""" + """returns asm source code (AsmList) from a file (uses a global)""" global asm - asm = open(filename, "r").read().split("\n") - asm = AsmList(asm) + asm = direct_load_asm(filename=filename) return asm def grouper(some_list, count=2): @@ -269,11 +168,14 @@ def grouper(some_list, count=2): def is_valid_address(address): """is_valid_rom_address""" - if address == None: return False + if address == None: + return False if type(address) == str: address = int(address, 16) - if 0 <= address <= 2097152: return True - else: return False + if 0 <= address <= 2097152: + return True + else: + return False def rom_interval(offset, length, strings=True, debug=True): """returns hex values for the rom starting at offset until offset+length""" @@ -302,7 +204,10 @@ def load_map_group_offsets(): map_group_offsets.append(offset) return map_group_offsets -from pointers import calculate_bank, calculate_pointer +from pointers import ( + calculate_bank, + calculate_pointer, +) def calculate_pointer_from_bytes_at(address, bank=False): """calculates a pointer from 2 bytes at a location @@ -317,7 +222,7 @@ def calculate_pointer_from_bytes_at(address, bank=False): elif type(bank) == int: pass else: - raise Exception, "bad bank given to calculate_pointer_from_bytes_at" + raise Exception("bad bank given to calculate_pointer_from_bytes_at") byte1 = ord(rom[address]) byte2 = ord(rom[address+1]) temp = byte1 + (byte2 << 8) @@ -345,6 +250,20 @@ def clean_up_long_info(long_info): long_info = "\n".join(new_lines) return long_info +from pokemon_constants import pokemon_constants + +def get_pokemon_constant_by_id(id): + if id == 0: + return None + else: + return pokemon_constants[id] + +from item_constants import ( + item_constants, + find_item_label_by_id, + generate_item_constants, +) + def command_debug_information(command_byte=None, map_group=None, map_id=None, address=0, info=None, long_info=None, pksv_name=None): "used to help debug in parse_script_engine_script_at" info1 = "parsing command byte " + hex(command_byte) + " for map " + \ @@ -376,7 +295,7 @@ class TextScript: self.force = force if is_script_already_parsed_at(address) and not force: - raise Exception, "TextScript already parsed at "+hex(address) + raise Exception("TextScript already parsed at "+hex(address)) if not label: label = self.base_label + hex(address) @@ -449,11 +368,11 @@ class TextScript: print "self.commands is: " + str(commands) print "command 0 address is: " + hex(commands[0].address) + " last_address="+hex(commands[0].last_address) print "command 1 address is: " + hex(commands[1].address) + " last_address="+hex(commands[1].last_address) - raise Exception, "going beyond the bounds for this text script" + raise Exception("going beyond the bounds for this text script") # no matching command found if scripting_command_class == None: - raise Exception, "unable to parse text command $%.2x in the text script at %s at %s" % (cur_byte, hex(start_address), hex(current_address)) + raise Exception("unable to parse text command $%.2x in the text script at %s at %s" % (cur_byte, hex(start_address), hex(current_address))) # create an instance of the command class and let it parse its parameter bytes cls = scripting_command_class(address=current_address, map_group=self.map_group, map_id=self.map_id, debug=self.debug, force=self.force) @@ -631,7 +550,7 @@ class OldTextScript: if is_script_already_parsed_at(address) and not force: print "text is already parsed at this location: " + hex(address) - raise Exception, "text is already parsed, what's going on ?" + raise Exception("text is already parsed, what's going on ?") return script_parse_table[address] total_text_commands = 0 @@ -845,7 +764,7 @@ class OldTextScript: def get_dependencies(self, recompute=False, global_dependencies=set()): #if recompute: - # raise NotImplementedError, bryan_message + # raise NotImplementedError(bryan_message) global_dependencies.update(self.dependencies) return self.dependencies @@ -1157,7 +1076,8 @@ def parse_text_at3(address, map_group=None, map_id=None, debug=False): text = TextScript(address, map_group=map_group, map_id=map_id, debug=debug) if text.is_valid(): return text - else: return None + else: + return None def rom_text_at(address, count=10): """prints out raw text from the ROM @@ -1166,8 +1086,10 @@ def rom_text_at(address, count=10): def get_map_constant_label(map_group=None, map_id=None): """returns PALLET_TOWN for some map group/id pair""" - if map_group == None: raise Exception, "need map_group" - if map_id == None: raise Exception, "need map_id" + if map_group == None: + raise Exception("need map_group") + if map_id == None: + raise Exception("need map_id") global map_internal_ids for (id, each) in map_internal_ids.items(): @@ -1185,7 +1107,8 @@ def get_id_for_map_constant_label(label): PALLET_TOWN = 1, for instance.""" global map_internal_ids for (id, each) in map_internal_ids.items(): - if each["label"] == label: return id + if each["label"] == label: + return id return None def generate_map_constant_labels(): @@ -1264,37 +1187,16 @@ def transform_wildmons(asm): returnlines.append(line) return "\n".join(returnlines) -from pokemon_constants import pokemon_constants - -def get_pokemon_constant_by_id(id): - if id == 0: return None - return pokemon_constants[id] - def parse_script_asm_at(*args, **kwargs): # XXX TODO return None -from item_constants import item_constants - -def find_item_label_by_id(id): - if id in item_constants.keys(): - return item_constants[id] - else: return None - -def generate_item_constants(): - """make a list of items to put in constants.asm""" - output = "" - for (id, item) in item_constants.items(): - val = ("$%.2x"%id).upper() - while len(item)<13: item+= " " - output += item + " EQU " + val + "\n" - return output - def find_all_text_pointers_in_script_engine_script(script, bank=None, debug=False): """returns a list of text pointers based on each script-engine script command""" # TODO: recursively follow any jumps in the script - if script == None: return [] + if script == None: + return [] addresses = set() for (k, command) in enumerate(script.commands): if debug: @@ -1329,16 +1231,16 @@ def translate_command_byte(crystal=None, gold=None): if 0x53 <= crystal <= 0x9E: return crystal-1 if crystal == 0x9F: return None if 0xA0 <= crystal <= 0xA5: return crystal-2 - if crystal > 0xA5: raise Exception, "dunno yet if crystal has new insertions after crystal:0xA5 (gold:0xA3)" + if crystal > 0xA5: + raise Exception("dunno yet if crystal has new insertions after crystal:0xA5 (gold:0xA3)") elif gold != None: # convert to crystal if gold <= 0x51: return gold if 0x52 <= gold <= 0x9D: return gold+1 if 0x9E <= gold <= 0xA3: return gold+2 - if gold > 0xA3: raise Exception, "dunno yet if crystal has new insertions after gold:0xA3 (crystal:0xA5)" - else: raise Exception, "translate_command_byte needs either a crystal or gold command" - -from pksv import pksv_gs, pksv_crystal, pksv_crystal_unknowns,\ - pksv_crystal_more_enders + if gold > 0xA3: + raise Exception("dunno yet if crystal has new insertions after gold:0xA3 (crystal:0xA5)") + else: + raise Exception("translate_command_byte needs either a crystal or gold command") class SingleByteParam(): """or SingleByte(CommandParam)""" @@ -1351,14 +1253,14 @@ class SingleByteParam(): setattr(self, key, value) # check address if not hasattr(self, "address"): - raise Exception, "an address is a requirement" + raise Exception("an address is a requirement") elif self.address == None: - raise Exception, "address must not be None" + raise Exception("address must not be None") elif not is_valid_address(self.address): - raise Exception, "address must be valid" + raise Exception("address must be valid") # check size if not hasattr(self, "size") or self.size == None: - raise Exception, "size is probably 1?" + raise Exception("size is probably 1?") # parse bytes from ROM self.parse() @@ -1368,18 +1270,23 @@ class SingleByteParam(): return [] def to_asm(self): - if not self.should_be_decimal: return hex(self.byte).replace("0x", "$") - else: return str(self.byte) + if not self.should_be_decimal: + return hex(self.byte).replace("0x", "$") + else: + return str(self.byte) class DollarSignByte(SingleByteParam): - def to_asm(self): return hex(self.byte).replace("0x", "$") + def to_asm(self): + return hex(self.byte).replace("0x", "$") HexByte=DollarSignByte class ItemLabelByte(DollarSignByte): def to_asm(self): label = find_item_label_by_id(self.byte) - if label: return label - elif not label: return DollarSignByte.to_asm(self) + if label: + return label + elif not label: + return DollarSignByte.to_asm(self) class DecimalParam(SingleByteParam): @@ -1398,12 +1305,12 @@ class MultiByteParam(): setattr(self, key, value) # check address if not hasattr(self, "address") or self.address == None: - raise Exception, "an address is a requirement" + raise Exception("an address is a requirement") elif not is_valid_address(self.address): - raise Exception, "address must be valid" + raise Exception("address must be valid") # check size if not hasattr(self, "size") or self.size == None: - raise Exception, "don't know how many bytes to read (size)" + raise Exception("don't know how many bytes to read (size)") self.parse() def parse(self): @@ -1445,9 +1352,9 @@ class PointerLabelParam(MultiByteParam): self.size = self.default_size + 1 self.given_bank = kwargs["bank"] #if kwargs["bank"] not in [None, False, True, "reverse"]: - # raise Exception, "bank cannot be: " + str(kwargs["bank"]) + # raise Exception("bank cannot be: " + str(kwargs["bank"])) if self.size > 3: - raise Exception, "param size is too large" + raise Exception("param size is too large") # continue instantiation.. self.bank will be set down the road MultiByteParam.__init__(self, *args, **kwargs) @@ -1520,15 +1427,16 @@ class PointerLabelParam(MultiByteParam): return pointer_part+", "+bank_part elif bank == True: # bank, pointer return bank_part+", "+pointer_part - else: raise Exception, "this should never happen" - raise Exception, "this should never happen" + else: + raise Exception("this should never happen") + raise Exception("this should never happen") # this next one will either return the label or the raw bytes elif bank == False or bank == None: # pointer return pointer_part # this could be the same as label else: - #raise Exception, "this should never happen" + #raise Exception("this should never happen") return pointer_part # probably in the same bank ? - raise Exception, "this should never happen" + raise Exception("this should never happen") class PointerLabelBeforeBank(PointerLabelParam): bank = True # bank appears first, see calculate_pointer_from_bytes_at @@ -1549,12 +1457,12 @@ class ScriptPointerLabelBeforeBank(PointerLabelBeforeBank): pass class ScriptPointerLabelAfterBank(PointerLabelAfterBank): pass -def _parse_script_pointer_bytes(self): +def _parse_script_pointer_bytes(self, debug = False): PointerLabelParam.parse(self) - print "_parse_script_pointer_bytes - calculating the pointer located at " + hex(self.address) + if debug: print "_parse_script_pointer_bytes - calculating the pointer located at " + hex(self.address) address = calculate_pointer_from_bytes_at(self.address, bank=self.bank) if address != None and address > 0x4000: - print "_parse_script_pointer_bytes - the pointer is: " + hex(address) + if debug: print "_parse_script_pointer_bytes - the pointer is: " + hex(address) self.script = parse_script_engine_script_at(address, debug=self.debug, force=self.force, map_group=self.map_group, map_id=self.map_id) ScriptPointerLabelParam.parse = _parse_script_pointer_bytes ScriptPointerLabelBeforeBank.parse = _parse_script_pointer_bytes @@ -1587,8 +1495,10 @@ class RAMAddressParam(MultiByteParam): def to_asm(self): address = calculate_pointer_from_bytes_at(self.address, bank=False) label = get_ram_label(address) - if label: return label - else: return "$"+"".join(["%.2x"%x for x in reversed(self.bytes)])+"" + if label: + return label + else: + return "$"+"".join(["%.2x"%x for x in reversed(self.bytes)])+"" class MoneyByteParam(MultiByteParam): @@ -1646,9 +1556,11 @@ class MapGroupParam(SingleByteParam): def to_asm(self): map_id = ord(rom[self.address+1]) map_constant_label = get_map_constant_label(map_id=map_id, map_group=self.byte) # like PALLET_TOWN - if map_constant_label == None: return str(self.byte) + if map_constant_label == None: + return str(self.byte) #else: return "GROUP("+map_constant_label+")" - else: return "GROUP_"+map_constant_label + else: + return "GROUP_"+map_constant_label class MapIdParam(SingleByteParam): @@ -1659,9 +1571,11 @@ class MapIdParam(SingleByteParam): def to_asm(self): map_group = ord(rom[self.address-1]) map_constant_label = get_map_constant_label(map_id=self.byte, map_group=map_group) - if map_constant_label == None: return str(self.byte) + if map_constant_label == None: + return str(self.byte) #else: return "MAP("+map_constant_label+")" - else: return "MAP_"+map_constant_label + else: + return "MAP_"+map_constant_label class MapGroupIdParam(MultiByteParam): @@ -1680,13 +1594,15 @@ class MapGroupIdParam(MultiByteParam): class PokemonParam(SingleByteParam): def to_asm(self): pokemon_constant = get_pokemon_constant_by_id(self.byte) - if pokemon_constant: return pokemon_constant - else: return str(self.byte) + if pokemon_constant: + return pokemon_constant + else: + return str(self.byte) class PointerParamToItemAndLetter(MultiByteParam): # [2F][2byte pointer to item no + 0x20 bytes letter text] - #raise NotImplementedError, bryan_message + #raise NotImplementedError(bryan_message) pass @@ -1702,7 +1618,7 @@ class TrainerIdParam(SingleByteParam): i += 1 if foundit == None: - raise Exception, "didn't find a TrainerGroupParam in this command??" + raise Exception("didn't find a TrainerGroupParam in this command??") # now get the trainer group id trainer_group_id = self.parent.params[foundit].byte @@ -1729,7 +1645,7 @@ class MoveParam(SingleByteParam): class MenuDataPointerParam(PointerLabelParam): # read menu data at the target site - #raise NotImplementedError, bryan_message + #raise NotImplementedError(bryan_message) pass @@ -1813,7 +1729,7 @@ class MovementPointerLabelParam(PointerLabelParam): global_dependencies.add(self.movement) return [self.movement] + self.movement.get_dependencies(recompute=recompute, global_dependencies=global_dependencies) else: - raise Exception, "MovementPointerLabelParam hasn't been parsed yet" + raise Exception("MovementPointerLabelParam hasn't been parsed yet") class MapDataPointerParam(PointerLabelParam): pass @@ -1838,7 +1754,7 @@ class Command: """ defaults = {"force": False, "debug": False, "map_group": None, "map_id": None} if not is_valid_address(address): - raise Exception, "address is invalid" + raise Exception("address is invalid") # set up some variables self.address = address self.last_address = None @@ -1878,7 +1794,8 @@ class Command: # output += "_" output += self.macro_name # return if there are no params - if len(self.param_types.keys()) == 0: return output + if len(self.param_types.keys()) == 0: + return output # first one will have no prefixing comma first = True # start reading the bytes after the command byte @@ -1923,7 +1840,7 @@ class Command: current_address = self.address byte = ord(rom[self.address]) if not self.override_byte_check and (not byte == self.id): - raise Exception, "byte ("+hex(byte)+") != self.id ("+hex(self.id)+")" + raise Exception("byte ("+hex(byte)+") != self.id ("+hex(self.id)+")") i = 0 for (key, param_type) in self.param_types.items(): name = param_type["name"] @@ -1959,7 +1876,7 @@ class GivePoke(Command): self.params = {} byte = ord(rom[self.address]) if not byte == self.id: - raise Exception, "this should never happen" + raise Exception("this should never happen") current_address = self.address+1 i = 0 self.size = 1 @@ -2120,7 +2037,8 @@ def create_movement_commands(debug=False): direction = "left" elif x == 3: direction = "right" - else: raise Exception, "this should never happen" + else: + raise Exception("this should never happen") cmd_name = cmd[0].replace(" ", "_") + "_" + direction klass_name = cmd_name+"Command" @@ -2355,11 +2273,11 @@ class MainText(TextCommand): print "bytes are: " + str(self.bytes) print "self.size is: " + str(self.size) print "self.last_address is: " + hex(self.last_address) - raise Exception, "last_address is wrong for 0x9c00e" + raise Exception("last_address is wrong for 0x9c00e") def to_asm(self): if self.size < 2 or len(self.bytes) < 1: - raise Exception, "$0 text command can't end itself with no follow-on bytes" + raise Exception("$0 text command can't end itself with no follow-on bytes") if self.use_zero: output = "db $0" @@ -2390,13 +2308,13 @@ class MainText(TextCommand): for byte in self.bytes: if end: - raise Exception, "the text ended due to a $50 or $57 but there are more bytes?" + raise Exception("the text ended due to a $50 or $57 but there are more bytes?") if new_line: if in_quotes: - raise Exception, "can't be in_quotes on a newline" + raise Exception("can't be in_quotes on a newline") elif was_comma: - raise Exception, "last line's last character can't be a comma" + raise Exception("last line's last character can't be a comma") output += "db " @@ -2490,7 +2408,7 @@ class MainText(TextCommand): was_comma = False end = False else: - # raise Exception, "unknown byte in text script ($%.2x)" % (byte) + # raise Exception("unknown byte in text script ($%.2x)" % (byte)) # just add an unknown byte directly to the text.. what's the worse that can happen? if in_quotes: @@ -2511,7 +2429,7 @@ class MainText(TextCommand): # this shouldn't happen because of the rom_until calls in the parse method if not end: - raise Exception, "ran out of bytes without the script ending? starts at "+hex(self.address) + raise Exception("ran out of bytes without the script ending? starts at "+hex(self.address)) # last character may or may not be allowed to be a newline? # Script.to_asm() has command.to_asm()+"\n" @@ -2817,6 +2735,7 @@ pksv_crystal_more = { 0x4F: ["loadmenudata", ["data", MenuDataPointerParam]], 0x50: ["writebackup"], 0x51: ["jumptextfaceplayer", ["text_pointer", RawTextPointerLabelParam]], + 0x52: ["3jumptext", ["text_pointer", PointerLabelBeforeBank]], 0x53: ["jumptext", ["text_pointer", RawTextPointerLabelParam]], 0x54: ["closetext"], 0x55: ["keeptextopen"], @@ -3088,17 +3007,17 @@ class Script: self.address = None self.commands = None if len(kwargs) == 0 and len(args) == 0: - raise Exception, "Script.__init__ must be given some arguments" + raise Exception("Script.__init__ must be given some arguments") # first positional argument is address if len(args) == 1: address = args[0] if type(address) == str: address = int(address, 16) elif type(address) != int: - raise Exception, "address must be an integer or string" + raise Exception("address must be an integer or string") self.address = address elif len(args) > 1: - raise Exception, "don't know what to do with second (or later) positional arguments" + raise Exception("don't know what to do with second (or later) positional arguments") self.dependencies = None if "label" in kwargs.keys(): label = kwargs["label"] @@ -3160,15 +3079,15 @@ class Script: """ global command_classes, rom, script_parse_table current_address = start_address - print "Script.parse address="+hex(self.address) +" map_group="+str(map_group)+" map_id="+str(map_id) + if debug: print "Script.parse address="+hex(self.address) +" map_group="+str(map_group)+" map_id="+str(map_id) if start_address in stop_points and force == False: - print "script parsing is stopping at stop_point=" + hex(start_address) + " at map_group="+str(map_group)+" map_id="+str(map_id) + if debug: print "script parsing is stopping at stop_point=" + hex(start_address) + " at map_group="+str(map_group)+" map_id="+str(map_id) return None if start_address < 0x4000 and start_address not in [0x26ef, 0x114, 0x1108]: - print "address is less than 0x4000.. address is: " + hex(start_address) + if debug: print "address is less than 0x4000.. address is: " + hex(start_address) sys.exit(1) if is_script_already_parsed_at(start_address) and not force and not force_top: - raise Exception, "this script has already been parsed before, please use that instance ("+hex(start_address)+")" + raise Exception("this script has already been parsed before, please use that instance ("+hex(start_address)+")") # load up the rom if it hasn't been loaded already load_rom() @@ -3198,13 +3117,13 @@ class Script: # no matching command found (not implemented yet)- just end this script # NOTE: might be better to raise an exception and end the program? if scripting_command_class == None: - print "parsing script; current_address is: " + hex(current_address) + if debug: print "parsing script; current_address is: " + hex(current_address) current_address += 1 asm_output = "\n".join([command.to_asm() for command in commands]) end = True continue # maybe the program should exit with failure instead? - #raise Exception, "no command found? id: " + hex(cur_byte) + " at " + hex(current_address) + " asm is:\n" + asm_output + #raise Exception("no command found? id: " + hex(cur_byte) + " at " + hex(current_address) + " asm is:\n" + asm_output) # create an instance of the command class and let it parse its parameter bytes #print "about to parse command(script@"+hex(start_address)+"): " + str(scripting_command_class.macro_name) @@ -3231,7 +3150,7 @@ class Script: script_parse_table[start_address:current_address] = self asm_output = "\n".join([command.to_asm() for command in commands]) - print "--------------\n"+asm_output + if debug: print "--------------\n"+asm_output # store the script self.commands = commands @@ -3529,7 +3448,8 @@ class TrainerFragment(Command): def get_dependencies(self, recompute=False, global_dependencies=set()): deps = [] - if not is_valid_address(self.address): return deps + if not is_valid_address(self.address): + return deps if self.dependencies != None and not recompute: global_dependencies.update(self.dependencies) return self.dependencies @@ -3867,7 +3787,7 @@ class TrainerHeader: break if party_mon_parser == None: - raise Exception, "no trainer party mon parser found to parse data type " + hex(self.data_type) + raise Exception("no trainer party mon parser found to parse data type " + hex(self.data_type)) self.party_mons = party_mon_parser(address=current_address, group_id=self.trainer_group_id, trainer_id=self.trainer_id, parent=self) @@ -4422,7 +4342,8 @@ class SignpostRemoteBase: def to_asm(self): """very similar to Command.to_asm""" - if len(self.params) == 0: return "" + if len(self.params) == 0: + return "" #output = ", ".join([p.to_asm() for p in self.params]) output = "" for param in self.params: @@ -4670,7 +4591,7 @@ class Signpost(Command): mb = PointerLabelParam(address=self.address+3, map_group=self.map_group, map_id=self.map_id, debug=self.debug) self.params.append(mb) else: - raise Exception, "unknown signpost type byte="+hex(func) + " signpost@"+hex(self.address) + raise Exception("unknown signpost type byte="+hex(func) + " signpost@"+hex(self.address)) def get_dependencies(self, recompute=False, global_dependencies=set()): dependencies = [] @@ -4684,13 +4605,15 @@ class Signpost(Command): def to_asm(self): output = self.macro_name + " " - if self.params == []: raise Exception, "signpost has no params?" + if self.params == []: + raise Exception("signpost has no params?") output += ", ".join([p.to_asm() for p in self.params]) return output all_signposts = [] def parse_signposts(address, signpost_count, bank=None, map_group=None, map_id=None, debug=True): - if bank == None: raise Exception, "signposts need to know their bank" + if bank == None: + raise Exception("signposts need to know their bank") signposts = [] current_address = address id = 0 @@ -5273,7 +5196,7 @@ class Connection: wrong_norths.append(data) # this will only happen if there's a bad formula - raise Exception, "tauwasser strip_pointer calculation was wrong? strip_pointer="+hex(strip_pointer) + " p="+hex(p) + raise Exception("tauwasser strip_pointer calculation was wrong? strip_pointer="+hex(strip_pointer) + " p="+hex(p)) calculated_destination = None method = "strip_destination_default" @@ -5295,7 +5218,7 @@ class Connection: x_movement_of_the_connection_strip_in_blocks = strip_destination - 0xC703 print "(north) x_movement_of_the_connection_strip_in_blocks is: " + str(x_movement_of_the_connection_strip_in_blocks) if x_movement_of_the_connection_strip_in_blocks < 0: - raise Exception, "x_movement_of_the_connection_strip_in_blocks is wrong? " + str(x_movement_of_the_connection_strip_in_blocks) + raise Exception("x_movement_of_the_connection_strip_in_blocks is wrong? " + str(x_movement_of_the_connection_strip_in_blocks)) elif ldirection == "south": # strip_destination = # 0xc703 + (current_map_height + 3) * (current_map_width + 6) + x_movement_of_the_connection_strip_in_blocks @@ -5570,11 +5493,11 @@ class Connection: yoffset = self.yoffset # y_position_after_map_change if ldirection == "south" and yoffset != 0: - raise Exception, "tauwasser was wrong about yoffset=0 for south? it's: " + str(yoffset) + raise Exception("tauwasser was wrong about yoffset=0 for south? it's: " + str(yoffset)) elif ldirection == "north" and yoffset != ((connected_map_height * 2) - 1): - raise Exception, "tauwasser was wrong about yoffset for north? it's: " + str(yoffset) + raise Exception("tauwasser was wrong about yoffset for north? it's: " + str(yoffset)) #elif not ((yoffset % -2) == 0): - # raise Exception, "tauwasser was wrong about yoffset for west/east? it's not divisible by -2: " + str(yoffset) + # raise Exception("tauwasser was wrong about yoffset for west/east? it's not divisible by -2: " + str(yoffset)) # Left: (Width_of_connected_map * 2) - 1 # Right: 0 @@ -5582,11 +5505,11 @@ class Connection: xoffset = self.xoffset # x_position_after_map_change if ldirection == "east" and xoffset != 0: - raise Exception, "tauwasser was wrong about xoffset=0 for east? it's: " + str(xoffset) + raise Exception("tauwasser was wrong about xoffset=0 for east? it's: " + str(xoffset)) elif ldirection == "west" and xoffset != ((connected_map_width * 2) - 1): - raise Exception, "tauwasser was wrong about xoffset for west? it's: " + str(xoffset) + raise Exception("tauwasser was wrong about xoffset for west? it's: " + str(xoffset)) #elif not ((xoffset % -2) == 0): - # raise Exception, "tauwasser was wrong about xoffset for north/south? it's not divisible by -2: " + str(xoffset) + # raise Exception("tauwasser was wrong about xoffset for north/south? it's not divisible by -2: " + str(xoffset)) output += "db " @@ -5703,7 +5626,7 @@ class MapBlockData: self.width = width self.height = height else: - raise Exception, "MapBlockData needs to know the width/height of its map" + raise Exception("MapBlockData needs to know the width/height of its map") label = self.make_label() self.label = Label(name=label, address=address, object=self) self.last_address = self.address + (self.width.byte * self.height.byte) @@ -6270,14 +6193,14 @@ def parse_map_header_by_id(*args, **kwargs): map_id = kwargs["map_id"] if (map_group == None and map_id != None) or \ (map_group != None and map_id == None): - raise Exception, "map_group and map_id must both be provided" + raise Exception("map_group and map_id must both be provided") elif map_group == None and map_id == None and len(args) == 0: - raise Exception, "must be given an argument" + raise Exception("must be given an argument") elif len(args) == 1 and type(args[0]) == str: map_group = int(args[0].split(".")[0]) map_id = int(args[0].split(".")[1]) else: - raise Exception, "dunno what to do with input" + raise Exception("dunno what to do with input") offset = map_names[map_group]["offset"] map_header_offset = offset + ((map_id - 1) * map_header_byte_size) return parse_map_header_at(map_header_offset, map_group=map_group, map_id=map_id) @@ -6286,7 +6209,7 @@ def parse_all_map_headers(debug=True): """calls parse_map_header_at for each map in each map group""" global map_names if not map_names[1].has_key("offset"): - raise Exception, "dunno what to do - map_names should have groups with pre-calculated offsets by now" + raise Exception("dunno what to do - map_names should have groups with pre-calculated offsets by now") for group_id, group_data in map_names.items(): offset = group_data["offset"] # we only care about the maps @@ -7045,7 +6968,7 @@ def find_incbin_to_replace_for(address, debug=False, rom_file="../baserom.gbc"): if you were to insert bytes into main.asm""" if type(address) == str: address = int(address, 16) if not (0 <= address <= os.lstat(rom_file).st_size): - raise IndexError, "address is out of bounds" + raise IndexError("address is out of bounds") for incbin_key in processed_incbins.keys(): incbin = processed_incbins[incbin_key] start = incbin["start"] @@ -7069,9 +6992,9 @@ def split_incbin_line_into_three(line, start_address, byte_count, rom_file="../b """ if type(start_address) == str: start_address = int(start_address, 16) if not (0 <= start_address <= os.lstat(rom_file).st_size): - raise IndexError, "start_address is out of bounds" + raise IndexError("start_address is out of bounds") if len(processed_incbins) == 0: - raise Exception, "processed_incbins must be populated" + raise Exception("processed_incbins must be populated") original_incbin = processed_incbins[line] start = original_incbin["start"] @@ -7191,7 +7114,7 @@ class Incbin: start = eval(start) except Exception, e: print "start is: " + str(start) - raise Exception, "problem with evaluating interval range: " + str(e) + raise Exception("problem with evaluating interval range: " + str(e)) start_hex = hex(start).replace("0x", "$") @@ -7212,11 +7135,12 @@ class Incbin: def to_asm(self): if self.interval > 0: return self.line - else: return "" + else: + return "" def split(self, start_address, byte_count): """splits this incbin into three separate incbins""" if start_address < self.start_address or start_address > self.end_address: - raise Exception, "this incbin doesn't handle this address" + raise Exception("this incbin doesn't handle this address") incbins = [] if self.debug: @@ -7358,7 +7282,7 @@ class Asm: if not hasattr(new_object, "last_address"): print debugmsg - raise Exception, "object needs to have a last_address property" + raise Exception("object needs to have a last_address property") end_address = new_object.last_address debugmsg += " last_address="+hex(end_address) @@ -7384,7 +7308,7 @@ class Asm: print "start_address="+hex(start_address)+" end_address="+hex(end_address) if hasattr(new_object, "to_asm"): print to_asm(new_object) - raise Exception, "Asm.insert was given an object with a bad address range" + raise Exception("Asm.insert was given an object with a bad address range") # 1) find which object needs to be replaced # or @@ -7426,7 +7350,7 @@ class Asm: found = True break if not found: - raise Exception, "unable to insert object into Asm" + raise Exception("unable to insert object into Asm") self.labels.append(new_object.label) return True def insert_with_dependencies(self, input): @@ -7458,9 +7382,9 @@ class Asm: # just some old debugging #if object.label.name == "UnknownText_0x60128": - # raise Exception, "debugging..." + # raise Exception("debugging...") #elif object.label.name == "UnknownScript_0x60011": - # raise Exception, "debugging.. dependencies are: " + str(object.dependencies) + " versus: " + str(object.get_dependencies()) + # raise Exception("debugging.. dependencies are: " + str(object.dependencies) + " versus: " + str(object.get_dependencies())) def insert_single_with_dependencies(self, object): self.insert_with_dependencies(object) def insert_multiple_with_dependencies(self, objects): @@ -7516,7 +7440,7 @@ class Asm: current_requested_newlines_before = 2 current_requested_newlines_after = 2 else: - raise Exception, "dunno what to do with("+str(each)+") in Asm.parts" + raise Exception("dunno what to do with("+str(each)+") in Asm.parts") if write_something: if not first: @@ -7546,7 +7470,7 @@ def list_texts_in_bank(bank): that you will be inserting into Asm. """ if len(all_texts) == 0: - raise Exception, "all_texts is blank.. run_main() will populate it" + raise Exception("all_texts is blank.. run_main() will populate it") assert bank != None, "list_texts_in_banks must be given a particular bank" @@ -7564,7 +7488,7 @@ def list_movements_in_bank(bank): to speed up Asm insertion. """ if len(all_movements) == 0: - raise Exception, "all_movements is blank.. run_main() will populate it" + raise Exception("all_movements is blank.. run_main() will populate it") assert bank != None, "list_movements_in_bank must be given a particular bank" assert 0 <= bank < 0x80, "bank doesn't exist in the ROM (out of bounds)" @@ -7673,7 +7597,7 @@ def get_label_for(address): if address == None: return None if type(address) != int: - raise Exception, "get_label_for requires an integer address, got: " + str(type(address)) + raise Exception("get_label_for requires an integer address, got: " + str(type(address))) # lousy hack to get around recursive scripts in dragon shrine if address in lousy_dragon_shrine_hack: @@ -7787,10 +7711,6 @@ class Label: name = object.make_label() return name -from labels import remove_quoted_text, line_has_comment_address, \ - line_has_label, get_label_from_line, \ - get_address_from_line_comment - def find_labels_without_addresses(): """scans the asm source and finds labels that are unmarked""" without_addresses = [] @@ -7872,14 +7792,9 @@ def scan_for_predefined_labels(debug=False): abbreviation_next = "1" # calculate the start/stop line numbers for this bank - for a in (abbreviation, abbreviation.lower()): - start_line_id = index(asm, lambda line: "\"bank" + a + "\"" in line) - if start_line_id != None: break - + start_line_id = index(asm, lambda line: "\"bank" + abbreviation.lower() + "\"" in line.lower()) if bank_id != 0x7F: - for a in (abbreviation_next, abbreviation_next.lower()): - end_line_id = index(asm, lambda line: "\"bank" + a + "\"" in line) - if end_line_id != None: break + end_line_id = index(asm, lambda line: "\"bank" + abbreviation_next.lower() + "\"" in line.lower()) end_line_id += 1 else: end_line_id = len(asm) - 1 @@ -7907,924 +7822,6 @@ def scan_for_predefined_labels(debug=False): write_all_labels(all_labels) return all_labels -#### generic testing #### - -class TestCram(unittest.TestCase): - "this is where i cram all of my unit tests together" - - @classmethod - def setUpClass(cls): - global rom - cls.rom = direct_load_rom() - rom = cls.rom - - @classmethod - def tearDownClass(cls): - del cls.rom - - def test_generic_useless(self): - "do i know how to write a test?" - self.assertEqual(1, 1) - - def test_map_name_cleaner(self): - name = "hello world" - cleaned_name = map_name_cleaner(name) - self.assertNotEqual(name, cleaned_name) - self.failUnless(" " not in cleaned_name) - name = "Some Random Pokémon Center" - cleaned_name = map_name_cleaner(name) - self.assertNotEqual(name, cleaned_name) - self.failIf(" " in cleaned_name) - self.failIf("é" in cleaned_name) - - def test_grouper(self): - data = range(0, 10) - groups = grouper(data, count=2) - self.assertEquals(len(groups), 5) - data = range(0, 20) - groups = grouper(data, count=2) - self.assertEquals(len(groups), 10) - self.assertNotEqual(data, groups) - self.assertNotEqual(len(data), len(groups)) - - def test_direct_load_rom(self): - rom = self.rom - self.assertEqual(len(rom), 2097152) - self.failUnless(isinstance(rom, RomStr)) - - def test_load_rom(self): - global rom - rom = None - load_rom() - self.failIf(rom == None) - rom = RomStr(None) - load_rom() - self.failIf(rom == RomStr(None)) - - def test_load_asm(self): - asm = load_asm() - joined_lines = "\n".join(asm) - self.failUnless("SECTION" in joined_lines) - self.failUnless("bank" in joined_lines) - self.failUnless(isinstance(asm, AsmList)) - - def test_rom_file_existence(self): - "ROM file must exist" - self.failUnless("baserom.gbc" in os.listdir("../")) - - def test_rom_md5(self): - "ROM file must have the correct md5 sum" - rom = self.rom - correct = "9f2922b235a5eeb78d65594e82ef5dde" - md5 = hashlib.md5() - md5.update(rom) - md5sum = md5.hexdigest() - self.assertEqual(md5sum, correct) - - def test_bizarre_http_presence(self): - rom_segment = self.rom[0x112116:0x112116+8] - self.assertEqual(rom_segment, "HTTP/1.0") - - def test_rom_interval(self): - address = 0x100 - interval = 10 - correct_strings = ['0x0', '0xc3', '0x6e', '0x1', '0xce', - '0xed', '0x66', '0x66', '0xcc', '0xd'] - byte_strings = rom_interval(address, interval, strings=True) - self.assertEqual(byte_strings, correct_strings) - correct_ints = [0, 195, 110, 1, 206, 237, 102, 102, 204, 13] - ints = rom_interval(address, interval, strings=False) - self.assertEqual(ints, correct_ints) - - def test_rom_until(self): - address = 0x1337 - byte = 0x13 - bytes = rom_until(address, byte, strings=True) - self.failUnless(len(bytes) == 3) - self.failUnless(bytes[0] == '0xd5') - bytes = rom_until(address, byte, strings=False) - self.failUnless(len(bytes) == 3) - self.failUnless(bytes[0] == 0xd5) - - def test_how_many_until(self): - how_many = how_many_until(chr(0x13), 0x1337) - self.assertEqual(how_many, 3) - - def test_calculate_bank(self): - self.failUnless(calculate_bank(0x8000) == 2) - self.failUnless(calculate_bank("0x9000") == 2) - self.failUnless(calculate_bank(0) == 0) - for address in [0x4000, 0x5000, 0x6000, 0x7000]: - self.assertRaises(Exception, calculate_bank, address) - - def test_calculate_pointer(self): - # for offset <= 0x4000 - self.assertEqual(calculate_pointer(0x0000), 0x0000) - self.assertEqual(calculate_pointer(0x3FFF), 0x3FFF) - # for 0x4000 <= offset <= 0x7FFFF - self.assertEqual(calculate_pointer(0x430F, bank=5), 0x1430F) - # for offset >= 0x7FFF - self.assertEqual(calculate_pointer(0x8FFF, bank=6), calculate_pointer(0x8FFF, bank=7)) - - def test_calculate_pointer_from_bytes_at(self): - addr1 = calculate_pointer_from_bytes_at(0x100, bank=False) - self.assertEqual(addr1, 0xc300) - addr2 = calculate_pointer_from_bytes_at(0x100, bank=True) - self.assertEqual(addr2, 0x2ec3) - - def test_rom_text_at(self): - self.assertEquals(rom_text_at(0x112116, 8), "HTTP/1.0") - - def test_translate_command_byte(self): - self.failUnless(translate_command_byte(crystal=0x0) == 0x0) - self.failUnless(translate_command_byte(crystal=0x10) == 0x10) - self.failUnless(translate_command_byte(crystal=0x40) == 0x40) - self.failUnless(translate_command_byte(gold=0x0) == 0x0) - self.failUnless(translate_command_byte(gold=0x10) == 0x10) - self.failUnless(translate_command_byte(gold=0x40) == 0x40) - self.assertEqual(translate_command_byte(gold=0x0), translate_command_byte(crystal=0x0)) - self.failUnless(translate_command_byte(gold=0x52) == 0x53) - self.failUnless(translate_command_byte(gold=0x53) == 0x54) - self.failUnless(translate_command_byte(crystal=0x53) == 0x52) - self.failUnless(translate_command_byte(crystal=0x52) == None) - self.assertRaises(Exception, translate_command_byte, None, gold=0xA4) - - def test_pksv_integrity(self): - "does pksv_gs look okay?" - self.assertEqual(pksv_gs[0x00], "2call") - self.assertEqual(pksv_gs[0x2D], "givepoke") - self.assertEqual(pksv_gs[0x85], "waitbutton") - self.assertEqual(pksv_crystal[0x00], "2call") - self.assertEqual(pksv_crystal[0x86], "waitbutton") - self.assertEqual(pksv_crystal[0xA2], "credits") - - def test_chars_integrity(self): - self.assertEqual(chars[0x80], "A") - self.assertEqual(chars[0xA0], "a") - self.assertEqual(chars[0xF0], "¥") - self.assertEqual(jap_chars[0x44], "ぱ") - - def test_map_names_integrity(self): - def map_name(map_group, map_id): return map_names[map_group][map_id]["name"] - self.assertEqual(map_name(2, 7), "Mahogany Town") - self.assertEqual(map_name(3, 0x34), "Ilex Forest") - self.assertEqual(map_name(7, 0x11), "Cerulean City") - - def test_load_map_group_offsets(self): - addresses = load_map_group_offsets() - self.assertEqual(len(addresses), 26, msg="there should be 26 map groups") - addresses = load_map_group_offsets() - self.assertEqual(len(addresses), 26, msg="there should still be 26 map groups") - self.assertIn(0x94034, addresses) - for address in addresses: - self.assertGreaterEqual(address, 0x4000) - self.failIf(0x4000 <= address <= 0x7FFF) - self.failIf(address <= 0x4000) - - def test_index(self): - self.assertTrue(index([1,2,3,4], lambda f: True) == 0) - self.assertTrue(index([1,2,3,4], lambda f: f==3) == 2) - - def test_get_pokemon_constant_by_id(self): - x = get_pokemon_constant_by_id - self.assertEqual(x(1), "BULBASAUR") - self.assertEqual(x(151), "MEW") - self.assertEqual(x(250), "HO_OH") - - def test_find_item_label_by_id(self): - x = find_item_label_by_id - self.assertEqual(x(249), "HM_07") - self.assertEqual(x(173), "BERRY") - self.assertEqual(x(45), None) - - def test_generate_item_constants(self): - x = generate_item_constants - r = x() - self.failUnless("HM_07" in r) - self.failUnless("EQU" in r) - - def test_get_label_for(self): - global all_labels - temp = copy(all_labels) - # this is basd on the format defined in get_labels_between - all_labels = [{"label": "poop", "address": 0x5, - "offset": 0x5, "bank": 0, - "line_number": 2 - }] - self.assertEqual(get_label_for(5), "poop") - all_labels = temp - - def test_generate_map_constant_labels(self): - ids = generate_map_constant_labels() - self.assertEqual(ids[0]["label"], "OLIVINE_POKECENTER_1F") - self.assertEqual(ids[1]["label"], "OLIVINE_GYM") - - def test_get_id_for_map_constant_label(self): - global map_internal_ids - map_internal_ids = generate_map_constant_labels() - self.assertEqual(get_id_for_map_constant_label("OLIVINE_GYM"), 1) - self.assertEqual(get_id_for_map_constant_label("OLIVINE_POKECENTER_1F"), 0) - - def test_get_map_constant_label_by_id(self): - global map_internal_ids - map_internal_ids = generate_map_constant_labels() - self.assertEqual(get_map_constant_label_by_id(0), "OLIVINE_POKECENTER_1F") - self.assertEqual(get_map_constant_label_by_id(1), "OLIVINE_GYM") - - def test_is_valid_address(self): - self.assertTrue(is_valid_address(0)) - self.assertTrue(is_valid_address(1)) - self.assertTrue(is_valid_address(10)) - self.assertTrue(is_valid_address(100)) - self.assertTrue(is_valid_address(1000)) - self.assertTrue(is_valid_address(10000)) - self.assertFalse(is_valid_address(2097153)) - self.assertFalse(is_valid_address(2098000)) - addresses = [random.randrange(0,2097153) for i in range(0, 9+1)] - for address in addresses: - self.assertTrue(is_valid_address(address)) - - -class TestIntervalMap(unittest.TestCase): - def test_intervals(self): - i = IntervalMap() - first = "hello world" - second = "testing 123" - i[0:5] = first - i[5:10] = second - self.assertEqual(i[0], first) - self.assertEqual(i[1], first) - self.assertNotEqual(i[5], first) - self.assertEqual(i[6], second) - i[3:10] = second - self.assertEqual(i[3], second) - self.assertNotEqual(i[4], first) - - def test_items(self): - i = IntervalMap() - first = "hello world" - second = "testing 123" - i[0:5] = first - i[5:10] = second - results = list(i.items()) - self.failUnless(len(results) == 2) - self.assertEqual(results[0], ((0, 5), "hello world")) - self.assertEqual(results[1], ((5, 10), "testing 123")) - - -class TestRomStr(unittest.TestCase): - """RomStr is a class that should act exactly like str() - except that it never shows the contents of it string - unless explicitly forced""" - sample_text = "hello world!" - sample = None - - def setUp(self): - if self.sample == None: - self.__class__.sample = RomStr(self.sample_text) - - def test_equals(self): - "check if RomStr() == str()" - self.assertEquals(self.sample_text, self.sample) - - def test_not_equal(self): - "check if RomStr('a') != RomStr('b')" - self.assertNotEqual(RomStr('a'), RomStr('b')) - - def test_appending(self): - "check if RomStr()+'a'==str()+'a'" - self.assertEquals(self.sample_text+'a', self.sample+'a') - - def test_conversion(self): - "check if RomStr() -> str() works" - self.assertEquals(str(self.sample), self.sample_text) - - def test_inheritance(self): - self.failUnless(issubclass(RomStr, str)) - - def test_length(self): - self.assertEquals(len(self.sample_text), len(self.sample)) - self.assertEquals(len(self.sample_text), self.sample.length()) - self.assertEquals(len(self.sample), self.sample.length()) - - def test_rom_interval(self): - global rom - load_rom() - address = 0x100 - interval = 10 - correct_strings = ['0x0', '0xc3', '0x6e', '0x1', '0xce', - '0xed', '0x66', '0x66', '0xcc', '0xd'] - byte_strings = rom.interval(address, interval, strings=True) - self.assertEqual(byte_strings, correct_strings) - correct_ints = [0, 195, 110, 1, 206, 237, 102, 102, 204, 13] - ints = rom.interval(address, interval, strings=False) - self.assertEqual(ints, correct_ints) - - def test_rom_until(self): - global rom - load_rom() - address = 0x1337 - byte = 0x13 - bytes = rom.until(address, byte, strings=True) - self.failUnless(len(bytes) == 3) - self.failUnless(bytes[0] == '0xd5') - bytes = rom.until(address, byte, strings=False) - self.failUnless(len(bytes) == 3) - self.failUnless(bytes[0] == 0xd5) - - -class TestAsmList(unittest.TestCase): - """AsmList is a class that should act exactly like list() - except that it never shows the contents of its list - unless explicitly forced""" - - def test_equals(self): - base = [1,2,3] - asm = AsmList(base) - self.assertEquals(base, asm) - self.assertEquals(asm, base) - self.assertEquals(base, list(asm)) - - def test_inheritance(self): - self.failUnless(issubclass(AsmList, list)) - - def test_length(self): - base = range(0, 10) - asm = AsmList(base) - self.assertEquals(len(base), len(asm)) - self.assertEquals(len(base), asm.length()) - self.assertEquals(len(base), len(list(asm))) - self.assertEquals(len(asm), asm.length()) - - def test_remove_quoted_text(self): - x = remove_quoted_text - self.assertEqual(x("hello world"), "hello world") - self.assertEqual(x("hello \"world\""), "hello ") - input = 'hello world "testing 123"' - self.assertNotEqual(x(input), input) - input = "hello world 'testing 123'" - self.assertNotEqual(x(input), input) - self.failIf("testing" in x(input)) - - def test_line_has_comment_address(self): - x = line_has_comment_address - self.assertFalse(x("")) - self.assertFalse(x(";")) - self.assertFalse(x(";;;")) - self.assertFalse(x(":;")) - self.assertFalse(x(":;:")) - self.assertFalse(x(";:")) - self.assertFalse(x(" ")) - self.assertFalse(x("".join(" " * 5))) - self.assertFalse(x("".join(" " * 10))) - self.assertFalse(x("hello world")) - self.assertFalse(x("hello_world")) - self.assertFalse(x("hello_world:")) - self.assertFalse(x("hello_world:;")) - self.assertFalse(x("hello_world: ;")) - self.assertFalse(x("hello_world: ; ")) - self.assertFalse(x("hello_world: ;" + "".join(" " * 5))) - self.assertFalse(x("hello_world: ;" + "".join(" " * 10))) - self.assertTrue(x(";1")) - self.assertTrue(x(";F")) - self.assertTrue(x(";$00FF")) - self.assertTrue(x(";0x00FF")) - self.assertTrue(x("; 0x00FF")) - self.assertTrue(x(";$3:$300")) - self.assertTrue(x(";0x3:$300")) - self.assertTrue(x(";$3:0x300")) - self.assertTrue(x(";3:300")) - self.assertTrue(x(";3:FFAA")) - self.assertFalse(x('hello world "how are you today;0x1"')) - self.assertTrue(x('hello world "how are you today:0x1";1')) - returnable = {} - self.assertTrue(x("hello_world: ; 0x4050", returnable=returnable, bank=5)) - self.assertTrue(returnable["address"] == 0x14050) - - def test_line_has_label(self): - x = line_has_label - self.assertTrue(x("hi:")) - self.assertTrue(x("Hello: ")) - self.assertTrue(x("MyLabel: ; test xyz")) - self.assertFalse(x(":")) - self.assertFalse(x(";HelloWorld:")) - self.assertFalse(x("::::")) - self.assertFalse(x(":;:;:;:::")) - - def test_get_label_from_line(self): - x = get_label_from_line - self.assertEqual(x("HelloWorld: "), "HelloWorld") - self.assertEqual(x("HiWorld:"), "HiWorld") - self.assertEqual(x("HiWorld"), None) - - def test_find_labels_without_addresses(self): - global asm - asm = ["hello_world: ; 0x1", "hello_world2: ;"] - labels = find_labels_without_addresses() - self.failUnless(labels[0]["label"] == "hello_world2") - asm = ["hello world: ;1", "hello_world: ;2"] - labels = find_labels_without_addresses() - self.failUnless(len(labels) == 0) - asm = None - - def test_get_labels_between(self): - global asm - x = get_labels_between#(start_line_id, end_line_id, bank) - asm = ["HelloWorld: ;1", - "hi:", - "no label on this line", - ] - labels = x(0, 2, 0x12) - self.assertEqual(len(labels), 1) - self.assertEqual(labels[0]["label"], "HelloWorld") - del asm - - def test_scan_for_predefined_labels(self): - # label keys: line_number, bank, label, offset, address - load_asm() - all_labels = scan_for_predefined_labels() - label_names = [x["label"] for x in all_labels] - self.assertIn("GetFarByte", label_names) - self.assertIn("AddNTimes", label_names) - self.assertIn("CheckShininess", label_names) - - def test_write_all_labels(self): - """dumping json into a file""" - filename = "test_labels.json" - # remove the current file - if os.path.exists(filename): - os.system("rm " + filename) - # make up some labels - labels = [] - # fake label 1 - label = {"line_number": 5, "bank": 0, "label": "SomeLabel", "address": 0x10} - labels.append(label) - # fake label 2 - label = {"line_number": 15, "bank": 2, "label": "SomeOtherLabel", "address": 0x9F0A} - labels.append(label) - # dump to file - write_all_labels(labels, filename=filename) - # open the file and read the contents - file_handler = open(filename, "r") - contents = file_handler.read() - file_handler.close() - # parse into json - obj = json.read(contents) - # begin testing - self.assertEqual(len(obj), len(labels)) - self.assertEqual(len(obj), 2) - self.assertEqual(obj, labels) - - def test_isolate_incbins(self): - global asm - asm = ["123", "456", "789", "abc", "def", "ghi", - 'INCBIN "baserom.gbc",$12DA,$12F8 - $12DA', - "jkl", - 'INCBIN "baserom.gbc",$137A,$13D0 - $137A'] - lines = isolate_incbins() - self.assertIn(asm[6], lines) - self.assertIn(asm[8], lines) - for line in lines: - self.assertIn("baserom", line) - - def test_process_incbins(self): - global incbin_lines, processed_incbins, asm - incbin_lines = ['INCBIN "baserom.gbc",$12DA,$12F8 - $12DA', - 'INCBIN "baserom.gbc",$137A,$13D0 - $137A'] - asm = copy(incbin_lines) - asm.insert(1, "some other random line") - processed_incbins = process_incbins() - self.assertEqual(len(processed_incbins), len(incbin_lines)) - self.assertEqual(processed_incbins[0]["line"], incbin_lines[0]) - self.assertEqual(processed_incbins[2]["line"], incbin_lines[1]) - - def test_reset_incbins(self): - global asm, incbin_lines, processed_incbins - # temporarily override the functions - global load_asm, isolate_incbins, process_incbins - temp1, temp2, temp3 = load_asm, isolate_incbins, process_incbins - def load_asm(): pass - def isolate_incbins(): pass - def process_incbins(): pass - # call reset - reset_incbins() - # check the results - self.assertTrue(asm == [] or asm == None) - self.assertTrue(incbin_lines == []) - self.assertTrue(processed_incbins == {}) - # reset the original functions - load_asm, isolate_incbins, process_incbins = temp1, temp2, temp3 - - def test_find_incbin_to_replace_for(self): - global asm, incbin_lines, processed_incbins - asm = ['first line', 'second line', 'third line', - 'INCBIN "baserom.gbc",$90,$200 - $90', - 'fifth line', 'last line'] - isolate_incbins() - process_incbins() - line_num = find_incbin_to_replace_for(0x100) - # must be the 4th line (the INBIN line) - self.assertEqual(line_num, 3) - - def test_split_incbin_line_into_three(self): - global asm, incbin_lines, processed_incbins - asm = ['first line', 'second line', 'third line', - 'INCBIN "baserom.gbc",$90,$200 - $90', - 'fifth line', 'last line'] - isolate_incbins() - process_incbins() - content = split_incbin_line_into_three(3, 0x100, 10) - # must end up with three INCBINs in output - self.failUnless(content.count("INCBIN") == 3) - - def test_analyze_intervals(self): - global asm, incbin_lines, processed_incbins - asm, incbin_lines, processed_incbins = None, [], {} - asm = ['first line', 'second line', 'third line', - 'INCBIN "baserom.gbc",$90,$200 - $90', - 'fifth line', 'last line', - 'INCBIN "baserom.gbc",$33F,$4000 - $33F'] - isolate_incbins() - process_incbins() - largest = analyze_intervals() - self.assertEqual(largest[0]["line_number"], 6) - self.assertEqual(largest[0]["line"], asm[6]) - self.assertEqual(largest[1]["line_number"], 3) - self.assertEqual(largest[1]["line"], asm[3]) - - def test_generate_diff_insert(self): - global asm - asm = ['first line', 'second line', 'third line', - 'INCBIN "baserom.gbc",$90,$200 - $90', - 'fifth line', 'last line', - 'INCBIN "baserom.gbc",$33F,$4000 - $33F'] - diff = generate_diff_insert(0, "the real first line", debug=False) - self.assertIn("the real first line", diff) - self.assertIn("INCBIN", diff) - self.assertNotIn("No newline at end of file", diff) - self.assertIn("+"+asm[1], diff) - - -class TestMapParsing(unittest.TestCase): - def test_parse_all_map_headers(self): - global parse_map_header_at, old_parse_map_header_at, counter - counter = 0 - for k in map_names.keys(): - if "offset" not in map_names[k].keys(): - map_names[k]["offset"] = 0 - temp = parse_map_header_at - temp2 = old_parse_map_header_at - def parse_map_header_at(address, map_group=None, map_id=None, debug=False): - global counter - counter += 1 - return {} - old_parse_map_header_at = parse_map_header_at - parse_all_map_headers(debug=False) - # parse_all_map_headers is currently doing it 2x - # because of the new/old map header parsing routines - self.assertEqual(counter, 388 * 2) - parse_map_header_at = temp - old_parse_map_header_at = temp2 - -class TestTextScript(unittest.TestCase): - """for testing 'in-script' commands, etc.""" - #def test_to_asm(self): - # pass # or raise NotImplementedError, bryan_message - #def test_find_addresses(self): - # pass # or raise NotImplementedError, bryan_message - #def test_parse_text_at(self): - # pass # or raise NotImplementedError, bryan_message - - -class TestEncodedText(unittest.TestCase): - """for testing chars-table encoded text chunks""" - - def test_process_00_subcommands(self): - g = process_00_subcommands(0x197186, 0x197186+601, debug=False) - self.assertEqual(len(g), 42) - self.assertEqual(len(g[0]), 13) - self.assertEqual(g[1], [184, 174, 180, 211, 164, 127, 20, 231, 81]) - - def test_parse_text_at2(self): - oakspeech = parse_text_at2(0x197186, 601, debug=False) - self.assertIn("encyclopedia", oakspeech) - self.assertIn("researcher", oakspeech) - self.assertIn("dependable", oakspeech) - - def test_parse_text_engine_script_at(self): - p = parse_text_engine_script_at(0x197185, debug=False) - self.assertEqual(len(p.commands), 2) - self.assertEqual(len(p.commands[0]["lines"]), 41) - - # don't really care about these other two - def test_parse_text_from_bytes(self): pass - def test_parse_text_at(self): pass - - -class TestScript(unittest.TestCase): - """for testing parse_script_engine_script_at and script parsing in - general. Script should be a class.""" - #def test_parse_script_engine_script_at(self): - # pass # or raise NotImplementedError, bryan_message - - def test_find_all_text_pointers_in_script_engine_script(self): - address = 0x197637 # 0x197634 - script = parse_script_engine_script_at(address, debug=False) - bank = calculate_bank(address) - r = find_all_text_pointers_in_script_engine_script(script, bank=bank, debug=False) - results = list(r) - self.assertIn(0x197661, results) - - -class TestLabel(unittest.TestCase): - def test_label_making(self): - line_number = 2 - address = 0xf0c0 - label_name = "poop" - l = Label(name=label_name, address=address, line_number=line_number) - self.failUnless(hasattr(l, "name")) - self.failUnless(hasattr(l, "address")) - self.failUnless(hasattr(l, "line_number")) - self.failIf(isinstance(l.address, str)) - self.failIf(isinstance(l.line_number, str)) - self.failUnless(isinstance(l.name, str)) - self.assertEqual(l.line_number, line_number) - self.assertEqual(l.name, label_name) - self.assertEqual(l.address, address) - - -class TestByteParams(unittest.TestCase): - @classmethod - def setUpClass(cls): - load_rom() - cls.address = 10 - cls.sbp = SingleByteParam(address=cls.address) - - @classmethod - def tearDownClass(cls): - del cls.sbp - - def test__init__(self): - self.assertEqual(self.sbp.size, 1) - self.assertEqual(self.sbp.address, self.address) - - def test_parse(self): - self.sbp.parse() - self.assertEqual(str(self.sbp.byte), str(45)) - - def test_to_asm(self): - self.assertEqual(self.sbp.to_asm(), "$2d") - self.sbp.should_be_decimal = True - self.assertEqual(self.sbp.to_asm(), str(45)) - - # HexByte and DollarSignByte are the same now - def test_HexByte_to_asm(self): - h = HexByte(address=10) - a = h.to_asm() - self.assertEqual(a, "$2d") - - def test_DollarSignByte_to_asm(self): - d = DollarSignByte(address=10) - a = d.to_asm() - self.assertEqual(a, "$2d") - - def test_ItemLabelByte_to_asm(self): - i = ItemLabelByte(address=433) - self.assertEqual(i.byte, 54) - self.assertEqual(i.to_asm(), "COIN_CASE") - self.assertEqual(ItemLabelByte(address=10).to_asm(), "$2d") - - def test_DecimalParam_to_asm(self): - d = DecimalParam(address=10) - x = d.to_asm() - self.assertEqual(x, str(0x2d)) - - -class TestMultiByteParam(unittest.TestCase): - def setup_for(self, somecls, byte_size=2, address=443, **kwargs): - self.cls = somecls(address=address, size=byte_size, **kwargs) - self.assertEqual(self.cls.address, address) - self.assertEqual(self.cls.bytes, rom_interval(address, byte_size, strings=False)) - self.assertEqual(self.cls.size, byte_size) - - def test_two_byte_param(self): - self.setup_for(MultiByteParam, byte_size=2) - self.assertEqual(self.cls.to_asm(), "$f0c0") - - def test_three_byte_param(self): - self.setup_for(MultiByteParam, byte_size=3) - - def test_PointerLabelParam_no_bank(self): - self.setup_for(PointerLabelParam, bank=None) - # assuming no label at this location.. - self.assertEqual(self.cls.to_asm(), "$f0c0") - global all_labels - # hm.. maybe all_labels should be using a class? - all_labels = [{"label": "poop", "address": 0xf0c0, - "offset": 0xf0c0, "bank": 0, - "line_number": 2 - }] - self.assertEqual(self.cls.to_asm(), "poop") - - -class TestPostParsing: #(unittest.TestCase): - """tests that must be run after parsing all maps""" - @classmethod - def setUpClass(cls): - run_main() - - def test_signpost_counts(self): - self.assertEqual(len(map_names[1][1]["signposts"]), 0) - self.assertEqual(len(map_names[1][2]["signposts"]), 2) - self.assertEqual(len(map_names[10][5]["signposts"]), 7) - - def test_warp_counts(self): - self.assertEqual(map_names[10][5]["warp_count"], 9) - self.assertEqual(map_names[18][5]["warp_count"], 3) - self.assertEqual(map_names[15][1]["warp_count"], 2) - - def test_map_sizes(self): - self.assertEqual(map_names[15][1]["height"], 18) - self.assertEqual(map_names[15][1]["width"], 10) - self.assertEqual(map_names[7][1]["height"], 4) - self.assertEqual(map_names[7][1]["width"], 4) - - def test_map_connection_counts(self): - self.assertEqual(map_names[7][1]["connections"], 0) - self.assertEqual(map_names[10][1]["connections"], 12) - self.assertEqual(map_names[10][2]["connections"], 12) - self.assertEqual(map_names[11][1]["connections"], 9) # or 13? - - def test_second_map_header_address(self): - self.assertEqual(map_names[11][1]["second_map_header_address"], 0x9509c) - self.assertEqual(map_names[1][5]["second_map_header_address"], 0x95bd0) - - def test_event_address(self): - self.assertEqual(map_names[17][5]["event_address"], 0x194d67) - self.assertEqual(map_names[23][3]["event_address"], 0x1a9ec9) - - def test_people_event_counts(self): - self.assertEqual(len(map_names[23][3]["people_events"]), 4) - self.assertEqual(len(map_names[10][3]["people_events"]), 9) - - -class TestMetaTesting(unittest.TestCase): - """test whether or not i am finding at least - some of the tests in this file""" - tests = None - - def setUp(self): - if self.tests == None: - self.__class__.tests = assemble_test_cases() - - def test_assemble_test_cases_count(self): - "does assemble_test_cases find some tests?" - self.failUnless(len(self.tests) > 0) - - def test_assemble_test_cases_inclusion(self): - "is this class found by assemble_test_cases?" - # i guess it would have to be for this to be running? - self.failUnless(self.__class__ in self.tests) - - def test_assemble_test_cases_others(self): - "test other inclusions for assemble_test_cases" - self.failUnless(TestRomStr in self.tests) - self.failUnless(TestCram in self.tests) - - def test_check_has_test(self): - self.failUnless(check_has_test("beaver", ["test_beaver"])) - self.failUnless(check_has_test("beaver", ["test_beaver_2"])) - self.failIf(check_has_test("beaver_1", ["test_beaver"])) - - def test_find_untested_methods(self): - untested = find_untested_methods() - # the return type must be an iterable - self.failUnless(hasattr(untested, "__iter__")) - #.. basically, a list - self.failUnless(isinstance(untested, list)) - - def test_find_untested_methods_method(self): - """create a function and see if it is found""" - # setup a function in the global namespace - global some_random_test_method - # define the method - def some_random_test_method(): pass - # first make sure it is in the global scope - members = inspect.getmembers(sys.modules[__name__], inspect.isfunction) - func_names = [functuple[0] for functuple in members] - self.assertIn("some_random_test_method", func_names) - # test whether or not it is found by find_untested_methods - untested = find_untested_methods() - self.assertIn("some_random_test_method", untested) - # remove the test method from the global namespace - del some_random_test_method - - def test_load_tests(self): - loader = unittest.TestLoader() - suite = load_tests(loader, None, None) - suite._tests[0]._testMethodName - membership_test = lambda member: \ - inspect.isclass(member) and issubclass(member, unittest.TestCase) - tests = inspect.getmembers(sys.modules[__name__], membership_test) - classes = [x[1] for x in tests] - for test in suite._tests: - self.assertIn(test.__class__, classes) - - def test_report_untested(self): - untested = find_untested_methods() - output = report_untested() - if len(untested) > 0: - self.assertIn("NOT TESTED", output) - for name in untested: - self.assertIn(name, output) - elif len(untested) == 0: - self.assertNotIn("NOT TESTED", output) - - -def assemble_test_cases(): - """finds classes that inherit from unittest.TestCase - because i am too lazy to remember to add them to a - global list of tests for the suite runner""" - classes = [] - clsmembers = inspect.getmembers(sys.modules[__name__], inspect.isclass) - for (name, some_class) in clsmembers: - if issubclass(some_class, unittest.TestCase): - classes.append(some_class) - return classes - -def load_tests(loader, tests, pattern): - suite = unittest.TestSuite() - for test_class in assemble_test_cases(): - tests = loader.loadTestsFromTestCase(test_class) - suite.addTests(tests) - return suite - -def check_has_test(func_name, tested_names): - """checks if there is a test dedicated to this function""" - if "test_"+func_name in tested_names: - return True - for name in tested_names: - if "test_"+func_name in name: - return True - return False - -def find_untested_methods(): - """finds all untested functions in this module - by searching for method names in test case - method names.""" - untested = [] - avoid_funcs = ["main", "run_tests", "run_main", "copy", "deepcopy"] - test_funcs = [] - # get a list of all classes in this module - classes = inspect.getmembers(sys.modules[__name__], inspect.isclass) - # for each class.. - for (name, klass) in classes: - # only look at those that have tests - if issubclass(klass, unittest.TestCase): - # look at this class' methods - funcs = inspect.getmembers(klass, inspect.ismethod) - # for each method.. - for (name2, func) in funcs: - # store the ones that begin with test_ - if "test_" in name2 and name2[0:5] == "test_": - test_funcs.append([name2, func]) - # assemble a list of all test method names (test_x, test_y, ..) - tested_names = [funcz[0] for funcz in test_funcs] - # now get a list of all functions in this module - funcs = inspect.getmembers(sys.modules[__name__], inspect.isfunction) - # for each function.. - for (name, func) in funcs: - # we don't care about some of these - if name in avoid_funcs: continue - # skip functions beginning with _ - if name[0] == "_": continue - # check if this function has a test named after it - has_test = check_has_test(name, tested_names) - if not has_test: - untested.append(name) - return untested - -def report_untested(): - untested = find_untested_methods() - output = "NOT TESTED: [" - first = True - for name in untested: - if first: - output += name - first = False - else: output += ", "+name - output += "]\n" - output += "total untested: " + str(len(untested)) - return output - -#### ways to run this file #### - -def run_tests(): # rather than unittest.main() - loader = unittest.TestLoader() - suite = load_tests(loader, None, None) - unittest.TextTestRunner(verbosity=2).run(suite) - print report_untested() - def run_main(): # read the rom and figure out the offsets for maps direct_load_rom() @@ -8848,10 +7845,9 @@ def run_main(): make_trainer_group_name_trainer_ids(trainer_group_table) # just a helpful alias -main=run_main -# when you run the file.. do unit tests -if __name__ == "__main__": - run_tests() +main = run_main + # when you load the module.. parse everything -elif __name__ == "crystal": pass - #run_main() +if __name__ == "crystal": + pass + diff --git a/dump_sections b/dump_sections new file mode 100755 index 0000000..362318f --- /dev/null +++ b/dump_sections @@ -0,0 +1,14 @@ +#!/bin/bash +# This wraps dump_sections.py so that other python scripts can import the +# functions. If dump_sections.py was instead called dump_sections, then other +# python source code would be unable to use the functions via import +# statements. + +# figure out the path to this script +cwd="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +# construct the path to dump_sections.py +secpath=$cwd/dump_sections.py + +# run dump_sections.py +$secpath $1 diff --git a/dump_sections.py b/dump_sections.py new file mode 100755 index 0000000..91306e4 --- /dev/null +++ b/dump_sections.py @@ -0,0 +1,130 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +""" +Use this tool to dump an asm file for a new source code or disassembly project. + +usage: + + from dump_sections import dump_sections + + output = dump_sections("../../butt.gbc") + + file_handler = open("main.asm", "w") + file_handler.write(output) + file_handler.close() + +You can also use this script from the shell, where it will look for +"baserom.gbc" in the current working path or whatever file path you pass in the +first positional argument. +""" + +import os +import sys +import argparse + +def upper_hex(input): + """ + Converts the input to an uppercase hex string. + """ + if input in [0, "0"]: + return "0" + elif input <= 0xF: + return ("%.x" % (input)).upper() + else: + return ("%.2x" % (input)).upper() + +def format_bank_number(address, bank_size=0x4000): + """ + Returns a str of the hex number of the bank based on the address. + """ + return upper_hex(address / bank_size) + +def calculate_bank_quantity(path, bank_size=0x4000): + """ + Returns the number of 0x4000 banks in the file at path. + """ + return float(os.lstat(path).st_size) / bank_size + +def dump_section(bank_number, separator="\n\n"): + """ + Returns a str of a section header for the asm file. + """ + output = "SECTION \"" + if bank_number in [0, "0"]: + output += "bank0\",HOME" + else: + output += "bank" + output += bank_number + output += "\",DATA,BANK[$" + output += bank_number + output += "]" + output += separator + return output + +def dump_incbin_for_section(address, bank_size=0x4000, baserom="baserom.gbc", separator="\n\n"): + """ + Returns a str for an INCBIN line for an entire section. + """ + output = "INCBIN \"" + output += baserom + output += "\",$" + output += upper_hex(address) + output += ",$" + output += upper_hex(bank_size) + output += separator + return output + +def dump_sections(path, bank_size=0x4000, initial_bank=0, last_bank=None, separator="\n\n"): + """ + Returns a str of assembly source code. The source code delineates each + SECTION and includes bytes from the file specified by baserom. + """ + if not last_bank: + last_bank = calculate_bank_quantity(path, bank_size=bank_size) + + if last_bank < initial_bank: + raise Exception("last_bank must be greater than or equal to initial_bank") + + baserom_name = os.path.basename(path) + + output = "" + + banks = range(initial_bank, last_bank) + + for bank_number in banks: + address = bank_number * bank_size + + # get a formatted hex number of the bank based on the address + formatted_bank_number = format_bank_number(address, bank_size=bank_size) + + # SECTION + output += dump_section(formatted_bank_number, separator=separator) + + # INCBIN a range of bytes from the ROM + output += dump_incbin_for_section(address, bank_size=bank_size, baserom=baserom_name, separator=separator) + + # clean up newlines at the end of the output + if output[-2:] == "\n\n": + output = output[:-2] + output += "\n" + + return output + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("rompath", nargs="?", metavar="rompath", type=str) + args = parser.parse_args() + + # default to "baserom.gbc" in the current working directory + baserom = "baserom.gbc" + + # but let the user override the default + if args.rompath: + baserom = args.rompath + + # generate some asm + output = dump_sections(baserom) + + # dump everything to stdout + sys.stdout.write(output) + diff --git a/gbz80disasm.py b/gbz80disasm.py index 48739e0..f2ba483 100644 --- a/gbz80disasm.py +++ b/gbz80disasm.py @@ -1,26 +1,28 @@ -#author: Bryan Bishop <kanzure@gmail.com> -#date: 2012-01-09 +# -*- coding: utf-8 -*- + import os import sys from copy import copy, deepcopy from ctypes import c_int8 -import json import random +import json -spacing = "\t" +# New versions of json don't have read anymore. +if not hasattr(json, "read"): + json.read = json.loads -class XRomStr(str): - def __repr__(self): - return "RomStr(too long)" +from romstr import RomStr def load_rom(filename="../baserom.gbc"): """loads bytes into memory""" global rom - file_handler = open(filename, "rb") - rom = XRomStr(file_handler.read()) + file_handler = open(filename, "rb") + rom = RomStr(file_handler.read()) file_handler.close() return rom +spacing = "\t" + temp_opt_table = [ [ "ADC A", 0x8f, 0 ], [ "ADC B", 0x88, 0 ], @@ -550,7 +552,7 @@ end_08_scripts_with = [ 0xc9, #ret ###0xda, 0xe9, 0xd2, 0xc2, 0xca, 0xc3, 0x38, 0x30, 0x20, 0x28, 0x18, 0xd8, 0xd0, 0xc0, 0xc8, 0xc9 ] -relative_jumps = [0x38, 0x30, 0x20, 0x28, 0x18, 0xc3, 0xda, 0xc2] +relative_jumps = [0x38, 0x30, 0x20, 0x28, 0x18, 0xc3, 0xda, 0xc2] relative_unconditional_jumps = [0xc3, 0x18] call_commands = [0xdc, 0xd4, 0xc4, 0xcc, 0xcd] @@ -559,7 +561,7 @@ all_labels = {} def load_labels(filename="labels.json"): global all_labels if os.path.exists(filename): - all_labels = json.loads(open(filename, "r").read()) + all_labels = json.read(open(filename, "r").read()) else: print "You must run crystal.scan_for_predefined_labels() to create \"labels.json\". Trying..." import crystal @@ -601,10 +603,10 @@ def output_bank_opcodes(original_offset, max_byte_count=0x4000, debug = False): #i = offset #ad = end_address #a, oa = current_byte_number - + load_labels() load_rom() - + bank_id = 0 if original_offset > 0x8000: bank_id = original_offset / 0x4000 @@ -1043,14 +1043,16 @@ def decompress_monsters(type = front): # decompress monster = decompress_monster_by_id(id, type) if monster != None: # no unowns here - filename = str(id+1).zfill(3) + '.2bpp' # 001.2bpp if not type: # front - folder = '../gfx/frontpics/' + filename = 'front.2bpp' + folder = '../gfx/pics/' + str(id+1).zfill(3) + '/' to_file(folder+filename, monster.pic) - folder = '../gfx/anim/' + filename = 'tiles.2bpp' + folder = '../gfx/pics/' + str(id+1).zfill(3) + '/' to_file(folder+filename, monster.animtiles) else: # back - folder = '../gfx/backpics/' + filename = 'back.2bpp' + folder = '../gfx/pics/' + str(id+1).zfill(3) + '/' to_file(folder+filename, monster.pic) @@ -1073,14 +1075,16 @@ def decompress_unowns(type = front): # decompress unown = decompress_unown_by_id(letter, type) - filename = str(unown_dex).zfill(3) + chr(ord('a') + letter) + '.2bpp' # 201a.2bpp if not type: # front - folder = '../gfx/frontpics/' + filename = 'front.2bpp' + folder = '../gfx/pics/' + str(unown_dex).zfill(3) + chr(ord('a') + letter) + '/' to_file(folder+filename, unown.pic) + filename = 'tiles.2bpp' folder = '../gfx/anim/' to_file(folder+filename, unown.animtiles) else: # back - folder = '../gfx/backpics/' + filename = 'back.2bpp' + folder = '../gfx/pics/' + str(unown_dex).zfill(3) + chr(ord('a') + letter) + '/' to_file(folder+filename, unown.pic) @@ -1255,8 +1259,8 @@ def compress_file(filein, fileout, mode = 'horiz'): def compress_monster_frontpic(id, fileout): mode = 'vert' - fpic = '../gfx/frontpics/' + str(id).zfill(3) + '.2bpp' - fanim = '../gfx/anim/' + str(id).zfill(3) + '.2bpp' + fpic = '../gfx/pics/' + str(id).zfill(3) + '/front.2bpp' + fanim = '../gfx/pics/' + str(id).zfill(3) + '/tiles.2bpp' pic = open(fpic, 'rb').read() anim = open(fanim, 'rb').read() @@ -1264,7 +1268,7 @@ def compress_monster_frontpic(id, fileout): lz = Compressed(image, mode, 5) - out = '../gfx/frontpics/lz/' + str(id).zfill(3) + '.lz' + out = '../gfx/pics/' + str(id).zfill(3) + '/front.lz' to_file(out, lz.output) @@ -1283,6 +1287,28 @@ def get_uncompressed_gfx(start, num_tiles, filename): +def hex_to_rgb(word): + red = word & 0b11111 + word >>= 5 + green = word & 0b11111 + word >>= 5 + blue = word & 0b11111 + return (red, green, blue) + +def grab_palettes(address, length = 0x80): + output = '' + for word in range(length/2): + color = ord(rom[address+1])*0x100 + ord(rom[address]) + address += 2 + color = hex_to_rgb(color) + red = str(color[0]).zfill(2) + green = str(color[1]).zfill(2) + blue = str(color[2]).zfill(2) + output += '\tRGB '+red+', '+green+', '+blue + output += '\n' + return output + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('cmd', nargs='?', metavar='cmd', type=str) @@ -1317,7 +1343,11 @@ if __name__ == "__main__": # python gfx.py un [address] [num_tiles] [filename] get_uncompressed_gfx(int(args.arg1,16), int(args.arg2), args.arg3) - else: - # python gfx.py - decompress_all() - if debug: print 'decompressed known gfx to ../gfx/!' + elif args.cmd == 'pal': + # python gfx.py pal [address] [length] + print grab_palettes(int(args.arg1,16), int(args.arg2)) + + #else: + ## python gfx.py + #decompress_all() + #if debug: print 'decompressed known gfx to ../gfx/!' @@ -1,12 +1,13 @@ -#!/usr/bin/python -# author: Bryan Bishop <kanzure@gmail.com> -# date: 2012-06-20 +# -*- coding: utf-8 -*- import networkx as nx -from romstr import RomStr, DisAsm, \ - relative_jumps, call_commands, \ - relative_unconditional_jumps +from romstr import ( + RomStr, + relative_jumps, + call_commands, + relative_unconditional_jumps, +) class RomGraph(nx.DiGraph): """ Graphs various functions pointing to each other. diff --git a/interval_map.py b/interval_map.py new file mode 100644 index 0000000..7e6c5cd --- /dev/null +++ b/interval_map.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- + +from bisect import bisect_left, bisect_right +from itertools import izip + +class IntervalMap(object): + """ + This class maps a set of intervals to a set of values. + + >>> i = IntervalMap() + >>> i[0:5] = "hello world" + >>> i[6:10] = "hello cruel world" + >>> print i[4] + "hello world" + """ + + def __init__(self): + """initializes an empty IntervalMap""" + self._bounds = [] + self._items = [] + self._upperitem = None + + def __setitem__(self, _slice, _value): + """sets an interval mapping""" + assert isinstance(_slice, slice), 'The key must be a slice object' + + if _slice.start is None: + start_point = -1 + else: + start_point = bisect_left(self._bounds, _slice.start) + + if _slice.stop is None: + end_point = -1 + else: + end_point = bisect_left(self._bounds, _slice.stop) + + if start_point>=0: + if start_point < len(self._bounds) and self._bounds[start_point]<_slice.start: + start_point += 1 + + if end_point>=0: + self._bounds[start_point:end_point] = [_slice.start, _slice.stop] + if start_point < len(self._items): + self._items[start_point:end_point] = [self._items[start_point], _value] + else: + self._items[start_point:end_point] = [self._upperitem, _value] + else: + self._bounds[start_point:] = [_slice.start] + if start_point < len(self._items): + self._items[start_point:] = [self._items[start_point], _value] + else: + self._items[start_point:] = [self._upperitem] + self._upperitem = _value + else: + if end_point>=0: + self._bounds[:end_point] = [_slice.stop] + self._items[:end_point] = [_value] + else: + self._bounds[:] = [] + self._items[:] = [] + self._upperitem = _value + + def __getitem__(self,_point): + """gets a value from the mapping""" + assert not isinstance(_point, slice), 'The key cannot be a slice object' + + index = bisect_right(self._bounds, _point) + if index < len(self._bounds): + return self._items[index] + else: + return self._upperitem + + def items(self): + """returns an iterator with each item being + ((low_bound, high_bound), value) + these items are returned in order""" + previous_bound = None + for (b, v) in izip(self._bounds, self._items): + if v is not None: + yield (previous_bound, b), v + previous_bound = b + if self._upperitem is not None: + yield (previous_bound, None), self._upperitem + + def values(self): + """returns an iterator with each item being a stored value + the items are returned in order""" + for v in self._items: + if v is not None: + yield v + if self._upperitem is not None: + yield self._upperitem + + def __repr__(self): + s = [] + for b,v in self.items(): + if v is not None: + s.append('[%r, %r] => %r'%( + b[0], + b[1], + v + )) + return '{'+', '.join(s)+'}' + diff --git a/item_constants.py b/item_constants.py index d60dfb1..a050637 100644 --- a/item_constants.py +++ b/item_constants.py @@ -1,4 +1,7 @@ -item_constants = {1: 'MASTER_BALL', +# -*- coding: utf-8 -*- + +item_constants = { +1: 'MASTER_BALL', 2: 'ULTRA_BALL', 3: 'BRIGHTPOWDER', 4: 'GREAT_BALL', @@ -219,4 +222,20 @@ item_constants = {1: 'MASTER_BALL', 246: 'HM_04', 247: 'HM_05', 248: 'HM_06', -249: 'HM_07'} +249: 'HM_07', +} + +def find_item_label_by_id(id): + if id in item_constants.keys(): + return item_constants[id] + else: return None + +def generate_item_constants(): + """make a list of items to put in constants.asm""" + output = "" + for (id, item) in item_constants.items(): + val = ("$%.2x"%id).upper() + while len(item)<13: item+= " " + output += item + " EQU " + val + "\n" + return output + @@ -1,7 +1,12 @@ -""" Various label/line-related functions. +# -*- coding: utf-8 -*- +""" +Various label/line-related functions. """ -from pointers import calculate_pointer, calculate_bank +from pointers import ( + calculate_pointer, + calculate_bank, +) def remove_quoted_text(line): """get rid of content inside quotes diff --git a/move_constants.py b/move_constants.py index 929f1fa..a20af85 100644 --- a/move_constants.py +++ b/move_constants.py @@ -1,3 +1,5 @@ +# -*- coding: utf-8 -*- + moves = { 0x01: "POUND", 0x02: "KARATE_CHOP", @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- pksv_gs = { 0x00: "2call", @@ -34,7 +35,7 @@ pksv_gs = { 0x21: "checkitem", 0x22: "givemoney", 0x23: "takemoney", - 0x24: "checkmonkey", + 0x24: "checkmoney", 0x25: "givecoins", 0x26: "takecoins", 0x27: "checkcoins", @@ -141,8 +142,8 @@ pksv_gs = { 0xA3: "displaylocation", } -#see http://www.pokecommunity.com/showpost.php?p=4347261 -#NOTE: this has some updates that need to be back-ported to gold +# see http://www.pokecommunity.com/showpost.php?p=4347261 +# NOTE: this has some updates that need to be back-ported to gold pksv_crystal = { 0x00: "2call", 0x01: "3call", @@ -179,7 +180,7 @@ pksv_crystal = { 0x21: "checkitem", 0x22: "givemoney", 0x23: "takemoney", - 0x24: "checkmonkey", + 0x24: "checkmoney", 0x25: "givecoins", 0x26: "takecoins", 0x27: "checkcoins", @@ -292,13 +293,14 @@ pksv_crystal = { } #these cause the script to end; used in create_command_classes -pksv_crystal_more_enders = [0x03, 0x04, 0x05, 0x0C, 0x51, 0x53, - 0x8D, 0x8F, 0x90, 0x91, 0x92, 0x9B, +pksv_crystal_more_enders = [0x03, 0x04, 0x05, 0x0C, 0x51, 0x52, + 0x53, 0x8D, 0x8F, 0x90, 0x91, 0x92, + 0x9B, 0xB2, #maybe? 0xCC, #maybe? ] -#these have no pksv names as of pksv 2.1.1 +# these have no pksv names as of pksv 2.1.1 pksv_crystal_unknowns = [ 0x9F, 0xA6, 0xA7, 0xA8, 0xA9, 0xAA, 0xAB, 0xAC, 0xAD, 0xAE, 0xAF, diff --git a/pointers.py b/pointers.py index f392241..8fe3df3 100644 --- a/pointers.py +++ b/pointers.py @@ -1,10 +1,12 @@ -""" Various functions related to pointer and address math. Mostly to avoid - depedency loops. +# -*- coding: utf-8 -*- +""" +Various functions related to pointer and address math. Mostly to avoid +depedency loops. """ def calculate_bank(address): """you are too lazy to divide on your own?""" - if type(address) == str: + if type(address) == str: address = int(address, 16) #if 0x4000 <= address <= 0x7FFF: # raise Exception, "bank 1 does not exist" diff --git a/pokemon_constants.py b/pokemon_constants.py index 33b6a0e..221a31c 100644 --- a/pokemon_constants.py +++ b/pokemon_constants.py @@ -1,3 +1,5 @@ +# -*- coding: utf-8 -*- + pokemon_constants = { 1: "BULBASAUR", 2: "IVYSAUR", @@ -1,8 +1,21 @@ -import sys, os, time, datetime, json -from gbz80disasm import opt_table +# -*- coding: utf-8 -*- + +import sys +import os +import time +import datetime from ctypes import c_int8 -from copy import copy, deepcopy -from labels import get_label_from_line, get_address_from_line_comment +from copy import copy +import json + +# New versions of json don't have read anymore. +if not hasattr(json, "read"): + json.read = json.loads + +from labels import ( + get_label_from_line, + get_address_from_line_comment, +) relative_jumps = [0x38, 0x30, 0x20, 0x28, 0x18, 0xc3, 0xda, 0xc2, 0x32] relative_unconditional_jumps = [0xc3, 0x18] @@ -91,7 +104,7 @@ class RomStr(str): file_handler.close() # load the labels from the file - self.labels = json.loads(open(filename, "r").read()) + self.labels = json.read(open(filename, "r").read()) def get_address_for(self, label): """ Returns the address of a label. This is slow and could be improved @@ -137,7 +150,7 @@ class RomStr(str): that will be parsed, so that large patches of data aren't parsed as code. """ - if type(address) == str and "0x" in address: + if type(address) in [str, unicode] and "0x" in address: address = int(address, 16) start_address = address @@ -166,333 +179,8 @@ class RomStr(str): elif end_address != None and size == None: size = end_address - start_address - return DisAsm(start_address=start_address, end_address=end_address, size=size, max_size=max_size, debug=debug, rom=self) - -class DisAsm: - """ z80 disassembler - """ - - def __init__(self, start_address=None, end_address=None, size=None, max_size=0x4000, debug=True, rom=None): - assert start_address != None, "start_address must be given" - - if rom == None: - file_handler = open("../baserom.gbc", "r") - bytes = file_handler.read() - file_handler.close() - rom = RomStr(bytes) - - if debug not in [None, True, False]: - raise Exception, "debug param is invalid" - if debug == None: - debug = False - - # get end_address and size in sync with each other - if end_address == None and size != None: - end_address = start_address + size - elif end_address != None and size == None: - size = end_address - start_address - elif end_address != None and size != None: - size = max(end_address - start_address, size) - end_address = start_address + size - - # check that the bounds make sense - if end_address != None: - if end_address <= start_address: - raise Exception, "end_address is out of bounds" - elif (end_address - start_address) > max_size: - raise Exception, "end_address goes beyond max_size" - - # check more edge cases - if not start_address >= 0: - raise Exception, "start_address must be at least 0" - elif end_address != None and not end_address >= 0: - raise Exception, "end_address must be at least 0" - - self.rom = rom - self.start_address = start_address - self.end_address = end_address - self.size = size - self.max_size = max_size - self.debug = debug - - self.parse() - - def parse(self): - """ Disassembles stuff and things. - """ - - rom = self.rom - start_address = self.start_address - end_address = self.end_address - max_size = self.max_size - debug = self.debug - - bank_id = start_address / 0x4000 - - # [{"command": 0x20, "bytes": [0x20, 0x40, 0x50], - # "asm": "jp $5040", "label": "Unknown5040"}] - asm_commands = {} - - offset = start_address - - last_hl_address = None - last_a_address = None - used_3d97 = False - - keep_reading = True - - while (end_address != 0 and offset <= end_address) or keep_reading: - # read the current opcode byte - current_byte = ord(rom[offset]) - current_byte_number = len(asm_commands.keys()) - - # setup this next/upcoming command - if offset in asm_commands.keys(): - asm_command = asm_commands[offset] - else: - asm_command = {} - - asm_command["address"] = offset - - if not "references" in asm_command.keys(): - # This counts how many times relative jumps reference this - # byte. This is used to determine whether or not to print out a - # label later. - asm_command["references"] = 0 - - # some commands have two opcodes - next_byte = ord(rom[offset+1]) - - if self.debug: - print "offset: \t\t" + hex(offset) - print "current_byte: \t\t" + hex(current_byte) - print "next_byte: \t\t" + hex(next_byte) - - # all two-byte opcodes also have their first byte in there somewhere - if (current_byte in opt_table.keys()) or ((current_byte + (next_byte << 8)) in opt_table.keys()): - # this might be a two-byte opcode - possible_opcode = current_byte + (next_byte << 8) - - # check if this is a two-byte opcode - if possible_opcode in opt_table.keys(): - op_code = possible_opcode - else: - op_code = current_byte - - op = opt_table[op_code] - - opstr = op[0].lower() - optype = op[1] - - if self.debug: - print "opstr: " + opstr - - asm_command["type"] = "op" - asm_command["id"] = op_code - asm_command["format"] = opstr - asm_command["opnumberthing"] = optype - - opstr2 = None - base_opstr = copy(opstr) - - if "x" in opstr: - for x in range(0, opstr.count("x")): - insertion = ord(rom[offset + 1]) - - # Certain opcodes will have a local relative jump label - # here instead of a raw hex value, but this is - # controlled through asm output. - insertion = "$" + hex(insertion)[2:] - - opstr = opstr[:opstr.find("x")].lower() + insertion + opstr[opstr.find("x")+1:].lower() - - if op_code in relative_jumps: - target_address = offset + 2 + c_int8(ord(rom[offset + 1])).value - insertion = "asm_" + hex(target_address) - - if str(target_address) in self.rom.labels.keys(): - insertion = self.rom.labels[str(target_address)] - - opstr2 = base_opstr[:base_opstr.find("x")].lower() + insertion + base_opstr[base_opstr.find("x")+1:].lower() - asm_command["formatted_with_labels"] = opstr2 - asm_command["target_address"] = target_address - - current_byte_number += 1 - offset += 1 - - if "?" in opstr: - for y in range(0, opstr.count("?")): - byte1 = ord(rom[offset + 1]) - byte2 = ord(rom[offset + 2]) - - number = byte1 - number += byte2 << 8; - - # In most cases, you can use a label here. Labels will - # be shown during asm output. - insertion = "$%.4x" % (number) - - opstr = opstr[:opstr.find("?")].lower() + insertion + opstr[opstr.find("?")+1:].lower() - - # This version of the formatted string has labels. In - # the future, the actual labels should be parsed - # straight out of the "main.asm" file. - target_address = number % 0x4000 - insertion = "asm_" + hex(target_address) - - if str(target_address) in self.rom.labels.keys(): - insertion = self.rom.labels[str(target_address)] - - opstr2 = base_opstr[:base_opstr.find("?")].lower() + insertion + base_opstr[base_opstr.find("?")+1:].lower() - asm_command["formatted_with_labels"] = opstr2 - asm_command["target_address"] = target_address - - current_byte_number += 2 - offset += 2 - - # Check for relative jumps, construct the formatted asm line. - # Also set the usage of labels. - if current_byte in [0x18, 0x20] + relative_jumps: # jr or jr nz - # generate a label for the byte we're jumping to - target_address = offset + 1 + c_int8(ord(rom[offset])).value - - if target_address in asm_commands.keys(): - asm_commands[target_address]["references"] += 1 - remote_label = "asm_" + hex(target_address) - asm_commands[target_address]["current_label"] = remote_label - asm_command["remote_label"] = remote_label - - # Not sure how to set this, can't be True because an - # address referenced multiple times will use a label - # despite the label not necessarily being used in the - # output. The "use_remote_label" values should be - # calculated when rendering the asm output, based on - # which addresses and which op codes will be displayed - # (within the range). - asm_command["use_remote_label"] = "unknown" - else: - remote_label = "asm_" + hex(target_address) - - # This remote address might not be part of this - # function. - asm_commands[target_address] = { - "references": 1, - "current_label": remote_label, - "address": target_address, - } - # Also, target_address can be negative (before the - # start_address that the user originally requested), - # and it shouldn't be shown on asm output because the - # intermediate bytes (between a negative target_address - # and start_address) won't be disassembled. - - # Don't know yet if this remote address is part of this - # function or not. When the remote address is not part - # of this function, the label name should not be used, - # because that label will not be disassembled in the - # output, until the user asks it to. - asm_command["use_remote_label"] = "unknown" - asm_command["remote_label"] = remote_label - elif current_byte == 0x3e: - last_a_address = ord(rom[offset + 1]) - - # store the formatted string for the output later - asm_command["formatted"] = opstr - - if current_byte == 0x21: - last_hl_address = byte1 + (byte2 << 8) - - # this is leftover from pokered, might be meaningless - if current_byte == 0xcd: - if number == 0x3d97: - used_3d97 = True - - if current_byte == 0xc3 or current_byte in relative_unconditional_jumps: - if current_byte == 0xc3: - if number == 0x3d97: - used_3d97 = True - - # stop reading at a jump, relative jump or return - if current_byte in end_08_scripts_with: - is_data = False - - if not self.has_outstanding_labels(asm_commands, offset): - keep_reading = False - break - else: - keep_reading = True - else: - keep_reading = True - - else: - # This shouldn't really happen, and means that this area of the - # ROM probably doesn't represent instructions. - asm_command["type"] = "data" # db - asm_command["value"] = current_byte - keep_reading = False - - # save this new command in the list - asm_commands[asm_command["address"]] = asm_command - - # jump forward by a byte - offset += 1 - - # also save the last command if necessary - if len(asm_commands.keys()) > 0 and asm_commands[asm_commands.keys()[-1]] is not asm_command: - asm_commands[asm_command["address"]] = asm_command - - # store the set of commands on this object - self.asm_commands = asm_commands - - self.end_address = offset + 1 - self.last_address = self.end_address - - def has_outstanding_labels(self, asm_commands, offset): - """ Checks if there are any labels that haven't yet been created. - """ # is this really necessary?? - return False - - def used_addresses(self): - """ Returns a list of unique addresses that this function will probably - call. - """ - addresses = set() - - for (id, command) in self.asm_commands.items(): - if command.has_key("target_address") and command["id"] in call_commands: - addresses.add(command["target_address"]) - - return addresses - - def __str__(self): - """ ASM pretty printer. - """ - output = "" - - for (key, line) in self.asm_commands.items(): - # skip anything from before the beginning - if key < self.start_address: - continue - - # show a label - if line["references"] > 0 and "current_label" in line.keys(): - if line["address"] == self.start_address: - output += "thing: ; " + hex(line["address"]) + "\n" - else: - output += "." + line["current_label"] + "\@ ; " + hex(line["address"]) + "\n" - - # show the actual line - if line.has_key("formatted_with_labels"): - output += spacing + line["formatted_with_labels"] - elif line.has_key("formatted"): - output += spacing + line["formatted"] - #output += " ; to " + - output += "\n" - - # show the next address after this chunk - output += "; " + hex(self.end_address) - - return output + raise NotImplementedError("DisAsm was removed and never worked; hook up another disassembler please.") + #return DisAsm(start_address=start_address, end_address=end_address, size=size, max_size=max_size, debug=debug, rom=self) class AsmList(list): """ Simple wrapper to prevent all asm lines from being shown on screen. diff --git a/test_dump_sections.py b/test_dump_sections.py new file mode 100644 index 0000000..b73b86f --- /dev/null +++ b/test_dump_sections.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- + +try: + import unittest2 as unittest +except ImportError: + import unittest + +# check for things we need in unittest +if not hasattr(unittest.TestCase, 'setUpClass'): + sys.stderr.write("The unittest2 module or Python 2.7 is required to run this script.") + sys.exit(1) + +from dump_sections import ( + upper_hex, + format_bank_number, + calculate_bank_quantity, + dump_section, + dump_incbin_for_section, +) + +class TestDumpSections(unittest.TestCase): + def test_upper_hex(self): + number = 0x52 + self.assertEquals(number, int("0x" + upper_hex(number), 16)) + + number = 0x1 + self.assertEquals(number, int("0x" + upper_hex(number), 16)) + + number = 0x0 + self.assertEquals(number, int("0x" + upper_hex(number), 16)) + + number = 0xAA + self.assertEquals(number, int("0x" + upper_hex(number), 16)) + + number = 0xFFFFAAA0000 + self.assertEquals(number, int("0x" + upper_hex(number), 16)) + + def test_format_bank_number(self): + address = 0x0 + self.assertEquals("0", format_bank_number(address)) + + address = 0x4000 + self.assertEquals("1", format_bank_number(address)) + + address = 0x1FC000 + self.assertEquals("7F", format_bank_number(address)) + + def test_dump_section(self): + self.assertIn("SECTION", dump_section(str(0))) + self.assertIn("HOME", dump_section(str(0))) + self.assertNotIn("HOME", dump_section(str(1))) + self.assertIn("DATA", dump_section(str(2))) + self.assertIn("BANK", dump_section(str(40))) + self.assertNotIn("BANK", dump_section(str(0))) + + def test_dump_incbin_for_section(self): + self.assertIn("INCBIN", dump_incbin_for_section(0)) + + def test_dump_incbin_for_section_separator(self): + separator = "\n\n" + self.assertIn(separator, dump_incbin_for_section(0, separator=separator)) + + separator = "\t\t" # dumb + self.assertIn(separator, dump_incbin_for_section(0, separator=separator)) + + def test_dump_incbin_for_section_default(self): + rom = "baserom.gbc" + self.assertIn(rom, dump_incbin_for_section(0)) + + rom = "baserom" + self.assertIn(rom, dump_incbin_for_section(0x4000)) + +if __name__ == "__main__": + unittest.main() diff --git a/tests.py b/tests.py new file mode 100644 index 0000000..61f46d6 --- /dev/null +++ b/tests.py @@ -0,0 +1,1015 @@ +# -*- coding: utf-8 -*- + +import os +import sys +import inspect +from copy import copy +import hashlib +import random +import json + +from interval_map import IntervalMap +from chars import chars, jap_chars + +from romstr import ( + RomStr, + AsmList, +) + +from item_constants import ( + item_constants, + find_item_label_by_id, + generate_item_constants, +) + +from pointers import ( + calculate_bank, + calculate_pointer, +) + +from pksv import ( + pksv_gs, + pksv_crystal, +) + +from labels import ( + remove_quoted_text, + line_has_comment_address, + line_has_label, + get_label_from_line, +) + +from crystal import ( + rom, + load_rom, + rom_until, + direct_load_rom, + parse_script_engine_script_at, + parse_text_engine_script_at, + parse_text_at2, + find_all_text_pointers_in_script_engine_script, + SingleByteParam, + HexByte, + MultiByteParam, + PointerLabelParam, + ItemLabelByte, + DollarSignByte, + DecimalParam, + rom_interval, + map_names, + Label, + scan_for_predefined_labels, + all_labels, + write_all_labels, + parse_map_header_at, + old_parse_map_header_at, + process_00_subcommands, + parse_all_map_headers, + translate_command_byte, + map_name_cleaner, + load_map_group_offsets, + load_asm, + asm, + is_valid_address, + index, + how_many_until, + grouper, + get_pokemon_constant_by_id, + generate_map_constant_labels, + get_map_constant_label_by_id, + get_id_for_map_constant_label, + calculate_pointer_from_bytes_at, + isolate_incbins, + process_incbins, + get_labels_between, + generate_diff_insert, + find_labels_without_addresses, + rom_text_at, + get_label_for, + split_incbin_line_into_three, + reset_incbins, +) + +# for testing all this crap +try: + import unittest2 as unittest +except ImportError: + import unittest + +# check for things we need in unittest +if not hasattr(unittest.TestCase, 'setUpClass'): + sys.stderr.write("The unittest2 module or Python 2.7 is required to run this script.") + sys.exit(1) + +class TestCram(unittest.TestCase): + "this is where i cram all of my unit tests together" + + @classmethod + def setUpClass(cls): + global rom + cls.rom = direct_load_rom() + rom = cls.rom + + @classmethod + def tearDownClass(cls): + del cls.rom + + def test_generic_useless(self): + "do i know how to write a test?" + self.assertEqual(1, 1) + + def test_map_name_cleaner(self): + name = "hello world" + cleaned_name = map_name_cleaner(name) + self.assertNotEqual(name, cleaned_name) + self.failUnless(" " not in cleaned_name) + name = "Some Random Pokémon Center" + cleaned_name = map_name_cleaner(name) + self.assertNotEqual(name, cleaned_name) + self.failIf(" " in cleaned_name) + self.failIf("é" in cleaned_name) + + def test_grouper(self): + data = range(0, 10) + groups = grouper(data, count=2) + self.assertEquals(len(groups), 5) + data = range(0, 20) + groups = grouper(data, count=2) + self.assertEquals(len(groups), 10) + self.assertNotEqual(data, groups) + self.assertNotEqual(len(data), len(groups)) + + def test_direct_load_rom(self): + rom = self.rom + self.assertEqual(len(rom), 2097152) + self.failUnless(isinstance(rom, RomStr)) + + def test_load_rom(self): + global rom + rom = None + load_rom() + self.failIf(rom == None) + rom = RomStr(None) + load_rom() + self.failIf(rom == RomStr(None)) + + def test_load_asm(self): + asm = load_asm() + joined_lines = "\n".join(asm) + self.failUnless("SECTION" in joined_lines) + self.failUnless("bank" in joined_lines) + self.failUnless(isinstance(asm, AsmList)) + + def test_rom_file_existence(self): + "ROM file must exist" + self.failUnless("baserom.gbc" in os.listdir("../")) + + def test_rom_md5(self): + "ROM file must have the correct md5 sum" + rom = self.rom + correct = "9f2922b235a5eeb78d65594e82ef5dde" + md5 = hashlib.md5() + md5.update(rom) + md5sum = md5.hexdigest() + self.assertEqual(md5sum, correct) + + def test_bizarre_http_presence(self): + rom_segment = self.rom[0x112116:0x112116+8] + self.assertEqual(rom_segment, "HTTP/1.0") + + def test_rom_interval(self): + address = 0x100 + interval = 10 + correct_strings = ['0x0', '0xc3', '0x6e', '0x1', '0xce', + '0xed', '0x66', '0x66', '0xcc', '0xd'] + byte_strings = rom_interval(address, interval, strings=True) + self.assertEqual(byte_strings, correct_strings) + correct_ints = [0, 195, 110, 1, 206, 237, 102, 102, 204, 13] + ints = rom_interval(address, interval, strings=False) + self.assertEqual(ints, correct_ints) + + def test_rom_until(self): + address = 0x1337 + byte = 0x13 + bytes = rom_until(address, byte, strings=True) + self.failUnless(len(bytes) == 3) + self.failUnless(bytes[0] == '0xd5') + bytes = rom_until(address, byte, strings=False) + self.failUnless(len(bytes) == 3) + self.failUnless(bytes[0] == 0xd5) + + def test_how_many_until(self): + how_many = how_many_until(chr(0x13), 0x1337) + self.assertEqual(how_many, 3) + + def test_calculate_bank(self): + self.failUnless(calculate_bank(0x8000) == 2) + self.failUnless(calculate_bank("0x9000") == 2) + self.failUnless(calculate_bank(0) == 0) + for address in [0x4000, 0x5000, 0x6000, 0x7000]: + self.assertRaises(Exception, calculate_bank, address) + + def test_calculate_pointer(self): + # for offset <= 0x4000 + self.assertEqual(calculate_pointer(0x0000), 0x0000) + self.assertEqual(calculate_pointer(0x3FFF), 0x3FFF) + # for 0x4000 <= offset <= 0x7FFFF + self.assertEqual(calculate_pointer(0x430F, bank=5), 0x1430F) + # for offset >= 0x7FFF + self.assertEqual(calculate_pointer(0x8FFF, bank=6), calculate_pointer(0x8FFF, bank=7)) + + def test_calculate_pointer_from_bytes_at(self): + addr1 = calculate_pointer_from_bytes_at(0x100, bank=False) + self.assertEqual(addr1, 0xc300) + addr2 = calculate_pointer_from_bytes_at(0x100, bank=True) + self.assertEqual(addr2, 0x2ec3) + + def test_rom_text_at(self): + self.assertEquals(rom_text_at(0x112116, 8), "HTTP/1.0") + + def test_translate_command_byte(self): + self.failUnless(translate_command_byte(crystal=0x0) == 0x0) + self.failUnless(translate_command_byte(crystal=0x10) == 0x10) + self.failUnless(translate_command_byte(crystal=0x40) == 0x40) + self.failUnless(translate_command_byte(gold=0x0) == 0x0) + self.failUnless(translate_command_byte(gold=0x10) == 0x10) + self.failUnless(translate_command_byte(gold=0x40) == 0x40) + self.assertEqual(translate_command_byte(gold=0x0), translate_command_byte(crystal=0x0)) + self.failUnless(translate_command_byte(gold=0x52) == 0x53) + self.failUnless(translate_command_byte(gold=0x53) == 0x54) + self.failUnless(translate_command_byte(crystal=0x53) == 0x52) + self.failUnless(translate_command_byte(crystal=0x52) == None) + self.assertRaises(Exception, translate_command_byte, None, gold=0xA4) + + def test_pksv_integrity(self): + "does pksv_gs look okay?" + self.assertEqual(pksv_gs[0x00], "2call") + self.assertEqual(pksv_gs[0x2D], "givepoke") + self.assertEqual(pksv_gs[0x85], "waitbutton") + self.assertEqual(pksv_crystal[0x00], "2call") + self.assertEqual(pksv_crystal[0x86], "waitbutton") + self.assertEqual(pksv_crystal[0xA2], "credits") + + def test_chars_integrity(self): + self.assertEqual(chars[0x80], "A") + self.assertEqual(chars[0xA0], "a") + self.assertEqual(chars[0xF0], "¥") + self.assertEqual(jap_chars[0x44], "ぱ") + + def test_map_names_integrity(self): + def map_name(map_group, map_id): return map_names[map_group][map_id]["name"] + self.assertEqual(map_name(2, 7), "Mahogany Town") + self.assertEqual(map_name(3, 0x34), "Ilex Forest") + self.assertEqual(map_name(7, 0x11), "Cerulean City") + + def test_load_map_group_offsets(self): + addresses = load_map_group_offsets() + self.assertEqual(len(addresses), 26, msg="there should be 26 map groups") + addresses = load_map_group_offsets() + self.assertEqual(len(addresses), 26, msg="there should still be 26 map groups") + self.assertIn(0x94034, addresses) + for address in addresses: + self.assertGreaterEqual(address, 0x4000) + self.failIf(0x4000 <= address <= 0x7FFF) + self.failIf(address <= 0x4000) + + def test_index(self): + self.assertTrue(index([1,2,3,4], lambda f: True) == 0) + self.assertTrue(index([1,2,3,4], lambda f: f==3) == 2) + + def test_get_pokemon_constant_by_id(self): + x = get_pokemon_constant_by_id + self.assertEqual(x(1), "BULBASAUR") + self.assertEqual(x(151), "MEW") + self.assertEqual(x(250), "HO_OH") + + def test_find_item_label_by_id(self): + x = find_item_label_by_id + self.assertEqual(x(249), "HM_07") + self.assertEqual(x(173), "BERRY") + self.assertEqual(x(45), None) + + def test_generate_item_constants(self): + x = generate_item_constants + r = x() + self.failUnless("HM_07" in r) + self.failUnless("EQU" in r) + + def test_get_label_for(self): + global all_labels + temp = copy(all_labels) + # this is basd on the format defined in get_labels_between + all_labels = [{"label": "poop", "address": 0x5, + "offset": 0x5, "bank": 0, + "line_number": 2 + }] + self.assertEqual(get_label_for(5), "poop") + all_labels = temp + + def test_generate_map_constant_labels(self): + ids = generate_map_constant_labels() + self.assertEqual(ids[0]["label"], "OLIVINE_POKECENTER_1F") + self.assertEqual(ids[1]["label"], "OLIVINE_GYM") + + def test_get_id_for_map_constant_label(self): + global map_internal_ids + map_internal_ids = generate_map_constant_labels() + self.assertEqual(get_id_for_map_constant_label("OLIVINE_GYM"), 1) + self.assertEqual(get_id_for_map_constant_label("OLIVINE_POKECENTER_1F"), 0) + + def test_get_map_constant_label_by_id(self): + global map_internal_ids + map_internal_ids = generate_map_constant_labels() + self.assertEqual(get_map_constant_label_by_id(0), "OLIVINE_POKECENTER_1F") + self.assertEqual(get_map_constant_label_by_id(1), "OLIVINE_GYM") + + def test_is_valid_address(self): + self.assertTrue(is_valid_address(0)) + self.assertTrue(is_valid_address(1)) + self.assertTrue(is_valid_address(10)) + self.assertTrue(is_valid_address(100)) + self.assertTrue(is_valid_address(1000)) + self.assertTrue(is_valid_address(10000)) + self.assertFalse(is_valid_address(2097153)) + self.assertFalse(is_valid_address(2098000)) + addresses = [random.randrange(0,2097153) for i in range(0, 9+1)] + for address in addresses: + self.assertTrue(is_valid_address(address)) + +class TestIntervalMap(unittest.TestCase): + def test_intervals(self): + i = IntervalMap() + first = "hello world" + second = "testing 123" + i[0:5] = first + i[5:10] = second + self.assertEqual(i[0], first) + self.assertEqual(i[1], first) + self.assertNotEqual(i[5], first) + self.assertEqual(i[6], second) + i[3:10] = second + self.assertEqual(i[3], second) + self.assertNotEqual(i[4], first) + + def test_items(self): + i = IntervalMap() + first = "hello world" + second = "testing 123" + i[0:5] = first + i[5:10] = second + results = list(i.items()) + self.failUnless(len(results) == 2) + self.assertEqual(results[0], ((0, 5), "hello world")) + self.assertEqual(results[1], ((5, 10), "testing 123")) + +class TestRomStr(unittest.TestCase): + """RomStr is a class that should act exactly like str() + except that it never shows the contents of it string + unless explicitly forced""" + sample_text = "hello world!" + sample = None + + def setUp(self): + if self.sample == None: + self.__class__.sample = RomStr(self.sample_text) + + def test_equals(self): + "check if RomStr() == str()" + self.assertEquals(self.sample_text, self.sample) + + def test_not_equal(self): + "check if RomStr('a') != RomStr('b')" + self.assertNotEqual(RomStr('a'), RomStr('b')) + + def test_appending(self): + "check if RomStr()+'a'==str()+'a'" + self.assertEquals(self.sample_text+'a', self.sample+'a') + + def test_conversion(self): + "check if RomStr() -> str() works" + self.assertEquals(str(self.sample), self.sample_text) + + def test_inheritance(self): + self.failUnless(issubclass(RomStr, str)) + + def test_length(self): + self.assertEquals(len(self.sample_text), len(self.sample)) + self.assertEquals(len(self.sample_text), self.sample.length()) + self.assertEquals(len(self.sample), self.sample.length()) + + def test_rom_interval(self): + global rom + load_rom() + address = 0x100 + interval = 10 + correct_strings = ['0x0', '0xc3', '0x6e', '0x1', '0xce', + '0xed', '0x66', '0x66', '0xcc', '0xd'] + byte_strings = rom.interval(address, interval, strings=True) + self.assertEqual(byte_strings, correct_strings) + correct_ints = [0, 195, 110, 1, 206, 237, 102, 102, 204, 13] + ints = rom.interval(address, interval, strings=False) + self.assertEqual(ints, correct_ints) + + def test_rom_until(self): + global rom + load_rom() + address = 0x1337 + byte = 0x13 + bytes = rom.until(address, byte, strings=True) + self.failUnless(len(bytes) == 3) + self.failUnless(bytes[0] == '0xd5') + bytes = rom.until(address, byte, strings=False) + self.failUnless(len(bytes) == 3) + self.failUnless(bytes[0] == 0xd5) + +class TestAsmList(unittest.TestCase): + """AsmList is a class that should act exactly like list() + except that it never shows the contents of its list + unless explicitly forced""" + + def test_equals(self): + base = [1,2,3] + asm = AsmList(base) + self.assertEquals(base, asm) + self.assertEquals(asm, base) + self.assertEquals(base, list(asm)) + + def test_inheritance(self): + self.failUnless(issubclass(AsmList, list)) + + def test_length(self): + base = range(0, 10) + asm = AsmList(base) + self.assertEquals(len(base), len(asm)) + self.assertEquals(len(base), asm.length()) + self.assertEquals(len(base), len(list(asm))) + self.assertEquals(len(asm), asm.length()) + + def test_remove_quoted_text(self): + x = remove_quoted_text + self.assertEqual(x("hello world"), "hello world") + self.assertEqual(x("hello \"world\""), "hello ") + input = 'hello world "testing 123"' + self.assertNotEqual(x(input), input) + input = "hello world 'testing 123'" + self.assertNotEqual(x(input), input) + self.failIf("testing" in x(input)) + + def test_line_has_comment_address(self): + x = line_has_comment_address + self.assertFalse(x("")) + self.assertFalse(x(";")) + self.assertFalse(x(";;;")) + self.assertFalse(x(":;")) + self.assertFalse(x(":;:")) + self.assertFalse(x(";:")) + self.assertFalse(x(" ")) + self.assertFalse(x("".join(" " * 5))) + self.assertFalse(x("".join(" " * 10))) + self.assertFalse(x("hello world")) + self.assertFalse(x("hello_world")) + self.assertFalse(x("hello_world:")) + self.assertFalse(x("hello_world:;")) + self.assertFalse(x("hello_world: ;")) + self.assertFalse(x("hello_world: ; ")) + self.assertFalse(x("hello_world: ;" + "".join(" " * 5))) + self.assertFalse(x("hello_world: ;" + "".join(" " * 10))) + self.assertTrue(x(";1")) + self.assertTrue(x(";F")) + self.assertTrue(x(";$00FF")) + self.assertTrue(x(";0x00FF")) + self.assertTrue(x("; 0x00FF")) + self.assertTrue(x(";$3:$300")) + self.assertTrue(x(";0x3:$300")) + self.assertTrue(x(";$3:0x300")) + self.assertTrue(x(";3:300")) + self.assertTrue(x(";3:FFAA")) + self.assertFalse(x('hello world "how are you today;0x1"')) + self.assertTrue(x('hello world "how are you today:0x1";1')) + returnable = {} + self.assertTrue(x("hello_world: ; 0x4050", returnable=returnable, bank=5)) + self.assertTrue(returnable["address"] == 0x14050) + + def test_line_has_label(self): + x = line_has_label + self.assertTrue(x("hi:")) + self.assertTrue(x("Hello: ")) + self.assertTrue(x("MyLabel: ; test xyz")) + self.assertFalse(x(":")) + self.assertFalse(x(";HelloWorld:")) + self.assertFalse(x("::::")) + self.assertFalse(x(":;:;:;:::")) + + def test_get_label_from_line(self): + x = get_label_from_line + self.assertEqual(x("HelloWorld: "), "HelloWorld") + self.assertEqual(x("HiWorld:"), "HiWorld") + self.assertEqual(x("HiWorld"), None) + + def test_find_labels_without_addresses(self): + global asm + asm = ["hello_world: ; 0x1", "hello_world2: ;"] + labels = find_labels_without_addresses() + self.failUnless(labels[0]["label"] == "hello_world2") + asm = ["hello world: ;1", "hello_world: ;2"] + labels = find_labels_without_addresses() + self.failUnless(len(labels) == 0) + asm = None + + def test_get_labels_between(self): + global asm + x = get_labels_between#(start_line_id, end_line_id, bank) + asm = ["HelloWorld: ;1", + "hi:", + "no label on this line", + ] + labels = x(0, 2, 0x12) + self.assertEqual(len(labels), 1) + self.assertEqual(labels[0]["label"], "HelloWorld") + del asm + + # this test takes a lot of time :( + def xtest_scan_for_predefined_labels(self): + # label keys: line_number, bank, label, offset, address + load_asm() + all_labels = scan_for_predefined_labels() + label_names = [x["label"] for x in all_labels] + self.assertIn("GetFarByte", label_names) + self.assertIn("AddNTimes", label_names) + self.assertIn("CheckShininess", label_names) + + def test_write_all_labels(self): + """dumping json into a file""" + filename = "test_labels.json" + # remove the current file + if os.path.exists(filename): + os.system("rm " + filename) + # make up some labels + labels = [] + # fake label 1 + label = {"line_number": 5, "bank": 0, "label": "SomeLabel", "address": 0x10} + labels.append(label) + # fake label 2 + label = {"line_number": 15, "bank": 2, "label": "SomeOtherLabel", "address": 0x9F0A} + labels.append(label) + # dump to file + write_all_labels(labels, filename=filename) + # open the file and read the contents + file_handler = open(filename, "r") + contents = file_handler.read() + file_handler.close() + # parse into json + obj = json.read(contents) + # begin testing + self.assertEqual(len(obj), len(labels)) + self.assertEqual(len(obj), 2) + self.assertEqual(obj, labels) + + def test_isolate_incbins(self): + global asm + asm = ["123", "456", "789", "abc", "def", "ghi", + 'INCBIN "baserom.gbc",$12DA,$12F8 - $12DA', + "jkl", + 'INCBIN "baserom.gbc",$137A,$13D0 - $137A'] + lines = isolate_incbins() + self.assertIn(asm[6], lines) + self.assertIn(asm[8], lines) + for line in lines: + self.assertIn("baserom", line) + + def test_process_incbins(self): + global incbin_lines, processed_incbins, asm + incbin_lines = ['INCBIN "baserom.gbc",$12DA,$12F8 - $12DA', + 'INCBIN "baserom.gbc",$137A,$13D0 - $137A'] + asm = copy(incbin_lines) + asm.insert(1, "some other random line") + processed_incbins = process_incbins() + self.assertEqual(len(processed_incbins), len(incbin_lines)) + self.assertEqual(processed_incbins[0]["line"], incbin_lines[0]) + self.assertEqual(processed_incbins[2]["line"], incbin_lines[1]) + + def test_reset_incbins(self): + global asm, incbin_lines, processed_incbins + # temporarily override the functions + global load_asm, isolate_incbins, process_incbins + temp1, temp2, temp3 = load_asm, isolate_incbins, process_incbins + def load_asm(): pass + def isolate_incbins(): pass + def process_incbins(): pass + # call reset + reset_incbins() + # check the results + self.assertTrue(asm == [] or asm == None) + self.assertTrue(incbin_lines == []) + self.assertTrue(processed_incbins == {}) + # reset the original functions + load_asm, isolate_incbins, process_incbins = temp1, temp2, temp3 + + def test_find_incbin_to_replace_for(self): + global asm, incbin_lines, processed_incbins + asm = ['first line', 'second line', 'third line', + 'INCBIN "baserom.gbc",$90,$200 - $90', + 'fifth line', 'last line'] + isolate_incbins() + process_incbins() + line_num = find_incbin_to_replace_for(0x100) + # must be the 4th line (the INBIN line) + self.assertEqual(line_num, 3) + + def test_split_incbin_line_into_three(self): + global asm, incbin_lines, processed_incbins + asm = ['first line', 'second line', 'third line', + 'INCBIN "baserom.gbc",$90,$200 - $90', + 'fifth line', 'last line'] + isolate_incbins() + process_incbins() + content = split_incbin_line_into_three(3, 0x100, 10) + # must end up with three INCBINs in output + self.failUnless(content.count("INCBIN") == 3) + + def test_analyze_intervals(self): + global asm, incbin_lines, processed_incbins + asm, incbin_lines, processed_incbins = None, [], {} + asm = ['first line', 'second line', 'third line', + 'INCBIN "baserom.gbc",$90,$200 - $90', + 'fifth line', 'last line', + 'INCBIN "baserom.gbc",$33F,$4000 - $33F'] + isolate_incbins() + process_incbins() + largest = analyze_intervals() + self.assertEqual(largest[0]["line_number"], 6) + self.assertEqual(largest[0]["line"], asm[6]) + self.assertEqual(largest[1]["line_number"], 3) + self.assertEqual(largest[1]["line"], asm[3]) + + def test_generate_diff_insert(self): + global asm + asm = ['first line', 'second line', 'third line', + 'INCBIN "baserom.gbc",$90,$200 - $90', + 'fifth line', 'last line', + 'INCBIN "baserom.gbc",$33F,$4000 - $33F'] + diff = generate_diff_insert(0, "the real first line", debug=False) + self.assertIn("the real first line", diff) + self.assertIn("INCBIN", diff) + self.assertNotIn("No newline at end of file", diff) + self.assertIn("+"+asm[1], diff) + +class TestMapParsing(unittest.TestCase): + def test_parse_all_map_headers(self): + global parse_map_header_at, old_parse_map_header_at, counter + counter = 0 + for k in map_names.keys(): + if "offset" not in map_names[k].keys(): + map_names[k]["offset"] = 0 + temp = parse_map_header_at + temp2 = old_parse_map_header_at + def parse_map_header_at(address, map_group=None, map_id=None, debug=False): + global counter + counter += 1 + return {} + old_parse_map_header_at = parse_map_header_at + parse_all_map_headers(debug=False) + # parse_all_map_headers is currently doing it 2x + # because of the new/old map header parsing routines + self.assertEqual(counter, 388 * 2) + parse_map_header_at = temp + old_parse_map_header_at = temp2 + +class TestTextScript(unittest.TestCase): + """for testing 'in-script' commands, etc.""" + #def test_to_asm(self): + # pass # or raise NotImplementedError, bryan_message + #def test_find_addresses(self): + # pass # or raise NotImplementedError, bryan_message + #def test_parse_text_at(self): + # pass # or raise NotImplementedError, bryan_message + +class TestEncodedText(unittest.TestCase): + """for testing chars-table encoded text chunks""" + + def test_process_00_subcommands(self): + g = process_00_subcommands(0x197186, 0x197186+601, debug=False) + self.assertEqual(len(g), 42) + self.assertEqual(len(g[0]), 13) + self.assertEqual(g[1], [184, 174, 180, 211, 164, 127, 20, 231, 81]) + + def test_parse_text_at2(self): + oakspeech = parse_text_at2(0x197186, 601, debug=False) + self.assertIn("encyclopedia", oakspeech) + self.assertIn("researcher", oakspeech) + self.assertIn("dependable", oakspeech) + + def test_parse_text_engine_script_at(self): + p = parse_text_engine_script_at(0x197185, debug=False) + self.assertEqual(len(p.commands), 2) + self.assertEqual(len(p.commands[0]["lines"]), 41) + + # don't really care about these other two + def test_parse_text_from_bytes(self): pass + def test_parse_text_at(self): pass + +class TestScript(unittest.TestCase): + """for testing parse_script_engine_script_at and script parsing in + general. Script should be a class.""" + #def test_parse_script_engine_script_at(self): + # pass # or raise NotImplementedError, bryan_message + + def test_find_all_text_pointers_in_script_engine_script(self): + address = 0x197637 # 0x197634 + script = parse_script_engine_script_at(address, debug=False) + bank = calculate_bank(address) + r = find_all_text_pointers_in_script_engine_script(script, bank=bank, debug=False) + results = list(r) + self.assertIn(0x197661, results) + +class TestLabel(unittest.TestCase): + def test_label_making(self): + line_number = 2 + address = 0xf0c0 + label_name = "poop" + l = Label(name=label_name, address=address, line_number=line_number) + self.failUnless(hasattr(l, "name")) + self.failUnless(hasattr(l, "address")) + self.failUnless(hasattr(l, "line_number")) + self.failIf(isinstance(l.address, str)) + self.failIf(isinstance(l.line_number, str)) + self.failUnless(isinstance(l.name, str)) + self.assertEqual(l.line_number, line_number) + self.assertEqual(l.name, label_name) + self.assertEqual(l.address, address) + +class TestByteParams(unittest.TestCase): + @classmethod + def setUpClass(cls): + load_rom() + cls.address = 10 + cls.sbp = SingleByteParam(address=cls.address) + + @classmethod + def tearDownClass(cls): + del cls.sbp + + def test__init__(self): + self.assertEqual(self.sbp.size, 1) + self.assertEqual(self.sbp.address, self.address) + + def test_parse(self): + self.sbp.parse() + self.assertEqual(str(self.sbp.byte), str(45)) + + def test_to_asm(self): + self.assertEqual(self.sbp.to_asm(), "$2d") + self.sbp.should_be_decimal = True + self.assertEqual(self.sbp.to_asm(), str(45)) + + # HexByte and DollarSignByte are the same now + def test_HexByte_to_asm(self): + h = HexByte(address=10) + a = h.to_asm() + self.assertEqual(a, "$2d") + + def test_DollarSignByte_to_asm(self): + d = DollarSignByte(address=10) + a = d.to_asm() + self.assertEqual(a, "$2d") + + def test_ItemLabelByte_to_asm(self): + i = ItemLabelByte(address=433) + self.assertEqual(i.byte, 54) + self.assertEqual(i.to_asm(), "COIN_CASE") + self.assertEqual(ItemLabelByte(address=10).to_asm(), "$2d") + + def test_DecimalParam_to_asm(self): + d = DecimalParam(address=10) + x = d.to_asm() + self.assertEqual(x, str(0x2d)) + +class TestMultiByteParam(unittest.TestCase): + def setup_for(self, somecls, byte_size=2, address=443, **kwargs): + self.cls = somecls(address=address, size=byte_size, **kwargs) + self.assertEqual(self.cls.address, address) + self.assertEqual(self.cls.bytes, rom_interval(address, byte_size, strings=False)) + self.assertEqual(self.cls.size, byte_size) + + def test_two_byte_param(self): + self.setup_for(MultiByteParam, byte_size=2) + self.assertEqual(self.cls.to_asm(), "$f0c0") + + def test_three_byte_param(self): + self.setup_for(MultiByteParam, byte_size=3) + + def test_PointerLabelParam_no_bank(self): + self.setup_for(PointerLabelParam, bank=None) + # assuming no label at this location.. + self.assertEqual(self.cls.to_asm(), "$f0c0") + global all_labels + # hm.. maybe all_labels should be using a class? + all_labels = [{"label": "poop", "address": 0xf0c0, + "offset": 0xf0c0, "bank": 0, + "line_number": 2 + }] + self.assertEqual(self.cls.to_asm(), "poop") + +class TestPostParsing: #(unittest.TestCase): + """tests that must be run after parsing all maps""" + @classmethod + def setUpClass(cls): + run_main() + + def test_signpost_counts(self): + self.assertEqual(len(map_names[1][1]["signposts"]), 0) + self.assertEqual(len(map_names[1][2]["signposts"]), 2) + self.assertEqual(len(map_names[10][5]["signposts"]), 7) + + def test_warp_counts(self): + self.assertEqual(map_names[10][5]["warp_count"], 9) + self.assertEqual(map_names[18][5]["warp_count"], 3) + self.assertEqual(map_names[15][1]["warp_count"], 2) + + def test_map_sizes(self): + self.assertEqual(map_names[15][1]["height"], 18) + self.assertEqual(map_names[15][1]["width"], 10) + self.assertEqual(map_names[7][1]["height"], 4) + self.assertEqual(map_names[7][1]["width"], 4) + + def test_map_connection_counts(self): + self.assertEqual(map_names[7][1]["connections"], 0) + self.assertEqual(map_names[10][1]["connections"], 12) + self.assertEqual(map_names[10][2]["connections"], 12) + self.assertEqual(map_names[11][1]["connections"], 9) # or 13? + + def test_second_map_header_address(self): + self.assertEqual(map_names[11][1]["second_map_header_address"], 0x9509c) + self.assertEqual(map_names[1][5]["second_map_header_address"], 0x95bd0) + + def test_event_address(self): + self.assertEqual(map_names[17][5]["event_address"], 0x194d67) + self.assertEqual(map_names[23][3]["event_address"], 0x1a9ec9) + + def test_people_event_counts(self): + self.assertEqual(len(map_names[23][3]["people_events"]), 4) + self.assertEqual(len(map_names[10][3]["people_events"]), 9) + +class TestMetaTesting(unittest.TestCase): + """test whether or not i am finding at least + some of the tests in this file""" + tests = None + + def setUp(self): + if self.tests == None: + self.__class__.tests = assemble_test_cases() + + def test_assemble_test_cases_count(self): + "does assemble_test_cases find some tests?" + self.failUnless(len(self.tests) > 0) + + def test_assemble_test_cases_inclusion(self): + "is this class found by assemble_test_cases?" + # i guess it would have to be for this to be running? + self.failUnless(self.__class__ in self.tests) + + def test_assemble_test_cases_others(self): + "test other inclusions for assemble_test_cases" + self.failUnless(TestRomStr in self.tests) + self.failUnless(TestCram in self.tests) + + def test_check_has_test(self): + self.failUnless(check_has_test("beaver", ["test_beaver"])) + self.failUnless(check_has_test("beaver", ["test_beaver_2"])) + self.failIf(check_has_test("beaver_1", ["test_beaver"])) + + def test_find_untested_methods(self): + untested = find_untested_methods() + # the return type must be an iterable + self.failUnless(hasattr(untested, "__iter__")) + #.. basically, a list + self.failUnless(isinstance(untested, list)) + + def test_find_untested_methods_method(self): + """create a function and see if it is found""" + # setup a function in the global namespace + global some_random_test_method + # define the method + def some_random_test_method(): pass + # first make sure it is in the global scope + members = inspect.getmembers(sys.modules[__name__], inspect.isfunction) + func_names = [functuple[0] for functuple in members] + self.assertIn("some_random_test_method", func_names) + # test whether or not it is found by find_untested_methods + untested = find_untested_methods() + self.assertIn("some_random_test_method", untested) + # remove the test method from the global namespace + del some_random_test_method + + def test_load_tests(self): + loader = unittest.TestLoader() + suite = load_tests(loader, None, None) + suite._tests[0]._testMethodName + membership_test = lambda member: \ + inspect.isclass(member) and issubclass(member, unittest.TestCase) + tests = inspect.getmembers(sys.modules[__name__], membership_test) + classes = [x[1] for x in tests] + for test in suite._tests: + self.assertIn(test.__class__, classes) + + def test_report_untested(self): + untested = find_untested_methods() + output = report_untested() + if len(untested) > 0: + self.assertIn("NOT TESTED", output) + for name in untested: + self.assertIn(name, output) + elif len(untested) == 0: + self.assertNotIn("NOT TESTED", output) + +def assemble_test_cases(): + """finds classes that inherit from unittest.TestCase + because i am too lazy to remember to add them to a + global list of tests for the suite runner""" + classes = [] + clsmembers = inspect.getmembers(sys.modules[__name__], inspect.isclass) + for (name, some_class) in clsmembers: + if issubclass(some_class, unittest.TestCase): + classes.append(some_class) + return classes + +def load_tests(loader, tests, pattern): + suite = unittest.TestSuite() + for test_class in assemble_test_cases(): + tests = loader.loadTestsFromTestCase(test_class) + suite.addTests(tests) + return suite + +def check_has_test(func_name, tested_names): + """checks if there is a test dedicated to this function""" + if "test_"+func_name in tested_names: + return True + for name in tested_names: + if "test_"+func_name in name: + return True + return False + +def find_untested_methods(): + """finds all untested functions in this module + by searching for method names in test case + method names.""" + untested = [] + avoid_funcs = ["main", "run_tests", "run_main", "copy", "deepcopy"] + test_funcs = [] + # get a list of all classes in this module + classes = inspect.getmembers(sys.modules[__name__], inspect.isclass) + # for each class.. + for (name, klass) in classes: + # only look at those that have tests + if issubclass(klass, unittest.TestCase): + # look at this class' methods + funcs = inspect.getmembers(klass, inspect.ismethod) + # for each method.. + for (name2, func) in funcs: + # store the ones that begin with test_ + if "test_" in name2 and name2[0:5] == "test_": + test_funcs.append([name2, func]) + # assemble a list of all test method names (test_x, test_y, ..) + tested_names = [funcz[0] for funcz in test_funcs] + # now get a list of all functions in this module + funcs = inspect.getmembers(sys.modules[__name__], inspect.isfunction) + # for each function.. + for (name, func) in funcs: + # we don't care about some of these + if name in avoid_funcs: continue + # skip functions beginning with _ + if name[0] == "_": continue + # check if this function has a test named after it + has_test = check_has_test(name, tested_names) + if not has_test: + untested.append(name) + return untested + +def report_untested(): + """ + This reports about untested functions in the global namespace. This was + originally in the crystal module, where it would list out the majority of + the functions. Maybe it should be moved back. + """ + untested = find_untested_methods() + output = "NOT TESTED: [" + first = True + for name in untested: + if first: + output += name + first = False + else: output += ", "+name + output += "]\n" + output += "total untested: " + str(len(untested)) + return output + +def run_tests(): # rather than unittest.main() + loader = unittest.TestLoader() + suite = load_tests(loader, None, None) + unittest.TextTestRunner(verbosity=2).run(suite) + print report_untested() + +# run the unit tests when this file is executed directly +if __name__ == "__main__": + run_tests() + diff --git a/type_constants.py b/type_constants.py new file mode 100644 index 0000000..da89b0b --- /dev/null +++ b/type_constants.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +type_constants = { + "NORMAL": 0x00, + "FIGHTING": 0x01, + "FLYING": 0x02, + "POISON": 0x03, + "GROUND": 0x04, + "ROCK": 0x05, + "BUG": 0x07, + "GHOST": 0x08, + "STEEL": 0x09, + "CURSE_T": 0x13, + "FIRE": 0x14, + "WATER": 0x15, + "GRASS": 0x16, + "ELECTRIC": 0x17, + "PSYCHIC": 0x18, + "ICE": 0x19, + "DRAGON": 0x1A, + "DARK": 0x1B, +} |