summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md5
-rw-r--r--comparator.py40
-rw-r--r--crystal.py1410
-rwxr-xr-xdump_sections14
-rwxr-xr-xdump_sections.py130
-rw-r--r--gbz80disasm.py28
-rw-r--r--gfx.py58
-rw-r--r--graph.py13
-rw-r--r--interval_map.py104
-rw-r--r--item_constants.py23
-rw-r--r--labels.py9
-rw-r--r--move_constants.py2
-rw-r--r--pksv.py16
-rw-r--r--pointers.py8
-rw-r--r--pokemon_constants.py2
-rw-r--r--romstr.py354
-rw-r--r--test_dump_sections.py74
-rw-r--r--tests.py1015
-rw-r--r--type_constants.py21
19 files changed, 1714 insertions, 1612 deletions
diff --git a/README.md b/README.md
index d0a2119..0131fc0 100644
--- a/README.md
+++ b/README.md
@@ -30,9 +30,8 @@ After running those lines, `cp extras/output.txt main.asm` and run `git diff mai
Unit tests cover most of the classes.
-```python
-import crystal
-crystal.run_tests()
+```bash
+python tests.py
```
#### Parsing a script at a known address
diff --git a/comparator.py b/comparator.py
index 690fa52..6d981e4 100644
--- a/comparator.py
+++ b/comparator.py
@@ -1,34 +1,30 @@
-#!/usr/bin/python
# -*- coding: utf-8 -*-
-# author: Bryan Bishop <kanzure@gmail.com>
-# date: 2012-05-29
-# purpose: find shared functions between red/crystal
-
-from crystal import get_label_from_line, \
- get_address_from_line_comment, \
- AsmSection
-
-from romstr import RomStr, AsmList
+"""
+Finds shared functions between red/crystal.
+"""
+
+from crystal import (
+ get_label_from_line,
+ get_address_from_line_comment,
+ AsmSection,
+ direct_load_rom,
+ direct_load_asm,
+)
+
+from romstr import (
+ RomStr,
+ AsmList,
+)
def load_rom(path):
""" Loads a ROM file into an abbreviated RomStr object.
"""
-
- fh = open(path, "r")
- x = RomStr(fh.read())
- fh.close()
-
- return x
+ return direct_load_rom(filename=path)
def load_asm(path):
""" Loads source ASM into an abbreviated AsmList object.
"""
-
- fh = open(path, "r")
- x = AsmList(fh.read().split("\n"))
- fh.close()
-
- return x
+ return direct_load_asm(filename=path)
def findall_iter(sub, string):
# url: http://stackoverflow.com/a/3874760/687783
diff --git a/crystal.py b/crystal.py
index 3f30ed0..06d54ae 100644
--- a/crystal.py
+++ b/crystal.py
@@ -1,29 +1,19 @@
# -*- coding: utf-8 -*-
# utilities to help disassemble pokémon crystal
-import sys, os, inspect, hashlib, json
+import os
+import sys
+import inspect
+import hashlib
+import json
from copy import copy, deepcopy
import subprocess
from new import classobj
import random
-# for IntervalMap
-from bisect import bisect_left, bisect_right
-from itertools import izip
-
-# for testing all this crap
-try:
- import unittest2 as unittest
-except ImportError:
- import unittest
-
# for capwords
import string
-# Check for things we need in unittest.
-if not hasattr(unittest.TestCase, 'setUpClass'):
- sys.stderr.write("The unittest2 module or Python 2.7 is required to run this script.")
- sys.exit(1)
-
+# for python2.6
if not hasattr(json, "dumps"):
json.dumps = json.write
@@ -31,6 +21,14 @@ if not hasattr(json, "dumps"):
if not hasattr(json, "read"):
json.read = json.loads
+from labels import (
+ remove_quoted_text,
+ line_has_comment_address,
+ line_has_label,
+ get_label_from_line,
+ get_address_from_line_comment
+)
+
spacing = "\t"
lousy_dragon_shrine_hack = [0x18d079, 0x18d0a9, 0x18d061, 0x18d091]
@@ -63,132 +61,25 @@ constant_abbreviation_bytes = {}
# Import the characters from its module.
from chars import chars, jap_chars
-from trainers import *
+from trainers import (
+ trainer_group_pointer_table_address, # 0x39999
+ trainer_group_pointer_table_address_gs, # 0x3993E
+ trainer_group_names,
+)
from move_constants import moves
# for fixing trainer_group_names
import re
-trainer_group_pointer_table_address = 0x39999
-trainer_group_pointer_table_address_gs = 0x3993E
-
-class Size():
- """a simple way to track whether or not a size
- includes the first value or not, like for
- whether or not the size of a command in a script
- also includes the command byte or not"""
-
- def __init__(self, size, inclusive=False):
- self.inclusive = inclusive
- if inclusive: size = size-1
- self.size = size
-
- def inclusive(self):
- return self.size + 1
-
- def exclusive(self):
- return self.size
-
-class IntervalMap(object):
- """
- This class maps a set of intervals to a set of values.
-
- >>> i = IntervalMap()
- >>> i[0:5] = "hello world"
- >>> i[6:10] = "hello cruel world"
- >>> print i[4]
- "hello world"
- """
-
- def __init__(self):
- """initializes an empty IntervalMap"""
- self._bounds = []
- self._items = []
- self._upperitem = None
-
- def __setitem__(self, _slice, _value):
- """sets an interval mapping"""
- assert isinstance(_slice, slice), 'The key must be a slice object'
-
- if _slice.start is None:
- start_point = -1
- else:
- start_point = bisect_left(self._bounds, _slice.start)
-
- if _slice.stop is None:
- end_point = -1
- else:
- end_point = bisect_left(self._bounds, _slice.stop)
-
- if start_point>=0:
- if start_point < len(self._bounds) and self._bounds[start_point]<_slice.start:
- start_point += 1
-
- if end_point>=0:
- self._bounds[start_point:end_point] = [_slice.start, _slice.stop]
- if start_point < len(self._items):
- self._items[start_point:end_point] = [self._items[start_point], _value]
- else:
- self._items[start_point:end_point] = [self._upperitem, _value]
- else:
- self._bounds[start_point:] = [_slice.start]
- if start_point < len(self._items):
- self._items[start_point:] = [self._items[start_point], _value]
- else:
- self._items[start_point:] = [self._upperitem]
- self._upperitem = _value
- else:
- if end_point>=0:
- self._bounds[:end_point] = [_slice.stop]
- self._items[:end_point] = [_value]
- else:
- self._bounds[:] = []
- self._items[:] = []
- self._upperitem = _value
-
- def __getitem__(self,_point):
- """gets a value from the mapping"""
- assert not isinstance(_point, slice), 'The key cannot be a slice object'
-
- index = bisect_right(self._bounds, _point)
- if index < len(self._bounds):
- return self._items[index]
- else:
- return self._upperitem
-
- def items(self):
- """returns an iterator with each item being
- ((low_bound, high_bound), value)
- these items are returned in order"""
- previous_bound = None
- for (b, v) in izip(self._bounds, self._items):
- if v is not None:
- yield (previous_bound, b), v
- previous_bound = b
- if self._upperitem is not None:
- yield (previous_bound, None), self._upperitem
-
- def values(self):
- """returns an iterator with each item being a stored value
- the items are returned in order"""
- for v in self._items:
- if v is not None:
- yield v
- if self._upperitem is not None:
- yield self._upperitem
-
- def __repr__(self):
- s = []
- for b,v in self.items():
- if v is not None:
- s.append('[%r, %r] => %r'%(
- b[0],
- b[1],
- v
- ))
- return '{'+', '.join(s)+'}'
+from interval_map import IntervalMap
+from pksv import (
+ pksv_gs,
+ pksv_crystal,
+ pksv_crystal_unknowns,
+ pksv_crystal_more_enders
+)
# ---- script_parse_table explanation ----
# This is an IntervalMap that keeps track of previously parsed scripts, texts
@@ -206,7 +97,8 @@ script_parse_table = IntervalMap()
def is_script_already_parsed_at(address):
"""looks up whether or not a script is parsed at a certain address"""
- if script_parse_table[address] == None: return False
+ if script_parse_table[address] == None:
+ return False
return True
def script_parse_table_pretty_printer():
@@ -230,7 +122,10 @@ def map_name_cleaner(input):
replace("hooh", "HoOh").\
replace(" ", "")
-from romstr import RomStr, AsmList
+from romstr import (
+ RomStr,
+ AsmList,
+)
rom = RomStr(None)
@@ -253,12 +148,16 @@ def load_rom(filename="../baserom.gbc"):
elif os.lstat(filename).st_size != len(rom):
return direct_load_rom(filename)
+def direct_load_asm(filename="../main.asm"):
+ """returns asm source code (AsmList) from a file"""
+ asm = open(filename, "r").read().split("\n")
+ asm = AsmList(asm)
+ return asm
def load_asm(filename="../main.asm"):
- """loads the asm source code into memory"""
+ """returns asm source code (AsmList) from a file (uses a global)"""
global asm
- asm = open(filename, "r").read().split("\n")
- asm = AsmList(asm)
+ asm = direct_load_asm(filename=filename)
return asm
def grouper(some_list, count=2):
@@ -269,11 +168,14 @@ def grouper(some_list, count=2):
def is_valid_address(address):
"""is_valid_rom_address"""
- if address == None: return False
+ if address == None:
+ return False
if type(address) == str:
address = int(address, 16)
- if 0 <= address <= 2097152: return True
- else: return False
+ if 0 <= address <= 2097152:
+ return True
+ else:
+ return False
def rom_interval(offset, length, strings=True, debug=True):
"""returns hex values for the rom starting at offset until offset+length"""
@@ -302,7 +204,10 @@ def load_map_group_offsets():
map_group_offsets.append(offset)
return map_group_offsets
-from pointers import calculate_bank, calculate_pointer
+from pointers import (
+ calculate_bank,
+ calculate_pointer,
+)
def calculate_pointer_from_bytes_at(address, bank=False):
"""calculates a pointer from 2 bytes at a location
@@ -317,7 +222,7 @@ def calculate_pointer_from_bytes_at(address, bank=False):
elif type(bank) == int:
pass
else:
- raise Exception, "bad bank given to calculate_pointer_from_bytes_at"
+ raise Exception("bad bank given to calculate_pointer_from_bytes_at")
byte1 = ord(rom[address])
byte2 = ord(rom[address+1])
temp = byte1 + (byte2 << 8)
@@ -345,6 +250,20 @@ def clean_up_long_info(long_info):
long_info = "\n".join(new_lines)
return long_info
+from pokemon_constants import pokemon_constants
+
+def get_pokemon_constant_by_id(id):
+ if id == 0:
+ return None
+ else:
+ return pokemon_constants[id]
+
+from item_constants import (
+ item_constants,
+ find_item_label_by_id,
+ generate_item_constants,
+)
+
def command_debug_information(command_byte=None, map_group=None, map_id=None, address=0, info=None, long_info=None, pksv_name=None):
"used to help debug in parse_script_engine_script_at"
info1 = "parsing command byte " + hex(command_byte) + " for map " + \
@@ -376,7 +295,7 @@ class TextScript:
self.force = force
if is_script_already_parsed_at(address) and not force:
- raise Exception, "TextScript already parsed at "+hex(address)
+ raise Exception("TextScript already parsed at "+hex(address))
if not label:
label = self.base_label + hex(address)
@@ -449,11 +368,11 @@ class TextScript:
print "self.commands is: " + str(commands)
print "command 0 address is: " + hex(commands[0].address) + " last_address="+hex(commands[0].last_address)
print "command 1 address is: " + hex(commands[1].address) + " last_address="+hex(commands[1].last_address)
- raise Exception, "going beyond the bounds for this text script"
+ raise Exception("going beyond the bounds for this text script")
# no matching command found
if scripting_command_class == None:
- raise Exception, "unable to parse text command $%.2x in the text script at %s at %s" % (cur_byte, hex(start_address), hex(current_address))
+ raise Exception("unable to parse text command $%.2x in the text script at %s at %s" % (cur_byte, hex(start_address), hex(current_address)))
# create an instance of the command class and let it parse its parameter bytes
cls = scripting_command_class(address=current_address, map_group=self.map_group, map_id=self.map_id, debug=self.debug, force=self.force)
@@ -631,7 +550,7 @@ class OldTextScript:
if is_script_already_parsed_at(address) and not force:
print "text is already parsed at this location: " + hex(address)
- raise Exception, "text is already parsed, what's going on ?"
+ raise Exception("text is already parsed, what's going on ?")
return script_parse_table[address]
total_text_commands = 0
@@ -845,7 +764,7 @@ class OldTextScript:
def get_dependencies(self, recompute=False, global_dependencies=set()):
#if recompute:
- # raise NotImplementedError, bryan_message
+ # raise NotImplementedError(bryan_message)
global_dependencies.update(self.dependencies)
return self.dependencies
@@ -1157,7 +1076,8 @@ def parse_text_at3(address, map_group=None, map_id=None, debug=False):
text = TextScript(address, map_group=map_group, map_id=map_id, debug=debug)
if text.is_valid():
return text
- else: return None
+ else:
+ return None
def rom_text_at(address, count=10):
"""prints out raw text from the ROM
@@ -1166,8 +1086,10 @@ def rom_text_at(address, count=10):
def get_map_constant_label(map_group=None, map_id=None):
"""returns PALLET_TOWN for some map group/id pair"""
- if map_group == None: raise Exception, "need map_group"
- if map_id == None: raise Exception, "need map_id"
+ if map_group == None:
+ raise Exception("need map_group")
+ if map_id == None:
+ raise Exception("need map_id")
global map_internal_ids
for (id, each) in map_internal_ids.items():
@@ -1185,7 +1107,8 @@ def get_id_for_map_constant_label(label):
PALLET_TOWN = 1, for instance."""
global map_internal_ids
for (id, each) in map_internal_ids.items():
- if each["label"] == label: return id
+ if each["label"] == label:
+ return id
return None
def generate_map_constant_labels():
@@ -1264,37 +1187,16 @@ def transform_wildmons(asm):
returnlines.append(line)
return "\n".join(returnlines)
-from pokemon_constants import pokemon_constants
-
-def get_pokemon_constant_by_id(id):
- if id == 0: return None
- return pokemon_constants[id]
-
def parse_script_asm_at(*args, **kwargs):
# XXX TODO
return None
-from item_constants import item_constants
-
-def find_item_label_by_id(id):
- if id in item_constants.keys():
- return item_constants[id]
- else: return None
-
-def generate_item_constants():
- """make a list of items to put in constants.asm"""
- output = ""
- for (id, item) in item_constants.items():
- val = ("$%.2x"%id).upper()
- while len(item)<13: item+= " "
- output += item + " EQU " + val + "\n"
- return output
-
def find_all_text_pointers_in_script_engine_script(script, bank=None, debug=False):
"""returns a list of text pointers
based on each script-engine script command"""
# TODO: recursively follow any jumps in the script
- if script == None: return []
+ if script == None:
+ return []
addresses = set()
for (k, command) in enumerate(script.commands):
if debug:
@@ -1329,16 +1231,16 @@ def translate_command_byte(crystal=None, gold=None):
if 0x53 <= crystal <= 0x9E: return crystal-1
if crystal == 0x9F: return None
if 0xA0 <= crystal <= 0xA5: return crystal-2
- if crystal > 0xA5: raise Exception, "dunno yet if crystal has new insertions after crystal:0xA5 (gold:0xA3)"
+ if crystal > 0xA5:
+ raise Exception("dunno yet if crystal has new insertions after crystal:0xA5 (gold:0xA3)")
elif gold != None: # convert to crystal
if gold <= 0x51: return gold
if 0x52 <= gold <= 0x9D: return gold+1
if 0x9E <= gold <= 0xA3: return gold+2
- if gold > 0xA3: raise Exception, "dunno yet if crystal has new insertions after gold:0xA3 (crystal:0xA5)"
- else: raise Exception, "translate_command_byte needs either a crystal or gold command"
-
-from pksv import pksv_gs, pksv_crystal, pksv_crystal_unknowns,\
- pksv_crystal_more_enders
+ if gold > 0xA3:
+ raise Exception("dunno yet if crystal has new insertions after gold:0xA3 (crystal:0xA5)")
+ else:
+ raise Exception("translate_command_byte needs either a crystal or gold command")
class SingleByteParam():
"""or SingleByte(CommandParam)"""
@@ -1351,14 +1253,14 @@ class SingleByteParam():
setattr(self, key, value)
# check address
if not hasattr(self, "address"):
- raise Exception, "an address is a requirement"
+ raise Exception("an address is a requirement")
elif self.address == None:
- raise Exception, "address must not be None"
+ raise Exception("address must not be None")
elif not is_valid_address(self.address):
- raise Exception, "address must be valid"
+ raise Exception("address must be valid")
# check size
if not hasattr(self, "size") or self.size == None:
- raise Exception, "size is probably 1?"
+ raise Exception("size is probably 1?")
# parse bytes from ROM
self.parse()
@@ -1368,18 +1270,23 @@ class SingleByteParam():
return []
def to_asm(self):
- if not self.should_be_decimal: return hex(self.byte).replace("0x", "$")
- else: return str(self.byte)
+ if not self.should_be_decimal:
+ return hex(self.byte).replace("0x", "$")
+ else:
+ return str(self.byte)
class DollarSignByte(SingleByteParam):
- def to_asm(self): return hex(self.byte).replace("0x", "$")
+ def to_asm(self):
+ return hex(self.byte).replace("0x", "$")
HexByte=DollarSignByte
class ItemLabelByte(DollarSignByte):
def to_asm(self):
label = find_item_label_by_id(self.byte)
- if label: return label
- elif not label: return DollarSignByte.to_asm(self)
+ if label:
+ return label
+ elif not label:
+ return DollarSignByte.to_asm(self)
class DecimalParam(SingleByteParam):
@@ -1398,12 +1305,12 @@ class MultiByteParam():
setattr(self, key, value)
# check address
if not hasattr(self, "address") or self.address == None:
- raise Exception, "an address is a requirement"
+ raise Exception("an address is a requirement")
elif not is_valid_address(self.address):
- raise Exception, "address must be valid"
+ raise Exception("address must be valid")
# check size
if not hasattr(self, "size") or self.size == None:
- raise Exception, "don't know how many bytes to read (size)"
+ raise Exception("don't know how many bytes to read (size)")
self.parse()
def parse(self):
@@ -1445,9 +1352,9 @@ class PointerLabelParam(MultiByteParam):
self.size = self.default_size + 1
self.given_bank = kwargs["bank"]
#if kwargs["bank"] not in [None, False, True, "reverse"]:
- # raise Exception, "bank cannot be: " + str(kwargs["bank"])
+ # raise Exception("bank cannot be: " + str(kwargs["bank"]))
if self.size > 3:
- raise Exception, "param size is too large"
+ raise Exception("param size is too large")
# continue instantiation.. self.bank will be set down the road
MultiByteParam.__init__(self, *args, **kwargs)
@@ -1520,15 +1427,16 @@ class PointerLabelParam(MultiByteParam):
return pointer_part+", "+bank_part
elif bank == True: # bank, pointer
return bank_part+", "+pointer_part
- else: raise Exception, "this should never happen"
- raise Exception, "this should never happen"
+ else:
+ raise Exception("this should never happen")
+ raise Exception("this should never happen")
# this next one will either return the label or the raw bytes
elif bank == False or bank == None: # pointer
return pointer_part # this could be the same as label
else:
- #raise Exception, "this should never happen"
+ #raise Exception("this should never happen")
return pointer_part # probably in the same bank ?
- raise Exception, "this should never happen"
+ raise Exception("this should never happen")
class PointerLabelBeforeBank(PointerLabelParam):
bank = True # bank appears first, see calculate_pointer_from_bytes_at
@@ -1549,12 +1457,12 @@ class ScriptPointerLabelBeforeBank(PointerLabelBeforeBank): pass
class ScriptPointerLabelAfterBank(PointerLabelAfterBank): pass
-def _parse_script_pointer_bytes(self):
+def _parse_script_pointer_bytes(self, debug = False):
PointerLabelParam.parse(self)
- print "_parse_script_pointer_bytes - calculating the pointer located at " + hex(self.address)
+ if debug: print "_parse_script_pointer_bytes - calculating the pointer located at " + hex(self.address)
address = calculate_pointer_from_bytes_at(self.address, bank=self.bank)
if address != None and address > 0x4000:
- print "_parse_script_pointer_bytes - the pointer is: " + hex(address)
+ if debug: print "_parse_script_pointer_bytes - the pointer is: " + hex(address)
self.script = parse_script_engine_script_at(address, debug=self.debug, force=self.force, map_group=self.map_group, map_id=self.map_id)
ScriptPointerLabelParam.parse = _parse_script_pointer_bytes
ScriptPointerLabelBeforeBank.parse = _parse_script_pointer_bytes
@@ -1587,8 +1495,10 @@ class RAMAddressParam(MultiByteParam):
def to_asm(self):
address = calculate_pointer_from_bytes_at(self.address, bank=False)
label = get_ram_label(address)
- if label: return label
- else: return "$"+"".join(["%.2x"%x for x in reversed(self.bytes)])+""
+ if label:
+ return label
+ else:
+ return "$"+"".join(["%.2x"%x for x in reversed(self.bytes)])+""
class MoneyByteParam(MultiByteParam):
@@ -1646,9 +1556,11 @@ class MapGroupParam(SingleByteParam):
def to_asm(self):
map_id = ord(rom[self.address+1])
map_constant_label = get_map_constant_label(map_id=map_id, map_group=self.byte) # like PALLET_TOWN
- if map_constant_label == None: return str(self.byte)
+ if map_constant_label == None:
+ return str(self.byte)
#else: return "GROUP("+map_constant_label+")"
- else: return "GROUP_"+map_constant_label
+ else:
+ return "GROUP_"+map_constant_label
class MapIdParam(SingleByteParam):
@@ -1659,9 +1571,11 @@ class MapIdParam(SingleByteParam):
def to_asm(self):
map_group = ord(rom[self.address-1])
map_constant_label = get_map_constant_label(map_id=self.byte, map_group=map_group)
- if map_constant_label == None: return str(self.byte)
+ if map_constant_label == None:
+ return str(self.byte)
#else: return "MAP("+map_constant_label+")"
- else: return "MAP_"+map_constant_label
+ else:
+ return "MAP_"+map_constant_label
class MapGroupIdParam(MultiByteParam):
@@ -1680,13 +1594,15 @@ class MapGroupIdParam(MultiByteParam):
class PokemonParam(SingleByteParam):
def to_asm(self):
pokemon_constant = get_pokemon_constant_by_id(self.byte)
- if pokemon_constant: return pokemon_constant
- else: return str(self.byte)
+ if pokemon_constant:
+ return pokemon_constant
+ else:
+ return str(self.byte)
class PointerParamToItemAndLetter(MultiByteParam):
# [2F][2byte pointer to item no + 0x20 bytes letter text]
- #raise NotImplementedError, bryan_message
+ #raise NotImplementedError(bryan_message)
pass
@@ -1702,7 +1618,7 @@ class TrainerIdParam(SingleByteParam):
i += 1
if foundit == None:
- raise Exception, "didn't find a TrainerGroupParam in this command??"
+ raise Exception("didn't find a TrainerGroupParam in this command??")
# now get the trainer group id
trainer_group_id = self.parent.params[foundit].byte
@@ -1729,7 +1645,7 @@ class MoveParam(SingleByteParam):
class MenuDataPointerParam(PointerLabelParam):
# read menu data at the target site
- #raise NotImplementedError, bryan_message
+ #raise NotImplementedError(bryan_message)
pass
@@ -1813,7 +1729,7 @@ class MovementPointerLabelParam(PointerLabelParam):
global_dependencies.add(self.movement)
return [self.movement] + self.movement.get_dependencies(recompute=recompute, global_dependencies=global_dependencies)
else:
- raise Exception, "MovementPointerLabelParam hasn't been parsed yet"
+ raise Exception("MovementPointerLabelParam hasn't been parsed yet")
class MapDataPointerParam(PointerLabelParam):
pass
@@ -1838,7 +1754,7 @@ class Command:
"""
defaults = {"force": False, "debug": False, "map_group": None, "map_id": None}
if not is_valid_address(address):
- raise Exception, "address is invalid"
+ raise Exception("address is invalid")
# set up some variables
self.address = address
self.last_address = None
@@ -1878,7 +1794,8 @@ class Command:
# output += "_"
output += self.macro_name
# return if there are no params
- if len(self.param_types.keys()) == 0: return output
+ if len(self.param_types.keys()) == 0:
+ return output
# first one will have no prefixing comma
first = True
# start reading the bytes after the command byte
@@ -1923,7 +1840,7 @@ class Command:
current_address = self.address
byte = ord(rom[self.address])
if not self.override_byte_check and (not byte == self.id):
- raise Exception, "byte ("+hex(byte)+") != self.id ("+hex(self.id)+")"
+ raise Exception("byte ("+hex(byte)+") != self.id ("+hex(self.id)+")")
i = 0
for (key, param_type) in self.param_types.items():
name = param_type["name"]
@@ -1959,7 +1876,7 @@ class GivePoke(Command):
self.params = {}
byte = ord(rom[self.address])
if not byte == self.id:
- raise Exception, "this should never happen"
+ raise Exception("this should never happen")
current_address = self.address+1
i = 0
self.size = 1
@@ -2120,7 +2037,8 @@ def create_movement_commands(debug=False):
direction = "left"
elif x == 3:
direction = "right"
- else: raise Exception, "this should never happen"
+ else:
+ raise Exception("this should never happen")
cmd_name = cmd[0].replace(" ", "_") + "_" + direction
klass_name = cmd_name+"Command"
@@ -2355,11 +2273,11 @@ class MainText(TextCommand):
print "bytes are: " + str(self.bytes)
print "self.size is: " + str(self.size)
print "self.last_address is: " + hex(self.last_address)
- raise Exception, "last_address is wrong for 0x9c00e"
+ raise Exception("last_address is wrong for 0x9c00e")
def to_asm(self):
if self.size < 2 or len(self.bytes) < 1:
- raise Exception, "$0 text command can't end itself with no follow-on bytes"
+ raise Exception("$0 text command can't end itself with no follow-on bytes")
if self.use_zero:
output = "db $0"
@@ -2390,13 +2308,13 @@ class MainText(TextCommand):
for byte in self.bytes:
if end:
- raise Exception, "the text ended due to a $50 or $57 but there are more bytes?"
+ raise Exception("the text ended due to a $50 or $57 but there are more bytes?")
if new_line:
if in_quotes:
- raise Exception, "can't be in_quotes on a newline"
+ raise Exception("can't be in_quotes on a newline")
elif was_comma:
- raise Exception, "last line's last character can't be a comma"
+ raise Exception("last line's last character can't be a comma")
output += "db "
@@ -2490,7 +2408,7 @@ class MainText(TextCommand):
was_comma = False
end = False
else:
- # raise Exception, "unknown byte in text script ($%.2x)" % (byte)
+ # raise Exception("unknown byte in text script ($%.2x)" % (byte))
# just add an unknown byte directly to the text.. what's the worse that can happen?
if in_quotes:
@@ -2511,7 +2429,7 @@ class MainText(TextCommand):
# this shouldn't happen because of the rom_until calls in the parse method
if not end:
- raise Exception, "ran out of bytes without the script ending? starts at "+hex(self.address)
+ raise Exception("ran out of bytes without the script ending? starts at "+hex(self.address))
# last character may or may not be allowed to be a newline?
# Script.to_asm() has command.to_asm()+"\n"
@@ -2817,6 +2735,7 @@ pksv_crystal_more = {
0x4F: ["loadmenudata", ["data", MenuDataPointerParam]],
0x50: ["writebackup"],
0x51: ["jumptextfaceplayer", ["text_pointer", RawTextPointerLabelParam]],
+ 0x52: ["3jumptext", ["text_pointer", PointerLabelBeforeBank]],
0x53: ["jumptext", ["text_pointer", RawTextPointerLabelParam]],
0x54: ["closetext"],
0x55: ["keeptextopen"],
@@ -3088,17 +3007,17 @@ class Script:
self.address = None
self.commands = None
if len(kwargs) == 0 and len(args) == 0:
- raise Exception, "Script.__init__ must be given some arguments"
+ raise Exception("Script.__init__ must be given some arguments")
# first positional argument is address
if len(args) == 1:
address = args[0]
if type(address) == str:
address = int(address, 16)
elif type(address) != int:
- raise Exception, "address must be an integer or string"
+ raise Exception("address must be an integer or string")
self.address = address
elif len(args) > 1:
- raise Exception, "don't know what to do with second (or later) positional arguments"
+ raise Exception("don't know what to do with second (or later) positional arguments")
self.dependencies = None
if "label" in kwargs.keys():
label = kwargs["label"]
@@ -3160,15 +3079,15 @@ class Script:
"""
global command_classes, rom, script_parse_table
current_address = start_address
- print "Script.parse address="+hex(self.address) +" map_group="+str(map_group)+" map_id="+str(map_id)
+ if debug: print "Script.parse address="+hex(self.address) +" map_group="+str(map_group)+" map_id="+str(map_id)
if start_address in stop_points and force == False:
- print "script parsing is stopping at stop_point=" + hex(start_address) + " at map_group="+str(map_group)+" map_id="+str(map_id)
+ if debug: print "script parsing is stopping at stop_point=" + hex(start_address) + " at map_group="+str(map_group)+" map_id="+str(map_id)
return None
if start_address < 0x4000 and start_address not in [0x26ef, 0x114, 0x1108]:
- print "address is less than 0x4000.. address is: " + hex(start_address)
+ if debug: print "address is less than 0x4000.. address is: " + hex(start_address)
sys.exit(1)
if is_script_already_parsed_at(start_address) and not force and not force_top:
- raise Exception, "this script has already been parsed before, please use that instance ("+hex(start_address)+")"
+ raise Exception("this script has already been parsed before, please use that instance ("+hex(start_address)+")")
# load up the rom if it hasn't been loaded already
load_rom()
@@ -3198,13 +3117,13 @@ class Script:
# no matching command found (not implemented yet)- just end this script
# NOTE: might be better to raise an exception and end the program?
if scripting_command_class == None:
- print "parsing script; current_address is: " + hex(current_address)
+ if debug: print "parsing script; current_address is: " + hex(current_address)
current_address += 1
asm_output = "\n".join([command.to_asm() for command in commands])
end = True
continue
# maybe the program should exit with failure instead?
- #raise Exception, "no command found? id: " + hex(cur_byte) + " at " + hex(current_address) + " asm is:\n" + asm_output
+ #raise Exception("no command found? id: " + hex(cur_byte) + " at " + hex(current_address) + " asm is:\n" + asm_output)
# create an instance of the command class and let it parse its parameter bytes
#print "about to parse command(script@"+hex(start_address)+"): " + str(scripting_command_class.macro_name)
@@ -3231,7 +3150,7 @@ class Script:
script_parse_table[start_address:current_address] = self
asm_output = "\n".join([command.to_asm() for command in commands])
- print "--------------\n"+asm_output
+ if debug: print "--------------\n"+asm_output
# store the script
self.commands = commands
@@ -3529,7 +3448,8 @@ class TrainerFragment(Command):
def get_dependencies(self, recompute=False, global_dependencies=set()):
deps = []
- if not is_valid_address(self.address): return deps
+ if not is_valid_address(self.address):
+ return deps
if self.dependencies != None and not recompute:
global_dependencies.update(self.dependencies)
return self.dependencies
@@ -3867,7 +3787,7 @@ class TrainerHeader:
break
if party_mon_parser == None:
- raise Exception, "no trainer party mon parser found to parse data type " + hex(self.data_type)
+ raise Exception("no trainer party mon parser found to parse data type " + hex(self.data_type))
self.party_mons = party_mon_parser(address=current_address, group_id=self.trainer_group_id, trainer_id=self.trainer_id, parent=self)
@@ -4422,7 +4342,8 @@ class SignpostRemoteBase:
def to_asm(self):
"""very similar to Command.to_asm"""
- if len(self.params) == 0: return ""
+ if len(self.params) == 0:
+ return ""
#output = ", ".join([p.to_asm() for p in self.params])
output = ""
for param in self.params:
@@ -4670,7 +4591,7 @@ class Signpost(Command):
mb = PointerLabelParam(address=self.address+3, map_group=self.map_group, map_id=self.map_id, debug=self.debug)
self.params.append(mb)
else:
- raise Exception, "unknown signpost type byte="+hex(func) + " signpost@"+hex(self.address)
+ raise Exception("unknown signpost type byte="+hex(func) + " signpost@"+hex(self.address))
def get_dependencies(self, recompute=False, global_dependencies=set()):
dependencies = []
@@ -4684,13 +4605,15 @@ class Signpost(Command):
def to_asm(self):
output = self.macro_name + " "
- if self.params == []: raise Exception, "signpost has no params?"
+ if self.params == []:
+ raise Exception("signpost has no params?")
output += ", ".join([p.to_asm() for p in self.params])
return output
all_signposts = []
def parse_signposts(address, signpost_count, bank=None, map_group=None, map_id=None, debug=True):
- if bank == None: raise Exception, "signposts need to know their bank"
+ if bank == None:
+ raise Exception("signposts need to know their bank")
signposts = []
current_address = address
id = 0
@@ -5273,7 +5196,7 @@ class Connection:
wrong_norths.append(data)
# this will only happen if there's a bad formula
- raise Exception, "tauwasser strip_pointer calculation was wrong? strip_pointer="+hex(strip_pointer) + " p="+hex(p)
+ raise Exception("tauwasser strip_pointer calculation was wrong? strip_pointer="+hex(strip_pointer) + " p="+hex(p))
calculated_destination = None
method = "strip_destination_default"
@@ -5295,7 +5218,7 @@ class Connection:
x_movement_of_the_connection_strip_in_blocks = strip_destination - 0xC703
print "(north) x_movement_of_the_connection_strip_in_blocks is: " + str(x_movement_of_the_connection_strip_in_blocks)
if x_movement_of_the_connection_strip_in_blocks < 0:
- raise Exception, "x_movement_of_the_connection_strip_in_blocks is wrong? " + str(x_movement_of_the_connection_strip_in_blocks)
+ raise Exception("x_movement_of_the_connection_strip_in_blocks is wrong? " + str(x_movement_of_the_connection_strip_in_blocks))
elif ldirection == "south":
# strip_destination =
# 0xc703 + (current_map_height + 3) * (current_map_width + 6) + x_movement_of_the_connection_strip_in_blocks
@@ -5570,11 +5493,11 @@ class Connection:
yoffset = self.yoffset # y_position_after_map_change
if ldirection == "south" and yoffset != 0:
- raise Exception, "tauwasser was wrong about yoffset=0 for south? it's: " + str(yoffset)
+ raise Exception("tauwasser was wrong about yoffset=0 for south? it's: " + str(yoffset))
elif ldirection == "north" and yoffset != ((connected_map_height * 2) - 1):
- raise Exception, "tauwasser was wrong about yoffset for north? it's: " + str(yoffset)
+ raise Exception("tauwasser was wrong about yoffset for north? it's: " + str(yoffset))
#elif not ((yoffset % -2) == 0):
- # raise Exception, "tauwasser was wrong about yoffset for west/east? it's not divisible by -2: " + str(yoffset)
+ # raise Exception("tauwasser was wrong about yoffset for west/east? it's not divisible by -2: " + str(yoffset))
# Left: (Width_of_connected_map * 2) - 1
# Right: 0
@@ -5582,11 +5505,11 @@ class Connection:
xoffset = self.xoffset # x_position_after_map_change
if ldirection == "east" and xoffset != 0:
- raise Exception, "tauwasser was wrong about xoffset=0 for east? it's: " + str(xoffset)
+ raise Exception("tauwasser was wrong about xoffset=0 for east? it's: " + str(xoffset))
elif ldirection == "west" and xoffset != ((connected_map_width * 2) - 1):
- raise Exception, "tauwasser was wrong about xoffset for west? it's: " + str(xoffset)
+ raise Exception("tauwasser was wrong about xoffset for west? it's: " + str(xoffset))
#elif not ((xoffset % -2) == 0):
- # raise Exception, "tauwasser was wrong about xoffset for north/south? it's not divisible by -2: " + str(xoffset)
+ # raise Exception("tauwasser was wrong about xoffset for north/south? it's not divisible by -2: " + str(xoffset))
output += "db "
@@ -5703,7 +5626,7 @@ class MapBlockData:
self.width = width
self.height = height
else:
- raise Exception, "MapBlockData needs to know the width/height of its map"
+ raise Exception("MapBlockData needs to know the width/height of its map")
label = self.make_label()
self.label = Label(name=label, address=address, object=self)
self.last_address = self.address + (self.width.byte * self.height.byte)
@@ -6270,14 +6193,14 @@ def parse_map_header_by_id(*args, **kwargs):
map_id = kwargs["map_id"]
if (map_group == None and map_id != None) or \
(map_group != None and map_id == None):
- raise Exception, "map_group and map_id must both be provided"
+ raise Exception("map_group and map_id must both be provided")
elif map_group == None and map_id == None and len(args) == 0:
- raise Exception, "must be given an argument"
+ raise Exception("must be given an argument")
elif len(args) == 1 and type(args[0]) == str:
map_group = int(args[0].split(".")[0])
map_id = int(args[0].split(".")[1])
else:
- raise Exception, "dunno what to do with input"
+ raise Exception("dunno what to do with input")
offset = map_names[map_group]["offset"]
map_header_offset = offset + ((map_id - 1) * map_header_byte_size)
return parse_map_header_at(map_header_offset, map_group=map_group, map_id=map_id)
@@ -6286,7 +6209,7 @@ def parse_all_map_headers(debug=True):
"""calls parse_map_header_at for each map in each map group"""
global map_names
if not map_names[1].has_key("offset"):
- raise Exception, "dunno what to do - map_names should have groups with pre-calculated offsets by now"
+ raise Exception("dunno what to do - map_names should have groups with pre-calculated offsets by now")
for group_id, group_data in map_names.items():
offset = group_data["offset"]
# we only care about the maps
@@ -7045,7 +6968,7 @@ def find_incbin_to_replace_for(address, debug=False, rom_file="../baserom.gbc"):
if you were to insert bytes into main.asm"""
if type(address) == str: address = int(address, 16)
if not (0 <= address <= os.lstat(rom_file).st_size):
- raise IndexError, "address is out of bounds"
+ raise IndexError("address is out of bounds")
for incbin_key in processed_incbins.keys():
incbin = processed_incbins[incbin_key]
start = incbin["start"]
@@ -7069,9 +6992,9 @@ def split_incbin_line_into_three(line, start_address, byte_count, rom_file="../b
"""
if type(start_address) == str: start_address = int(start_address, 16)
if not (0 <= start_address <= os.lstat(rom_file).st_size):
- raise IndexError, "start_address is out of bounds"
+ raise IndexError("start_address is out of bounds")
if len(processed_incbins) == 0:
- raise Exception, "processed_incbins must be populated"
+ raise Exception("processed_incbins must be populated")
original_incbin = processed_incbins[line]
start = original_incbin["start"]
@@ -7191,7 +7114,7 @@ class Incbin:
start = eval(start)
except Exception, e:
print "start is: " + str(start)
- raise Exception, "problem with evaluating interval range: " + str(e)
+ raise Exception("problem with evaluating interval range: " + str(e))
start_hex = hex(start).replace("0x", "$")
@@ -7212,11 +7135,12 @@ class Incbin:
def to_asm(self):
if self.interval > 0:
return self.line
- else: return ""
+ else:
+ return ""
def split(self, start_address, byte_count):
"""splits this incbin into three separate incbins"""
if start_address < self.start_address or start_address > self.end_address:
- raise Exception, "this incbin doesn't handle this address"
+ raise Exception("this incbin doesn't handle this address")
incbins = []
if self.debug:
@@ -7358,7 +7282,7 @@ class Asm:
if not hasattr(new_object, "last_address"):
print debugmsg
- raise Exception, "object needs to have a last_address property"
+ raise Exception("object needs to have a last_address property")
end_address = new_object.last_address
debugmsg += " last_address="+hex(end_address)
@@ -7384,7 +7308,7 @@ class Asm:
print "start_address="+hex(start_address)+" end_address="+hex(end_address)
if hasattr(new_object, "to_asm"):
print to_asm(new_object)
- raise Exception, "Asm.insert was given an object with a bad address range"
+ raise Exception("Asm.insert was given an object with a bad address range")
# 1) find which object needs to be replaced
# or
@@ -7426,7 +7350,7 @@ class Asm:
found = True
break
if not found:
- raise Exception, "unable to insert object into Asm"
+ raise Exception("unable to insert object into Asm")
self.labels.append(new_object.label)
return True
def insert_with_dependencies(self, input):
@@ -7458,9 +7382,9 @@ class Asm:
# just some old debugging
#if object.label.name == "UnknownText_0x60128":
- # raise Exception, "debugging..."
+ # raise Exception("debugging...")
#elif object.label.name == "UnknownScript_0x60011":
- # raise Exception, "debugging.. dependencies are: " + str(object.dependencies) + " versus: " + str(object.get_dependencies())
+ # raise Exception("debugging.. dependencies are: " + str(object.dependencies) + " versus: " + str(object.get_dependencies()))
def insert_single_with_dependencies(self, object):
self.insert_with_dependencies(object)
def insert_multiple_with_dependencies(self, objects):
@@ -7516,7 +7440,7 @@ class Asm:
current_requested_newlines_before = 2
current_requested_newlines_after = 2
else:
- raise Exception, "dunno what to do with("+str(each)+") in Asm.parts"
+ raise Exception("dunno what to do with("+str(each)+") in Asm.parts")
if write_something:
if not first:
@@ -7546,7 +7470,7 @@ def list_texts_in_bank(bank):
that you will be inserting into Asm.
"""
if len(all_texts) == 0:
- raise Exception, "all_texts is blank.. run_main() will populate it"
+ raise Exception("all_texts is blank.. run_main() will populate it")
assert bank != None, "list_texts_in_banks must be given a particular bank"
@@ -7564,7 +7488,7 @@ def list_movements_in_bank(bank):
to speed up Asm insertion.
"""
if len(all_movements) == 0:
- raise Exception, "all_movements is blank.. run_main() will populate it"
+ raise Exception("all_movements is blank.. run_main() will populate it")
assert bank != None, "list_movements_in_bank must be given a particular bank"
assert 0 <= bank < 0x80, "bank doesn't exist in the ROM (out of bounds)"
@@ -7673,7 +7597,7 @@ def get_label_for(address):
if address == None:
return None
if type(address) != int:
- raise Exception, "get_label_for requires an integer address, got: " + str(type(address))
+ raise Exception("get_label_for requires an integer address, got: " + str(type(address)))
# lousy hack to get around recursive scripts in dragon shrine
if address in lousy_dragon_shrine_hack:
@@ -7787,10 +7711,6 @@ class Label:
name = object.make_label()
return name
-from labels import remove_quoted_text, line_has_comment_address, \
- line_has_label, get_label_from_line, \
- get_address_from_line_comment
-
def find_labels_without_addresses():
"""scans the asm source and finds labels that are unmarked"""
without_addresses = []
@@ -7872,14 +7792,9 @@ def scan_for_predefined_labels(debug=False):
abbreviation_next = "1"
# calculate the start/stop line numbers for this bank
- for a in (abbreviation, abbreviation.lower()):
- start_line_id = index(asm, lambda line: "\"bank" + a + "\"" in line)
- if start_line_id != None: break
-
+ start_line_id = index(asm, lambda line: "\"bank" + abbreviation.lower() + "\"" in line.lower())
if bank_id != 0x7F:
- for a in (abbreviation_next, abbreviation_next.lower()):
- end_line_id = index(asm, lambda line: "\"bank" + a + "\"" in line)
- if end_line_id != None: break
+ end_line_id = index(asm, lambda line: "\"bank" + abbreviation_next.lower() + "\"" in line.lower())
end_line_id += 1
else:
end_line_id = len(asm) - 1
@@ -7907,924 +7822,6 @@ def scan_for_predefined_labels(debug=False):
write_all_labels(all_labels)
return all_labels
-#### generic testing ####
-
-class TestCram(unittest.TestCase):
- "this is where i cram all of my unit tests together"
-
- @classmethod
- def setUpClass(cls):
- global rom
- cls.rom = direct_load_rom()
- rom = cls.rom
-
- @classmethod
- def tearDownClass(cls):
- del cls.rom
-
- def test_generic_useless(self):
- "do i know how to write a test?"
- self.assertEqual(1, 1)
-
- def test_map_name_cleaner(self):
- name = "hello world"
- cleaned_name = map_name_cleaner(name)
- self.assertNotEqual(name, cleaned_name)
- self.failUnless(" " not in cleaned_name)
- name = "Some Random Pokémon Center"
- cleaned_name = map_name_cleaner(name)
- self.assertNotEqual(name, cleaned_name)
- self.failIf(" " in cleaned_name)
- self.failIf("é" in cleaned_name)
-
- def test_grouper(self):
- data = range(0, 10)
- groups = grouper(data, count=2)
- self.assertEquals(len(groups), 5)
- data = range(0, 20)
- groups = grouper(data, count=2)
- self.assertEquals(len(groups), 10)
- self.assertNotEqual(data, groups)
- self.assertNotEqual(len(data), len(groups))
-
- def test_direct_load_rom(self):
- rom = self.rom
- self.assertEqual(len(rom), 2097152)
- self.failUnless(isinstance(rom, RomStr))
-
- def test_load_rom(self):
- global rom
- rom = None
- load_rom()
- self.failIf(rom == None)
- rom = RomStr(None)
- load_rom()
- self.failIf(rom == RomStr(None))
-
- def test_load_asm(self):
- asm = load_asm()
- joined_lines = "\n".join(asm)
- self.failUnless("SECTION" in joined_lines)
- self.failUnless("bank" in joined_lines)
- self.failUnless(isinstance(asm, AsmList))
-
- def test_rom_file_existence(self):
- "ROM file must exist"
- self.failUnless("baserom.gbc" in os.listdir("../"))
-
- def test_rom_md5(self):
- "ROM file must have the correct md5 sum"
- rom = self.rom
- correct = "9f2922b235a5eeb78d65594e82ef5dde"
- md5 = hashlib.md5()
- md5.update(rom)
- md5sum = md5.hexdigest()
- self.assertEqual(md5sum, correct)
-
- def test_bizarre_http_presence(self):
- rom_segment = self.rom[0x112116:0x112116+8]
- self.assertEqual(rom_segment, "HTTP/1.0")
-
- def test_rom_interval(self):
- address = 0x100
- interval = 10
- correct_strings = ['0x0', '0xc3', '0x6e', '0x1', '0xce',
- '0xed', '0x66', '0x66', '0xcc', '0xd']
- byte_strings = rom_interval(address, interval, strings=True)
- self.assertEqual(byte_strings, correct_strings)
- correct_ints = [0, 195, 110, 1, 206, 237, 102, 102, 204, 13]
- ints = rom_interval(address, interval, strings=False)
- self.assertEqual(ints, correct_ints)
-
- def test_rom_until(self):
- address = 0x1337
- byte = 0x13
- bytes = rom_until(address, byte, strings=True)
- self.failUnless(len(bytes) == 3)
- self.failUnless(bytes[0] == '0xd5')
- bytes = rom_until(address, byte, strings=False)
- self.failUnless(len(bytes) == 3)
- self.failUnless(bytes[0] == 0xd5)
-
- def test_how_many_until(self):
- how_many = how_many_until(chr(0x13), 0x1337)
- self.assertEqual(how_many, 3)
-
- def test_calculate_bank(self):
- self.failUnless(calculate_bank(0x8000) == 2)
- self.failUnless(calculate_bank("0x9000") == 2)
- self.failUnless(calculate_bank(0) == 0)
- for address in [0x4000, 0x5000, 0x6000, 0x7000]:
- self.assertRaises(Exception, calculate_bank, address)
-
- def test_calculate_pointer(self):
- # for offset <= 0x4000
- self.assertEqual(calculate_pointer(0x0000), 0x0000)
- self.assertEqual(calculate_pointer(0x3FFF), 0x3FFF)
- # for 0x4000 <= offset <= 0x7FFFF
- self.assertEqual(calculate_pointer(0x430F, bank=5), 0x1430F)
- # for offset >= 0x7FFF
- self.assertEqual(calculate_pointer(0x8FFF, bank=6), calculate_pointer(0x8FFF, bank=7))
-
- def test_calculate_pointer_from_bytes_at(self):
- addr1 = calculate_pointer_from_bytes_at(0x100, bank=False)
- self.assertEqual(addr1, 0xc300)
- addr2 = calculate_pointer_from_bytes_at(0x100, bank=True)
- self.assertEqual(addr2, 0x2ec3)
-
- def test_rom_text_at(self):
- self.assertEquals(rom_text_at(0x112116, 8), "HTTP/1.0")
-
- def test_translate_command_byte(self):
- self.failUnless(translate_command_byte(crystal=0x0) == 0x0)
- self.failUnless(translate_command_byte(crystal=0x10) == 0x10)
- self.failUnless(translate_command_byte(crystal=0x40) == 0x40)
- self.failUnless(translate_command_byte(gold=0x0) == 0x0)
- self.failUnless(translate_command_byte(gold=0x10) == 0x10)
- self.failUnless(translate_command_byte(gold=0x40) == 0x40)
- self.assertEqual(translate_command_byte(gold=0x0), translate_command_byte(crystal=0x0))
- self.failUnless(translate_command_byte(gold=0x52) == 0x53)
- self.failUnless(translate_command_byte(gold=0x53) == 0x54)
- self.failUnless(translate_command_byte(crystal=0x53) == 0x52)
- self.failUnless(translate_command_byte(crystal=0x52) == None)
- self.assertRaises(Exception, translate_command_byte, None, gold=0xA4)
-
- def test_pksv_integrity(self):
- "does pksv_gs look okay?"
- self.assertEqual(pksv_gs[0x00], "2call")
- self.assertEqual(pksv_gs[0x2D], "givepoke")
- self.assertEqual(pksv_gs[0x85], "waitbutton")
- self.assertEqual(pksv_crystal[0x00], "2call")
- self.assertEqual(pksv_crystal[0x86], "waitbutton")
- self.assertEqual(pksv_crystal[0xA2], "credits")
-
- def test_chars_integrity(self):
- self.assertEqual(chars[0x80], "A")
- self.assertEqual(chars[0xA0], "a")
- self.assertEqual(chars[0xF0], "¥")
- self.assertEqual(jap_chars[0x44], "ぱ")
-
- def test_map_names_integrity(self):
- def map_name(map_group, map_id): return map_names[map_group][map_id]["name"]
- self.assertEqual(map_name(2, 7), "Mahogany Town")
- self.assertEqual(map_name(3, 0x34), "Ilex Forest")
- self.assertEqual(map_name(7, 0x11), "Cerulean City")
-
- def test_load_map_group_offsets(self):
- addresses = load_map_group_offsets()
- self.assertEqual(len(addresses), 26, msg="there should be 26 map groups")
- addresses = load_map_group_offsets()
- self.assertEqual(len(addresses), 26, msg="there should still be 26 map groups")
- self.assertIn(0x94034, addresses)
- for address in addresses:
- self.assertGreaterEqual(address, 0x4000)
- self.failIf(0x4000 <= address <= 0x7FFF)
- self.failIf(address <= 0x4000)
-
- def test_index(self):
- self.assertTrue(index([1,2,3,4], lambda f: True) == 0)
- self.assertTrue(index([1,2,3,4], lambda f: f==3) == 2)
-
- def test_get_pokemon_constant_by_id(self):
- x = get_pokemon_constant_by_id
- self.assertEqual(x(1), "BULBASAUR")
- self.assertEqual(x(151), "MEW")
- self.assertEqual(x(250), "HO_OH")
-
- def test_find_item_label_by_id(self):
- x = find_item_label_by_id
- self.assertEqual(x(249), "HM_07")
- self.assertEqual(x(173), "BERRY")
- self.assertEqual(x(45), None)
-
- def test_generate_item_constants(self):
- x = generate_item_constants
- r = x()
- self.failUnless("HM_07" in r)
- self.failUnless("EQU" in r)
-
- def test_get_label_for(self):
- global all_labels
- temp = copy(all_labels)
- # this is basd on the format defined in get_labels_between
- all_labels = [{"label": "poop", "address": 0x5,
- "offset": 0x5, "bank": 0,
- "line_number": 2
- }]
- self.assertEqual(get_label_for(5), "poop")
- all_labels = temp
-
- def test_generate_map_constant_labels(self):
- ids = generate_map_constant_labels()
- self.assertEqual(ids[0]["label"], "OLIVINE_POKECENTER_1F")
- self.assertEqual(ids[1]["label"], "OLIVINE_GYM")
-
- def test_get_id_for_map_constant_label(self):
- global map_internal_ids
- map_internal_ids = generate_map_constant_labels()
- self.assertEqual(get_id_for_map_constant_label("OLIVINE_GYM"), 1)
- self.assertEqual(get_id_for_map_constant_label("OLIVINE_POKECENTER_1F"), 0)
-
- def test_get_map_constant_label_by_id(self):
- global map_internal_ids
- map_internal_ids = generate_map_constant_labels()
- self.assertEqual(get_map_constant_label_by_id(0), "OLIVINE_POKECENTER_1F")
- self.assertEqual(get_map_constant_label_by_id(1), "OLIVINE_GYM")
-
- def test_is_valid_address(self):
- self.assertTrue(is_valid_address(0))
- self.assertTrue(is_valid_address(1))
- self.assertTrue(is_valid_address(10))
- self.assertTrue(is_valid_address(100))
- self.assertTrue(is_valid_address(1000))
- self.assertTrue(is_valid_address(10000))
- self.assertFalse(is_valid_address(2097153))
- self.assertFalse(is_valid_address(2098000))
- addresses = [random.randrange(0,2097153) for i in range(0, 9+1)]
- for address in addresses:
- self.assertTrue(is_valid_address(address))
-
-
-class TestIntervalMap(unittest.TestCase):
- def test_intervals(self):
- i = IntervalMap()
- first = "hello world"
- second = "testing 123"
- i[0:5] = first
- i[5:10] = second
- self.assertEqual(i[0], first)
- self.assertEqual(i[1], first)
- self.assertNotEqual(i[5], first)
- self.assertEqual(i[6], second)
- i[3:10] = second
- self.assertEqual(i[3], second)
- self.assertNotEqual(i[4], first)
-
- def test_items(self):
- i = IntervalMap()
- first = "hello world"
- second = "testing 123"
- i[0:5] = first
- i[5:10] = second
- results = list(i.items())
- self.failUnless(len(results) == 2)
- self.assertEqual(results[0], ((0, 5), "hello world"))
- self.assertEqual(results[1], ((5, 10), "testing 123"))
-
-
-class TestRomStr(unittest.TestCase):
- """RomStr is a class that should act exactly like str()
- except that it never shows the contents of it string
- unless explicitly forced"""
- sample_text = "hello world!"
- sample = None
-
- def setUp(self):
- if self.sample == None:
- self.__class__.sample = RomStr(self.sample_text)
-
- def test_equals(self):
- "check if RomStr() == str()"
- self.assertEquals(self.sample_text, self.sample)
-
- def test_not_equal(self):
- "check if RomStr('a') != RomStr('b')"
- self.assertNotEqual(RomStr('a'), RomStr('b'))
-
- def test_appending(self):
- "check if RomStr()+'a'==str()+'a'"
- self.assertEquals(self.sample_text+'a', self.sample+'a')
-
- def test_conversion(self):
- "check if RomStr() -> str() works"
- self.assertEquals(str(self.sample), self.sample_text)
-
- def test_inheritance(self):
- self.failUnless(issubclass(RomStr, str))
-
- def test_length(self):
- self.assertEquals(len(self.sample_text), len(self.sample))
- self.assertEquals(len(self.sample_text), self.sample.length())
- self.assertEquals(len(self.sample), self.sample.length())
-
- def test_rom_interval(self):
- global rom
- load_rom()
- address = 0x100
- interval = 10
- correct_strings = ['0x0', '0xc3', '0x6e', '0x1', '0xce',
- '0xed', '0x66', '0x66', '0xcc', '0xd']
- byte_strings = rom.interval(address, interval, strings=True)
- self.assertEqual(byte_strings, correct_strings)
- correct_ints = [0, 195, 110, 1, 206, 237, 102, 102, 204, 13]
- ints = rom.interval(address, interval, strings=False)
- self.assertEqual(ints, correct_ints)
-
- def test_rom_until(self):
- global rom
- load_rom()
- address = 0x1337
- byte = 0x13
- bytes = rom.until(address, byte, strings=True)
- self.failUnless(len(bytes) == 3)
- self.failUnless(bytes[0] == '0xd5')
- bytes = rom.until(address, byte, strings=False)
- self.failUnless(len(bytes) == 3)
- self.failUnless(bytes[0] == 0xd5)
-
-
-class TestAsmList(unittest.TestCase):
- """AsmList is a class that should act exactly like list()
- except that it never shows the contents of its list
- unless explicitly forced"""
-
- def test_equals(self):
- base = [1,2,3]
- asm = AsmList(base)
- self.assertEquals(base, asm)
- self.assertEquals(asm, base)
- self.assertEquals(base, list(asm))
-
- def test_inheritance(self):
- self.failUnless(issubclass(AsmList, list))
-
- def test_length(self):
- base = range(0, 10)
- asm = AsmList(base)
- self.assertEquals(len(base), len(asm))
- self.assertEquals(len(base), asm.length())
- self.assertEquals(len(base), len(list(asm)))
- self.assertEquals(len(asm), asm.length())
-
- def test_remove_quoted_text(self):
- x = remove_quoted_text
- self.assertEqual(x("hello world"), "hello world")
- self.assertEqual(x("hello \"world\""), "hello ")
- input = 'hello world "testing 123"'
- self.assertNotEqual(x(input), input)
- input = "hello world 'testing 123'"
- self.assertNotEqual(x(input), input)
- self.failIf("testing" in x(input))
-
- def test_line_has_comment_address(self):
- x = line_has_comment_address
- self.assertFalse(x(""))
- self.assertFalse(x(";"))
- self.assertFalse(x(";;;"))
- self.assertFalse(x(":;"))
- self.assertFalse(x(":;:"))
- self.assertFalse(x(";:"))
- self.assertFalse(x(" "))
- self.assertFalse(x("".join(" " * 5)))
- self.assertFalse(x("".join(" " * 10)))
- self.assertFalse(x("hello world"))
- self.assertFalse(x("hello_world"))
- self.assertFalse(x("hello_world:"))
- self.assertFalse(x("hello_world:;"))
- self.assertFalse(x("hello_world: ;"))
- self.assertFalse(x("hello_world: ; "))
- self.assertFalse(x("hello_world: ;" + "".join(" " * 5)))
- self.assertFalse(x("hello_world: ;" + "".join(" " * 10)))
- self.assertTrue(x(";1"))
- self.assertTrue(x(";F"))
- self.assertTrue(x(";$00FF"))
- self.assertTrue(x(";0x00FF"))
- self.assertTrue(x("; 0x00FF"))
- self.assertTrue(x(";$3:$300"))
- self.assertTrue(x(";0x3:$300"))
- self.assertTrue(x(";$3:0x300"))
- self.assertTrue(x(";3:300"))
- self.assertTrue(x(";3:FFAA"))
- self.assertFalse(x('hello world "how are you today;0x1"'))
- self.assertTrue(x('hello world "how are you today:0x1";1'))
- returnable = {}
- self.assertTrue(x("hello_world: ; 0x4050", returnable=returnable, bank=5))
- self.assertTrue(returnable["address"] == 0x14050)
-
- def test_line_has_label(self):
- x = line_has_label
- self.assertTrue(x("hi:"))
- self.assertTrue(x("Hello: "))
- self.assertTrue(x("MyLabel: ; test xyz"))
- self.assertFalse(x(":"))
- self.assertFalse(x(";HelloWorld:"))
- self.assertFalse(x("::::"))
- self.assertFalse(x(":;:;:;:::"))
-
- def test_get_label_from_line(self):
- x = get_label_from_line
- self.assertEqual(x("HelloWorld: "), "HelloWorld")
- self.assertEqual(x("HiWorld:"), "HiWorld")
- self.assertEqual(x("HiWorld"), None)
-
- def test_find_labels_without_addresses(self):
- global asm
- asm = ["hello_world: ; 0x1", "hello_world2: ;"]
- labels = find_labels_without_addresses()
- self.failUnless(labels[0]["label"] == "hello_world2")
- asm = ["hello world: ;1", "hello_world: ;2"]
- labels = find_labels_without_addresses()
- self.failUnless(len(labels) == 0)
- asm = None
-
- def test_get_labels_between(self):
- global asm
- x = get_labels_between#(start_line_id, end_line_id, bank)
- asm = ["HelloWorld: ;1",
- "hi:",
- "no label on this line",
- ]
- labels = x(0, 2, 0x12)
- self.assertEqual(len(labels), 1)
- self.assertEqual(labels[0]["label"], "HelloWorld")
- del asm
-
- def test_scan_for_predefined_labels(self):
- # label keys: line_number, bank, label, offset, address
- load_asm()
- all_labels = scan_for_predefined_labels()
- label_names = [x["label"] for x in all_labels]
- self.assertIn("GetFarByte", label_names)
- self.assertIn("AddNTimes", label_names)
- self.assertIn("CheckShininess", label_names)
-
- def test_write_all_labels(self):
- """dumping json into a file"""
- filename = "test_labels.json"
- # remove the current file
- if os.path.exists(filename):
- os.system("rm " + filename)
- # make up some labels
- labels = []
- # fake label 1
- label = {"line_number": 5, "bank": 0, "label": "SomeLabel", "address": 0x10}
- labels.append(label)
- # fake label 2
- label = {"line_number": 15, "bank": 2, "label": "SomeOtherLabel", "address": 0x9F0A}
- labels.append(label)
- # dump to file
- write_all_labels(labels, filename=filename)
- # open the file and read the contents
- file_handler = open(filename, "r")
- contents = file_handler.read()
- file_handler.close()
- # parse into json
- obj = json.read(contents)
- # begin testing
- self.assertEqual(len(obj), len(labels))
- self.assertEqual(len(obj), 2)
- self.assertEqual(obj, labels)
-
- def test_isolate_incbins(self):
- global asm
- asm = ["123", "456", "789", "abc", "def", "ghi",
- 'INCBIN "baserom.gbc",$12DA,$12F8 - $12DA',
- "jkl",
- 'INCBIN "baserom.gbc",$137A,$13D0 - $137A']
- lines = isolate_incbins()
- self.assertIn(asm[6], lines)
- self.assertIn(asm[8], lines)
- for line in lines:
- self.assertIn("baserom", line)
-
- def test_process_incbins(self):
- global incbin_lines, processed_incbins, asm
- incbin_lines = ['INCBIN "baserom.gbc",$12DA,$12F8 - $12DA',
- 'INCBIN "baserom.gbc",$137A,$13D0 - $137A']
- asm = copy(incbin_lines)
- asm.insert(1, "some other random line")
- processed_incbins = process_incbins()
- self.assertEqual(len(processed_incbins), len(incbin_lines))
- self.assertEqual(processed_incbins[0]["line"], incbin_lines[0])
- self.assertEqual(processed_incbins[2]["line"], incbin_lines[1])
-
- def test_reset_incbins(self):
- global asm, incbin_lines, processed_incbins
- # temporarily override the functions
- global load_asm, isolate_incbins, process_incbins
- temp1, temp2, temp3 = load_asm, isolate_incbins, process_incbins
- def load_asm(): pass
- def isolate_incbins(): pass
- def process_incbins(): pass
- # call reset
- reset_incbins()
- # check the results
- self.assertTrue(asm == [] or asm == None)
- self.assertTrue(incbin_lines == [])
- self.assertTrue(processed_incbins == {})
- # reset the original functions
- load_asm, isolate_incbins, process_incbins = temp1, temp2, temp3
-
- def test_find_incbin_to_replace_for(self):
- global asm, incbin_lines, processed_incbins
- asm = ['first line', 'second line', 'third line',
- 'INCBIN "baserom.gbc",$90,$200 - $90',
- 'fifth line', 'last line']
- isolate_incbins()
- process_incbins()
- line_num = find_incbin_to_replace_for(0x100)
- # must be the 4th line (the INBIN line)
- self.assertEqual(line_num, 3)
-
- def test_split_incbin_line_into_three(self):
- global asm, incbin_lines, processed_incbins
- asm = ['first line', 'second line', 'third line',
- 'INCBIN "baserom.gbc",$90,$200 - $90',
- 'fifth line', 'last line']
- isolate_incbins()
- process_incbins()
- content = split_incbin_line_into_three(3, 0x100, 10)
- # must end up with three INCBINs in output
- self.failUnless(content.count("INCBIN") == 3)
-
- def test_analyze_intervals(self):
- global asm, incbin_lines, processed_incbins
- asm, incbin_lines, processed_incbins = None, [], {}
- asm = ['first line', 'second line', 'third line',
- 'INCBIN "baserom.gbc",$90,$200 - $90',
- 'fifth line', 'last line',
- 'INCBIN "baserom.gbc",$33F,$4000 - $33F']
- isolate_incbins()
- process_incbins()
- largest = analyze_intervals()
- self.assertEqual(largest[0]["line_number"], 6)
- self.assertEqual(largest[0]["line"], asm[6])
- self.assertEqual(largest[1]["line_number"], 3)
- self.assertEqual(largest[1]["line"], asm[3])
-
- def test_generate_diff_insert(self):
- global asm
- asm = ['first line', 'second line', 'third line',
- 'INCBIN "baserom.gbc",$90,$200 - $90',
- 'fifth line', 'last line',
- 'INCBIN "baserom.gbc",$33F,$4000 - $33F']
- diff = generate_diff_insert(0, "the real first line", debug=False)
- self.assertIn("the real first line", diff)
- self.assertIn("INCBIN", diff)
- self.assertNotIn("No newline at end of file", diff)
- self.assertIn("+"+asm[1], diff)
-
-
-class TestMapParsing(unittest.TestCase):
- def test_parse_all_map_headers(self):
- global parse_map_header_at, old_parse_map_header_at, counter
- counter = 0
- for k in map_names.keys():
- if "offset" not in map_names[k].keys():
- map_names[k]["offset"] = 0
- temp = parse_map_header_at
- temp2 = old_parse_map_header_at
- def parse_map_header_at(address, map_group=None, map_id=None, debug=False):
- global counter
- counter += 1
- return {}
- old_parse_map_header_at = parse_map_header_at
- parse_all_map_headers(debug=False)
- # parse_all_map_headers is currently doing it 2x
- # because of the new/old map header parsing routines
- self.assertEqual(counter, 388 * 2)
- parse_map_header_at = temp
- old_parse_map_header_at = temp2
-
-class TestTextScript(unittest.TestCase):
- """for testing 'in-script' commands, etc."""
- #def test_to_asm(self):
- # pass # or raise NotImplementedError, bryan_message
- #def test_find_addresses(self):
- # pass # or raise NotImplementedError, bryan_message
- #def test_parse_text_at(self):
- # pass # or raise NotImplementedError, bryan_message
-
-
-class TestEncodedText(unittest.TestCase):
- """for testing chars-table encoded text chunks"""
-
- def test_process_00_subcommands(self):
- g = process_00_subcommands(0x197186, 0x197186+601, debug=False)
- self.assertEqual(len(g), 42)
- self.assertEqual(len(g[0]), 13)
- self.assertEqual(g[1], [184, 174, 180, 211, 164, 127, 20, 231, 81])
-
- def test_parse_text_at2(self):
- oakspeech = parse_text_at2(0x197186, 601, debug=False)
- self.assertIn("encyclopedia", oakspeech)
- self.assertIn("researcher", oakspeech)
- self.assertIn("dependable", oakspeech)
-
- def test_parse_text_engine_script_at(self):
- p = parse_text_engine_script_at(0x197185, debug=False)
- self.assertEqual(len(p.commands), 2)
- self.assertEqual(len(p.commands[0]["lines"]), 41)
-
- # don't really care about these other two
- def test_parse_text_from_bytes(self): pass
- def test_parse_text_at(self): pass
-
-
-class TestScript(unittest.TestCase):
- """for testing parse_script_engine_script_at and script parsing in
- general. Script should be a class."""
- #def test_parse_script_engine_script_at(self):
- # pass # or raise NotImplementedError, bryan_message
-
- def test_find_all_text_pointers_in_script_engine_script(self):
- address = 0x197637 # 0x197634
- script = parse_script_engine_script_at(address, debug=False)
- bank = calculate_bank(address)
- r = find_all_text_pointers_in_script_engine_script(script, bank=bank, debug=False)
- results = list(r)
- self.assertIn(0x197661, results)
-
-
-class TestLabel(unittest.TestCase):
- def test_label_making(self):
- line_number = 2
- address = 0xf0c0
- label_name = "poop"
- l = Label(name=label_name, address=address, line_number=line_number)
- self.failUnless(hasattr(l, "name"))
- self.failUnless(hasattr(l, "address"))
- self.failUnless(hasattr(l, "line_number"))
- self.failIf(isinstance(l.address, str))
- self.failIf(isinstance(l.line_number, str))
- self.failUnless(isinstance(l.name, str))
- self.assertEqual(l.line_number, line_number)
- self.assertEqual(l.name, label_name)
- self.assertEqual(l.address, address)
-
-
-class TestByteParams(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- load_rom()
- cls.address = 10
- cls.sbp = SingleByteParam(address=cls.address)
-
- @classmethod
- def tearDownClass(cls):
- del cls.sbp
-
- def test__init__(self):
- self.assertEqual(self.sbp.size, 1)
- self.assertEqual(self.sbp.address, self.address)
-
- def test_parse(self):
- self.sbp.parse()
- self.assertEqual(str(self.sbp.byte), str(45))
-
- def test_to_asm(self):
- self.assertEqual(self.sbp.to_asm(), "$2d")
- self.sbp.should_be_decimal = True
- self.assertEqual(self.sbp.to_asm(), str(45))
-
- # HexByte and DollarSignByte are the same now
- def test_HexByte_to_asm(self):
- h = HexByte(address=10)
- a = h.to_asm()
- self.assertEqual(a, "$2d")
-
- def test_DollarSignByte_to_asm(self):
- d = DollarSignByte(address=10)
- a = d.to_asm()
- self.assertEqual(a, "$2d")
-
- def test_ItemLabelByte_to_asm(self):
- i = ItemLabelByte(address=433)
- self.assertEqual(i.byte, 54)
- self.assertEqual(i.to_asm(), "COIN_CASE")
- self.assertEqual(ItemLabelByte(address=10).to_asm(), "$2d")
-
- def test_DecimalParam_to_asm(self):
- d = DecimalParam(address=10)
- x = d.to_asm()
- self.assertEqual(x, str(0x2d))
-
-
-class TestMultiByteParam(unittest.TestCase):
- def setup_for(self, somecls, byte_size=2, address=443, **kwargs):
- self.cls = somecls(address=address, size=byte_size, **kwargs)
- self.assertEqual(self.cls.address, address)
- self.assertEqual(self.cls.bytes, rom_interval(address, byte_size, strings=False))
- self.assertEqual(self.cls.size, byte_size)
-
- def test_two_byte_param(self):
- self.setup_for(MultiByteParam, byte_size=2)
- self.assertEqual(self.cls.to_asm(), "$f0c0")
-
- def test_three_byte_param(self):
- self.setup_for(MultiByteParam, byte_size=3)
-
- def test_PointerLabelParam_no_bank(self):
- self.setup_for(PointerLabelParam, bank=None)
- # assuming no label at this location..
- self.assertEqual(self.cls.to_asm(), "$f0c0")
- global all_labels
- # hm.. maybe all_labels should be using a class?
- all_labels = [{"label": "poop", "address": 0xf0c0,
- "offset": 0xf0c0, "bank": 0,
- "line_number": 2
- }]
- self.assertEqual(self.cls.to_asm(), "poop")
-
-
-class TestPostParsing: #(unittest.TestCase):
- """tests that must be run after parsing all maps"""
- @classmethod
- def setUpClass(cls):
- run_main()
-
- def test_signpost_counts(self):
- self.assertEqual(len(map_names[1][1]["signposts"]), 0)
- self.assertEqual(len(map_names[1][2]["signposts"]), 2)
- self.assertEqual(len(map_names[10][5]["signposts"]), 7)
-
- def test_warp_counts(self):
- self.assertEqual(map_names[10][5]["warp_count"], 9)
- self.assertEqual(map_names[18][5]["warp_count"], 3)
- self.assertEqual(map_names[15][1]["warp_count"], 2)
-
- def test_map_sizes(self):
- self.assertEqual(map_names[15][1]["height"], 18)
- self.assertEqual(map_names[15][1]["width"], 10)
- self.assertEqual(map_names[7][1]["height"], 4)
- self.assertEqual(map_names[7][1]["width"], 4)
-
- def test_map_connection_counts(self):
- self.assertEqual(map_names[7][1]["connections"], 0)
- self.assertEqual(map_names[10][1]["connections"], 12)
- self.assertEqual(map_names[10][2]["connections"], 12)
- self.assertEqual(map_names[11][1]["connections"], 9) # or 13?
-
- def test_second_map_header_address(self):
- self.assertEqual(map_names[11][1]["second_map_header_address"], 0x9509c)
- self.assertEqual(map_names[1][5]["second_map_header_address"], 0x95bd0)
-
- def test_event_address(self):
- self.assertEqual(map_names[17][5]["event_address"], 0x194d67)
- self.assertEqual(map_names[23][3]["event_address"], 0x1a9ec9)
-
- def test_people_event_counts(self):
- self.assertEqual(len(map_names[23][3]["people_events"]), 4)
- self.assertEqual(len(map_names[10][3]["people_events"]), 9)
-
-
-class TestMetaTesting(unittest.TestCase):
- """test whether or not i am finding at least
- some of the tests in this file"""
- tests = None
-
- def setUp(self):
- if self.tests == None:
- self.__class__.tests = assemble_test_cases()
-
- def test_assemble_test_cases_count(self):
- "does assemble_test_cases find some tests?"
- self.failUnless(len(self.tests) > 0)
-
- def test_assemble_test_cases_inclusion(self):
- "is this class found by assemble_test_cases?"
- # i guess it would have to be for this to be running?
- self.failUnless(self.__class__ in self.tests)
-
- def test_assemble_test_cases_others(self):
- "test other inclusions for assemble_test_cases"
- self.failUnless(TestRomStr in self.tests)
- self.failUnless(TestCram in self.tests)
-
- def test_check_has_test(self):
- self.failUnless(check_has_test("beaver", ["test_beaver"]))
- self.failUnless(check_has_test("beaver", ["test_beaver_2"]))
- self.failIf(check_has_test("beaver_1", ["test_beaver"]))
-
- def test_find_untested_methods(self):
- untested = find_untested_methods()
- # the return type must be an iterable
- self.failUnless(hasattr(untested, "__iter__"))
- #.. basically, a list
- self.failUnless(isinstance(untested, list))
-
- def test_find_untested_methods_method(self):
- """create a function and see if it is found"""
- # setup a function in the global namespace
- global some_random_test_method
- # define the method
- def some_random_test_method(): pass
- # first make sure it is in the global scope
- members = inspect.getmembers(sys.modules[__name__], inspect.isfunction)
- func_names = [functuple[0] for functuple in members]
- self.assertIn("some_random_test_method", func_names)
- # test whether or not it is found by find_untested_methods
- untested = find_untested_methods()
- self.assertIn("some_random_test_method", untested)
- # remove the test method from the global namespace
- del some_random_test_method
-
- def test_load_tests(self):
- loader = unittest.TestLoader()
- suite = load_tests(loader, None, None)
- suite._tests[0]._testMethodName
- membership_test = lambda member: \
- inspect.isclass(member) and issubclass(member, unittest.TestCase)
- tests = inspect.getmembers(sys.modules[__name__], membership_test)
- classes = [x[1] for x in tests]
- for test in suite._tests:
- self.assertIn(test.__class__, classes)
-
- def test_report_untested(self):
- untested = find_untested_methods()
- output = report_untested()
- if len(untested) > 0:
- self.assertIn("NOT TESTED", output)
- for name in untested:
- self.assertIn(name, output)
- elif len(untested) == 0:
- self.assertNotIn("NOT TESTED", output)
-
-
-def assemble_test_cases():
- """finds classes that inherit from unittest.TestCase
- because i am too lazy to remember to add them to a
- global list of tests for the suite runner"""
- classes = []
- clsmembers = inspect.getmembers(sys.modules[__name__], inspect.isclass)
- for (name, some_class) in clsmembers:
- if issubclass(some_class, unittest.TestCase):
- classes.append(some_class)
- return classes
-
-def load_tests(loader, tests, pattern):
- suite = unittest.TestSuite()
- for test_class in assemble_test_cases():
- tests = loader.loadTestsFromTestCase(test_class)
- suite.addTests(tests)
- return suite
-
-def check_has_test(func_name, tested_names):
- """checks if there is a test dedicated to this function"""
- if "test_"+func_name in tested_names:
- return True
- for name in tested_names:
- if "test_"+func_name in name:
- return True
- return False
-
-def find_untested_methods():
- """finds all untested functions in this module
- by searching for method names in test case
- method names."""
- untested = []
- avoid_funcs = ["main", "run_tests", "run_main", "copy", "deepcopy"]
- test_funcs = []
- # get a list of all classes in this module
- classes = inspect.getmembers(sys.modules[__name__], inspect.isclass)
- # for each class..
- for (name, klass) in classes:
- # only look at those that have tests
- if issubclass(klass, unittest.TestCase):
- # look at this class' methods
- funcs = inspect.getmembers(klass, inspect.ismethod)
- # for each method..
- for (name2, func) in funcs:
- # store the ones that begin with test_
- if "test_" in name2 and name2[0:5] == "test_":
- test_funcs.append([name2, func])
- # assemble a list of all test method names (test_x, test_y, ..)
- tested_names = [funcz[0] for funcz in test_funcs]
- # now get a list of all functions in this module
- funcs = inspect.getmembers(sys.modules[__name__], inspect.isfunction)
- # for each function..
- for (name, func) in funcs:
- # we don't care about some of these
- if name in avoid_funcs: continue
- # skip functions beginning with _
- if name[0] == "_": continue
- # check if this function has a test named after it
- has_test = check_has_test(name, tested_names)
- if not has_test:
- untested.append(name)
- return untested
-
-def report_untested():
- untested = find_untested_methods()
- output = "NOT TESTED: ["
- first = True
- for name in untested:
- if first:
- output += name
- first = False
- else: output += ", "+name
- output += "]\n"
- output += "total untested: " + str(len(untested))
- return output
-
-#### ways to run this file ####
-
-def run_tests(): # rather than unittest.main()
- loader = unittest.TestLoader()
- suite = load_tests(loader, None, None)
- unittest.TextTestRunner(verbosity=2).run(suite)
- print report_untested()
-
def run_main():
# read the rom and figure out the offsets for maps
direct_load_rom()
@@ -8848,10 +7845,9 @@ def run_main():
make_trainer_group_name_trainer_ids(trainer_group_table)
# just a helpful alias
-main=run_main
-# when you run the file.. do unit tests
-if __name__ == "__main__":
- run_tests()
+main = run_main
+
# when you load the module.. parse everything
-elif __name__ == "crystal": pass
- #run_main()
+if __name__ == "crystal":
+ pass
+
diff --git a/dump_sections b/dump_sections
new file mode 100755
index 0000000..362318f
--- /dev/null
+++ b/dump_sections
@@ -0,0 +1,14 @@
+#!/bin/bash
+# This wraps dump_sections.py so that other python scripts can import the
+# functions. If dump_sections.py was instead called dump_sections, then other
+# python source code would be unable to use the functions via import
+# statements.
+
+# figure out the path to this script
+cwd="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
+
+# construct the path to dump_sections.py
+secpath=$cwd/dump_sections.py
+
+# run dump_sections.py
+$secpath $1
diff --git a/dump_sections.py b/dump_sections.py
new file mode 100755
index 0000000..91306e4
--- /dev/null
+++ b/dump_sections.py
@@ -0,0 +1,130 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+"""
+Use this tool to dump an asm file for a new source code or disassembly project.
+
+usage:
+
+ from dump_sections import dump_sections
+
+ output = dump_sections("../../butt.gbc")
+
+ file_handler = open("main.asm", "w")
+ file_handler.write(output)
+ file_handler.close()
+
+You can also use this script from the shell, where it will look for
+"baserom.gbc" in the current working path or whatever file path you pass in the
+first positional argument.
+"""
+
+import os
+import sys
+import argparse
+
+def upper_hex(input):
+ """
+ Converts the input to an uppercase hex string.
+ """
+ if input in [0, "0"]:
+ return "0"
+ elif input <= 0xF:
+ return ("%.x" % (input)).upper()
+ else:
+ return ("%.2x" % (input)).upper()
+
+def format_bank_number(address, bank_size=0x4000):
+ """
+ Returns a str of the hex number of the bank based on the address.
+ """
+ return upper_hex(address / bank_size)
+
+def calculate_bank_quantity(path, bank_size=0x4000):
+ """
+ Returns the number of 0x4000 banks in the file at path.
+ """
+ return float(os.lstat(path).st_size) / bank_size
+
+def dump_section(bank_number, separator="\n\n"):
+ """
+ Returns a str of a section header for the asm file.
+ """
+ output = "SECTION \""
+ if bank_number in [0, "0"]:
+ output += "bank0\",HOME"
+ else:
+ output += "bank"
+ output += bank_number
+ output += "\",DATA,BANK[$"
+ output += bank_number
+ output += "]"
+ output += separator
+ return output
+
+def dump_incbin_for_section(address, bank_size=0x4000, baserom="baserom.gbc", separator="\n\n"):
+ """
+ Returns a str for an INCBIN line for an entire section.
+ """
+ output = "INCBIN \""
+ output += baserom
+ output += "\",$"
+ output += upper_hex(address)
+ output += ",$"
+ output += upper_hex(bank_size)
+ output += separator
+ return output
+
+def dump_sections(path, bank_size=0x4000, initial_bank=0, last_bank=None, separator="\n\n"):
+ """
+ Returns a str of assembly source code. The source code delineates each
+ SECTION and includes bytes from the file specified by baserom.
+ """
+ if not last_bank:
+ last_bank = calculate_bank_quantity(path, bank_size=bank_size)
+
+ if last_bank < initial_bank:
+ raise Exception("last_bank must be greater than or equal to initial_bank")
+
+ baserom_name = os.path.basename(path)
+
+ output = ""
+
+ banks = range(initial_bank, last_bank)
+
+ for bank_number in banks:
+ address = bank_number * bank_size
+
+ # get a formatted hex number of the bank based on the address
+ formatted_bank_number = format_bank_number(address, bank_size=bank_size)
+
+ # SECTION
+ output += dump_section(formatted_bank_number, separator=separator)
+
+ # INCBIN a range of bytes from the ROM
+ output += dump_incbin_for_section(address, bank_size=bank_size, baserom=baserom_name, separator=separator)
+
+ # clean up newlines at the end of the output
+ if output[-2:] == "\n\n":
+ output = output[:-2]
+ output += "\n"
+
+ return output
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("rompath", nargs="?", metavar="rompath", type=str)
+ args = parser.parse_args()
+
+ # default to "baserom.gbc" in the current working directory
+ baserom = "baserom.gbc"
+
+ # but let the user override the default
+ if args.rompath:
+ baserom = args.rompath
+
+ # generate some asm
+ output = dump_sections(baserom)
+
+ # dump everything to stdout
+ sys.stdout.write(output)
+
diff --git a/gbz80disasm.py b/gbz80disasm.py
index 48739e0..f2ba483 100644
--- a/gbz80disasm.py
+++ b/gbz80disasm.py
@@ -1,26 +1,28 @@
-#author: Bryan Bishop <kanzure@gmail.com>
-#date: 2012-01-09
+# -*- coding: utf-8 -*-
+
import os
import sys
from copy import copy, deepcopy
from ctypes import c_int8
-import json
import random
+import json
-spacing = "\t"
+# New versions of json don't have read anymore.
+if not hasattr(json, "read"):
+ json.read = json.loads
-class XRomStr(str):
- def __repr__(self):
- return "RomStr(too long)"
+from romstr import RomStr
def load_rom(filename="../baserom.gbc"):
"""loads bytes into memory"""
global rom
- file_handler = open(filename, "rb")
- rom = XRomStr(file_handler.read())
+ file_handler = open(filename, "rb")
+ rom = RomStr(file_handler.read())
file_handler.close()
return rom
+spacing = "\t"
+
temp_opt_table = [
[ "ADC A", 0x8f, 0 ],
[ "ADC B", 0x88, 0 ],
@@ -550,7 +552,7 @@ end_08_scripts_with = [
0xc9, #ret
###0xda, 0xe9, 0xd2, 0xc2, 0xca, 0xc3, 0x38, 0x30, 0x20, 0x28, 0x18, 0xd8, 0xd0, 0xc0, 0xc8, 0xc9
]
-relative_jumps = [0x38, 0x30, 0x20, 0x28, 0x18, 0xc3, 0xda, 0xc2]
+relative_jumps = [0x38, 0x30, 0x20, 0x28, 0x18, 0xc3, 0xda, 0xc2]
relative_unconditional_jumps = [0xc3, 0x18]
call_commands = [0xdc, 0xd4, 0xc4, 0xcc, 0xcd]
@@ -559,7 +561,7 @@ all_labels = {}
def load_labels(filename="labels.json"):
global all_labels
if os.path.exists(filename):
- all_labels = json.loads(open(filename, "r").read())
+ all_labels = json.read(open(filename, "r").read())
else:
print "You must run crystal.scan_for_predefined_labels() to create \"labels.json\". Trying..."
import crystal
@@ -601,10 +603,10 @@ def output_bank_opcodes(original_offset, max_byte_count=0x4000, debug = False):
#i = offset
#ad = end_address
#a, oa = current_byte_number
-
+
load_labels()
load_rom()
-
+
bank_id = 0
if original_offset > 0x8000:
bank_id = original_offset / 0x4000
diff --git a/gfx.py b/gfx.py
index f36b944..67bb664 100644
--- a/gfx.py
+++ b/gfx.py
@@ -1043,14 +1043,16 @@ def decompress_monsters(type = front):
# decompress
monster = decompress_monster_by_id(id, type)
if monster != None: # no unowns here
- filename = str(id+1).zfill(3) + '.2bpp' # 001.2bpp
if not type: # front
- folder = '../gfx/frontpics/'
+ filename = 'front.2bpp'
+ folder = '../gfx/pics/' + str(id+1).zfill(3) + '/'
to_file(folder+filename, monster.pic)
- folder = '../gfx/anim/'
+ filename = 'tiles.2bpp'
+ folder = '../gfx/pics/' + str(id+1).zfill(3) + '/'
to_file(folder+filename, monster.animtiles)
else: # back
- folder = '../gfx/backpics/'
+ filename = 'back.2bpp'
+ folder = '../gfx/pics/' + str(id+1).zfill(3) + '/'
to_file(folder+filename, monster.pic)
@@ -1073,14 +1075,16 @@ def decompress_unowns(type = front):
# decompress
unown = decompress_unown_by_id(letter, type)
- filename = str(unown_dex).zfill(3) + chr(ord('a') + letter) + '.2bpp' # 201a.2bpp
if not type: # front
- folder = '../gfx/frontpics/'
+ filename = 'front.2bpp'
+ folder = '../gfx/pics/' + str(unown_dex).zfill(3) + chr(ord('a') + letter) + '/'
to_file(folder+filename, unown.pic)
+ filename = 'tiles.2bpp'
folder = '../gfx/anim/'
to_file(folder+filename, unown.animtiles)
else: # back
- folder = '../gfx/backpics/'
+ filename = 'back.2bpp'
+ folder = '../gfx/pics/' + str(unown_dex).zfill(3) + chr(ord('a') + letter) + '/'
to_file(folder+filename, unown.pic)
@@ -1255,8 +1259,8 @@ def compress_file(filein, fileout, mode = 'horiz'):
def compress_monster_frontpic(id, fileout):
mode = 'vert'
- fpic = '../gfx/frontpics/' + str(id).zfill(3) + '.2bpp'
- fanim = '../gfx/anim/' + str(id).zfill(3) + '.2bpp'
+ fpic = '../gfx/pics/' + str(id).zfill(3) + '/front.2bpp'
+ fanim = '../gfx/pics/' + str(id).zfill(3) + '/tiles.2bpp'
pic = open(fpic, 'rb').read()
anim = open(fanim, 'rb').read()
@@ -1264,7 +1268,7 @@ def compress_monster_frontpic(id, fileout):
lz = Compressed(image, mode, 5)
- out = '../gfx/frontpics/lz/' + str(id).zfill(3) + '.lz'
+ out = '../gfx/pics/' + str(id).zfill(3) + '/front.lz'
to_file(out, lz.output)
@@ -1283,6 +1287,28 @@ def get_uncompressed_gfx(start, num_tiles, filename):
+def hex_to_rgb(word):
+ red = word & 0b11111
+ word >>= 5
+ green = word & 0b11111
+ word >>= 5
+ blue = word & 0b11111
+ return (red, green, blue)
+
+def grab_palettes(address, length = 0x80):
+ output = ''
+ for word in range(length/2):
+ color = ord(rom[address+1])*0x100 + ord(rom[address])
+ address += 2
+ color = hex_to_rgb(color)
+ red = str(color[0]).zfill(2)
+ green = str(color[1]).zfill(2)
+ blue = str(color[2]).zfill(2)
+ output += '\tRGB '+red+', '+green+', '+blue
+ output += '\n'
+ return output
+
+
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('cmd', nargs='?', metavar='cmd', type=str)
@@ -1317,7 +1343,11 @@ if __name__ == "__main__":
# python gfx.py un [address] [num_tiles] [filename]
get_uncompressed_gfx(int(args.arg1,16), int(args.arg2), args.arg3)
- else:
- # python gfx.py
- decompress_all()
- if debug: print 'decompressed known gfx to ../gfx/!'
+ elif args.cmd == 'pal':
+ # python gfx.py pal [address] [length]
+ print grab_palettes(int(args.arg1,16), int(args.arg2))
+
+ #else:
+ ## python gfx.py
+ #decompress_all()
+ #if debug: print 'decompressed known gfx to ../gfx/!'
diff --git a/graph.py b/graph.py
index 98f871a..b545083 100644
--- a/graph.py
+++ b/graph.py
@@ -1,12 +1,13 @@
-#!/usr/bin/python
-# author: Bryan Bishop <kanzure@gmail.com>
-# date: 2012-06-20
+# -*- coding: utf-8 -*-
import networkx as nx
-from romstr import RomStr, DisAsm, \
- relative_jumps, call_commands, \
- relative_unconditional_jumps
+from romstr import (
+ RomStr,
+ relative_jumps,
+ call_commands,
+ relative_unconditional_jumps,
+)
class RomGraph(nx.DiGraph):
""" Graphs various functions pointing to each other.
diff --git a/interval_map.py b/interval_map.py
new file mode 100644
index 0000000..7e6c5cd
--- /dev/null
+++ b/interval_map.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+
+from bisect import bisect_left, bisect_right
+from itertools import izip
+
+class IntervalMap(object):
+ """
+ This class maps a set of intervals to a set of values.
+
+ >>> i = IntervalMap()
+ >>> i[0:5] = "hello world"
+ >>> i[6:10] = "hello cruel world"
+ >>> print i[4]
+ "hello world"
+ """
+
+ def __init__(self):
+ """initializes an empty IntervalMap"""
+ self._bounds = []
+ self._items = []
+ self._upperitem = None
+
+ def __setitem__(self, _slice, _value):
+ """sets an interval mapping"""
+ assert isinstance(_slice, slice), 'The key must be a slice object'
+
+ if _slice.start is None:
+ start_point = -1
+ else:
+ start_point = bisect_left(self._bounds, _slice.start)
+
+ if _slice.stop is None:
+ end_point = -1
+ else:
+ end_point = bisect_left(self._bounds, _slice.stop)
+
+ if start_point>=0:
+ if start_point < len(self._bounds) and self._bounds[start_point]<_slice.start:
+ start_point += 1
+
+ if end_point>=0:
+ self._bounds[start_point:end_point] = [_slice.start, _slice.stop]
+ if start_point < len(self._items):
+ self._items[start_point:end_point] = [self._items[start_point], _value]
+ else:
+ self._items[start_point:end_point] = [self._upperitem, _value]
+ else:
+ self._bounds[start_point:] = [_slice.start]
+ if start_point < len(self._items):
+ self._items[start_point:] = [self._items[start_point], _value]
+ else:
+ self._items[start_point:] = [self._upperitem]
+ self._upperitem = _value
+ else:
+ if end_point>=0:
+ self._bounds[:end_point] = [_slice.stop]
+ self._items[:end_point] = [_value]
+ else:
+ self._bounds[:] = []
+ self._items[:] = []
+ self._upperitem = _value
+
+ def __getitem__(self,_point):
+ """gets a value from the mapping"""
+ assert not isinstance(_point, slice), 'The key cannot be a slice object'
+
+ index = bisect_right(self._bounds, _point)
+ if index < len(self._bounds):
+ return self._items[index]
+ else:
+ return self._upperitem
+
+ def items(self):
+ """returns an iterator with each item being
+ ((low_bound, high_bound), value)
+ these items are returned in order"""
+ previous_bound = None
+ for (b, v) in izip(self._bounds, self._items):
+ if v is not None:
+ yield (previous_bound, b), v
+ previous_bound = b
+ if self._upperitem is not None:
+ yield (previous_bound, None), self._upperitem
+
+ def values(self):
+ """returns an iterator with each item being a stored value
+ the items are returned in order"""
+ for v in self._items:
+ if v is not None:
+ yield v
+ if self._upperitem is not None:
+ yield self._upperitem
+
+ def __repr__(self):
+ s = []
+ for b,v in self.items():
+ if v is not None:
+ s.append('[%r, %r] => %r'%(
+ b[0],
+ b[1],
+ v
+ ))
+ return '{'+', '.join(s)+'}'
+
diff --git a/item_constants.py b/item_constants.py
index d60dfb1..a050637 100644
--- a/item_constants.py
+++ b/item_constants.py
@@ -1,4 +1,7 @@
-item_constants = {1: 'MASTER_BALL',
+# -*- coding: utf-8 -*-
+
+item_constants = {
+1: 'MASTER_BALL',
2: 'ULTRA_BALL',
3: 'BRIGHTPOWDER',
4: 'GREAT_BALL',
@@ -219,4 +222,20 @@ item_constants = {1: 'MASTER_BALL',
246: 'HM_04',
247: 'HM_05',
248: 'HM_06',
-249: 'HM_07'}
+249: 'HM_07',
+}
+
+def find_item_label_by_id(id):
+ if id in item_constants.keys():
+ return item_constants[id]
+ else: return None
+
+def generate_item_constants():
+ """make a list of items to put in constants.asm"""
+ output = ""
+ for (id, item) in item_constants.items():
+ val = ("$%.2x"%id).upper()
+ while len(item)<13: item+= " "
+ output += item + " EQU " + val + "\n"
+ return output
+
diff --git a/labels.py b/labels.py
index a25fa3f..e57c6e2 100644
--- a/labels.py
+++ b/labels.py
@@ -1,7 +1,12 @@
-""" Various label/line-related functions.
+# -*- coding: utf-8 -*-
+"""
+Various label/line-related functions.
"""
-from pointers import calculate_pointer, calculate_bank
+from pointers import (
+ calculate_pointer,
+ calculate_bank,
+)
def remove_quoted_text(line):
"""get rid of content inside quotes
diff --git a/move_constants.py b/move_constants.py
index 929f1fa..a20af85 100644
--- a/move_constants.py
+++ b/move_constants.py
@@ -1,3 +1,5 @@
+# -*- coding: utf-8 -*-
+
moves = {
0x01: "POUND",
0x02: "KARATE_CHOP",
diff --git a/pksv.py b/pksv.py
index 8f4bafe..03ad2d0 100644
--- a/pksv.py
+++ b/pksv.py
@@ -1,3 +1,4 @@
+# -*- coding: utf-8 -*-
pksv_gs = {
0x00: "2call",
@@ -34,7 +35,7 @@ pksv_gs = {
0x21: "checkitem",
0x22: "givemoney",
0x23: "takemoney",
- 0x24: "checkmonkey",
+ 0x24: "checkmoney",
0x25: "givecoins",
0x26: "takecoins",
0x27: "checkcoins",
@@ -141,8 +142,8 @@ pksv_gs = {
0xA3: "displaylocation",
}
-#see http://www.pokecommunity.com/showpost.php?p=4347261
-#NOTE: this has some updates that need to be back-ported to gold
+# see http://www.pokecommunity.com/showpost.php?p=4347261
+# NOTE: this has some updates that need to be back-ported to gold
pksv_crystal = {
0x00: "2call",
0x01: "3call",
@@ -179,7 +180,7 @@ pksv_crystal = {
0x21: "checkitem",
0x22: "givemoney",
0x23: "takemoney",
- 0x24: "checkmonkey",
+ 0x24: "checkmoney",
0x25: "givecoins",
0x26: "takecoins",
0x27: "checkcoins",
@@ -292,13 +293,14 @@ pksv_crystal = {
}
#these cause the script to end; used in create_command_classes
-pksv_crystal_more_enders = [0x03, 0x04, 0x05, 0x0C, 0x51, 0x53,
- 0x8D, 0x8F, 0x90, 0x91, 0x92, 0x9B,
+pksv_crystal_more_enders = [0x03, 0x04, 0x05, 0x0C, 0x51, 0x52,
+ 0x53, 0x8D, 0x8F, 0x90, 0x91, 0x92,
+ 0x9B,
0xB2, #maybe?
0xCC, #maybe?
]
-#these have no pksv names as of pksv 2.1.1
+# these have no pksv names as of pksv 2.1.1
pksv_crystal_unknowns = [
0x9F,
0xA6, 0xA7, 0xA8, 0xA9, 0xAA, 0xAB, 0xAC, 0xAD, 0xAE, 0xAF,
diff --git a/pointers.py b/pointers.py
index f392241..8fe3df3 100644
--- a/pointers.py
+++ b/pointers.py
@@ -1,10 +1,12 @@
-""" Various functions related to pointer and address math. Mostly to avoid
- depedency loops.
+# -*- coding: utf-8 -*-
+"""
+Various functions related to pointer and address math. Mostly to avoid
+depedency loops.
"""
def calculate_bank(address):
"""you are too lazy to divide on your own?"""
- if type(address) == str:
+ if type(address) == str:
address = int(address, 16)
#if 0x4000 <= address <= 0x7FFF:
# raise Exception, "bank 1 does not exist"
diff --git a/pokemon_constants.py b/pokemon_constants.py
index 33b6a0e..221a31c 100644
--- a/pokemon_constants.py
+++ b/pokemon_constants.py
@@ -1,3 +1,5 @@
+# -*- coding: utf-8 -*-
+
pokemon_constants = {
1: "BULBASAUR",
2: "IVYSAUR",
diff --git a/romstr.py b/romstr.py
index 66ac507..d2eea44 100644
--- a/romstr.py
+++ b/romstr.py
@@ -1,8 +1,21 @@
-import sys, os, time, datetime, json
-from gbz80disasm import opt_table
+# -*- coding: utf-8 -*-
+
+import sys
+import os
+import time
+import datetime
from ctypes import c_int8
-from copy import copy, deepcopy
-from labels import get_label_from_line, get_address_from_line_comment
+from copy import copy
+import json
+
+# New versions of json don't have read anymore.
+if not hasattr(json, "read"):
+ json.read = json.loads
+
+from labels import (
+ get_label_from_line,
+ get_address_from_line_comment,
+)
relative_jumps = [0x38, 0x30, 0x20, 0x28, 0x18, 0xc3, 0xda, 0xc2, 0x32]
relative_unconditional_jumps = [0xc3, 0x18]
@@ -91,7 +104,7 @@ class RomStr(str):
file_handler.close()
# load the labels from the file
- self.labels = json.loads(open(filename, "r").read())
+ self.labels = json.read(open(filename, "r").read())
def get_address_for(self, label):
""" Returns the address of a label. This is slow and could be improved
@@ -137,7 +150,7 @@ class RomStr(str):
that will be parsed, so that large patches of data aren't parsed as
code.
"""
- if type(address) == str and "0x" in address:
+ if type(address) in [str, unicode] and "0x" in address:
address = int(address, 16)
start_address = address
@@ -166,333 +179,8 @@ class RomStr(str):
elif end_address != None and size == None:
size = end_address - start_address
- return DisAsm(start_address=start_address, end_address=end_address, size=size, max_size=max_size, debug=debug, rom=self)
-
-class DisAsm:
- """ z80 disassembler
- """
-
- def __init__(self, start_address=None, end_address=None, size=None, max_size=0x4000, debug=True, rom=None):
- assert start_address != None, "start_address must be given"
-
- if rom == None:
- file_handler = open("../baserom.gbc", "r")
- bytes = file_handler.read()
- file_handler.close()
- rom = RomStr(bytes)
-
- if debug not in [None, True, False]:
- raise Exception, "debug param is invalid"
- if debug == None:
- debug = False
-
- # get end_address and size in sync with each other
- if end_address == None and size != None:
- end_address = start_address + size
- elif end_address != None and size == None:
- size = end_address - start_address
- elif end_address != None and size != None:
- size = max(end_address - start_address, size)
- end_address = start_address + size
-
- # check that the bounds make sense
- if end_address != None:
- if end_address <= start_address:
- raise Exception, "end_address is out of bounds"
- elif (end_address - start_address) > max_size:
- raise Exception, "end_address goes beyond max_size"
-
- # check more edge cases
- if not start_address >= 0:
- raise Exception, "start_address must be at least 0"
- elif end_address != None and not end_address >= 0:
- raise Exception, "end_address must be at least 0"
-
- self.rom = rom
- self.start_address = start_address
- self.end_address = end_address
- self.size = size
- self.max_size = max_size
- self.debug = debug
-
- self.parse()
-
- def parse(self):
- """ Disassembles stuff and things.
- """
-
- rom = self.rom
- start_address = self.start_address
- end_address = self.end_address
- max_size = self.max_size
- debug = self.debug
-
- bank_id = start_address / 0x4000
-
- # [{"command": 0x20, "bytes": [0x20, 0x40, 0x50],
- # "asm": "jp $5040", "label": "Unknown5040"}]
- asm_commands = {}
-
- offset = start_address
-
- last_hl_address = None
- last_a_address = None
- used_3d97 = False
-
- keep_reading = True
-
- while (end_address != 0 and offset <= end_address) or keep_reading:
- # read the current opcode byte
- current_byte = ord(rom[offset])
- current_byte_number = len(asm_commands.keys())
-
- # setup this next/upcoming command
- if offset in asm_commands.keys():
- asm_command = asm_commands[offset]
- else:
- asm_command = {}
-
- asm_command["address"] = offset
-
- if not "references" in asm_command.keys():
- # This counts how many times relative jumps reference this
- # byte. This is used to determine whether or not to print out a
- # label later.
- asm_command["references"] = 0
-
- # some commands have two opcodes
- next_byte = ord(rom[offset+1])
-
- if self.debug:
- print "offset: \t\t" + hex(offset)
- print "current_byte: \t\t" + hex(current_byte)
- print "next_byte: \t\t" + hex(next_byte)
-
- # all two-byte opcodes also have their first byte in there somewhere
- if (current_byte in opt_table.keys()) or ((current_byte + (next_byte << 8)) in opt_table.keys()):
- # this might be a two-byte opcode
- possible_opcode = current_byte + (next_byte << 8)
-
- # check if this is a two-byte opcode
- if possible_opcode in opt_table.keys():
- op_code = possible_opcode
- else:
- op_code = current_byte
-
- op = opt_table[op_code]
-
- opstr = op[0].lower()
- optype = op[1]
-
- if self.debug:
- print "opstr: " + opstr
-
- asm_command["type"] = "op"
- asm_command["id"] = op_code
- asm_command["format"] = opstr
- asm_command["opnumberthing"] = optype
-
- opstr2 = None
- base_opstr = copy(opstr)
-
- if "x" in opstr:
- for x in range(0, opstr.count("x")):
- insertion = ord(rom[offset + 1])
-
- # Certain opcodes will have a local relative jump label
- # here instead of a raw hex value, but this is
- # controlled through asm output.
- insertion = "$" + hex(insertion)[2:]
-
- opstr = opstr[:opstr.find("x")].lower() + insertion + opstr[opstr.find("x")+1:].lower()
-
- if op_code in relative_jumps:
- target_address = offset + 2 + c_int8(ord(rom[offset + 1])).value
- insertion = "asm_" + hex(target_address)
-
- if str(target_address) in self.rom.labels.keys():
- insertion = self.rom.labels[str(target_address)]
-
- opstr2 = base_opstr[:base_opstr.find("x")].lower() + insertion + base_opstr[base_opstr.find("x")+1:].lower()
- asm_command["formatted_with_labels"] = opstr2
- asm_command["target_address"] = target_address
-
- current_byte_number += 1
- offset += 1
-
- if "?" in opstr:
- for y in range(0, opstr.count("?")):
- byte1 = ord(rom[offset + 1])
- byte2 = ord(rom[offset + 2])
-
- number = byte1
- number += byte2 << 8;
-
- # In most cases, you can use a label here. Labels will
- # be shown during asm output.
- insertion = "$%.4x" % (number)
-
- opstr = opstr[:opstr.find("?")].lower() + insertion + opstr[opstr.find("?")+1:].lower()
-
- # This version of the formatted string has labels. In
- # the future, the actual labels should be parsed
- # straight out of the "main.asm" file.
- target_address = number % 0x4000
- insertion = "asm_" + hex(target_address)
-
- if str(target_address) in self.rom.labels.keys():
- insertion = self.rom.labels[str(target_address)]
-
- opstr2 = base_opstr[:base_opstr.find("?")].lower() + insertion + base_opstr[base_opstr.find("?")+1:].lower()
- asm_command["formatted_with_labels"] = opstr2
- asm_command["target_address"] = target_address
-
- current_byte_number += 2
- offset += 2
-
- # Check for relative jumps, construct the formatted asm line.
- # Also set the usage of labels.
- if current_byte in [0x18, 0x20] + relative_jumps: # jr or jr nz
- # generate a label for the byte we're jumping to
- target_address = offset + 1 + c_int8(ord(rom[offset])).value
-
- if target_address in asm_commands.keys():
- asm_commands[target_address]["references"] += 1
- remote_label = "asm_" + hex(target_address)
- asm_commands[target_address]["current_label"] = remote_label
- asm_command["remote_label"] = remote_label
-
- # Not sure how to set this, can't be True because an
- # address referenced multiple times will use a label
- # despite the label not necessarily being used in the
- # output. The "use_remote_label" values should be
- # calculated when rendering the asm output, based on
- # which addresses and which op codes will be displayed
- # (within the range).
- asm_command["use_remote_label"] = "unknown"
- else:
- remote_label = "asm_" + hex(target_address)
-
- # This remote address might not be part of this
- # function.
- asm_commands[target_address] = {
- "references": 1,
- "current_label": remote_label,
- "address": target_address,
- }
- # Also, target_address can be negative (before the
- # start_address that the user originally requested),
- # and it shouldn't be shown on asm output because the
- # intermediate bytes (between a negative target_address
- # and start_address) won't be disassembled.
-
- # Don't know yet if this remote address is part of this
- # function or not. When the remote address is not part
- # of this function, the label name should not be used,
- # because that label will not be disassembled in the
- # output, until the user asks it to.
- asm_command["use_remote_label"] = "unknown"
- asm_command["remote_label"] = remote_label
- elif current_byte == 0x3e:
- last_a_address = ord(rom[offset + 1])
-
- # store the formatted string for the output later
- asm_command["formatted"] = opstr
-
- if current_byte == 0x21:
- last_hl_address = byte1 + (byte2 << 8)
-
- # this is leftover from pokered, might be meaningless
- if current_byte == 0xcd:
- if number == 0x3d97:
- used_3d97 = True
-
- if current_byte == 0xc3 or current_byte in relative_unconditional_jumps:
- if current_byte == 0xc3:
- if number == 0x3d97:
- used_3d97 = True
-
- # stop reading at a jump, relative jump or return
- if current_byte in end_08_scripts_with:
- is_data = False
-
- if not self.has_outstanding_labels(asm_commands, offset):
- keep_reading = False
- break
- else:
- keep_reading = True
- else:
- keep_reading = True
-
- else:
- # This shouldn't really happen, and means that this area of the
- # ROM probably doesn't represent instructions.
- asm_command["type"] = "data" # db
- asm_command["value"] = current_byte
- keep_reading = False
-
- # save this new command in the list
- asm_commands[asm_command["address"]] = asm_command
-
- # jump forward by a byte
- offset += 1
-
- # also save the last command if necessary
- if len(asm_commands.keys()) > 0 and asm_commands[asm_commands.keys()[-1]] is not asm_command:
- asm_commands[asm_command["address"]] = asm_command
-
- # store the set of commands on this object
- self.asm_commands = asm_commands
-
- self.end_address = offset + 1
- self.last_address = self.end_address
-
- def has_outstanding_labels(self, asm_commands, offset):
- """ Checks if there are any labels that haven't yet been created.
- """ # is this really necessary??
- return False
-
- def used_addresses(self):
- """ Returns a list of unique addresses that this function will probably
- call.
- """
- addresses = set()
-
- for (id, command) in self.asm_commands.items():
- if command.has_key("target_address") and command["id"] in call_commands:
- addresses.add(command["target_address"])
-
- return addresses
-
- def __str__(self):
- """ ASM pretty printer.
- """
- output = ""
-
- for (key, line) in self.asm_commands.items():
- # skip anything from before the beginning
- if key < self.start_address:
- continue
-
- # show a label
- if line["references"] > 0 and "current_label" in line.keys():
- if line["address"] == self.start_address:
- output += "thing: ; " + hex(line["address"]) + "\n"
- else:
- output += "." + line["current_label"] + "\@ ; " + hex(line["address"]) + "\n"
-
- # show the actual line
- if line.has_key("formatted_with_labels"):
- output += spacing + line["formatted_with_labels"]
- elif line.has_key("formatted"):
- output += spacing + line["formatted"]
- #output += " ; to " +
- output += "\n"
-
- # show the next address after this chunk
- output += "; " + hex(self.end_address)
-
- return output
+ raise NotImplementedError("DisAsm was removed and never worked; hook up another disassembler please.")
+ #return DisAsm(start_address=start_address, end_address=end_address, size=size, max_size=max_size, debug=debug, rom=self)
class AsmList(list):
""" Simple wrapper to prevent all asm lines from being shown on screen.
diff --git a/test_dump_sections.py b/test_dump_sections.py
new file mode 100644
index 0000000..b73b86f
--- /dev/null
+++ b/test_dump_sections.py
@@ -0,0 +1,74 @@
+# -*- coding: utf-8 -*-
+
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+# check for things we need in unittest
+if not hasattr(unittest.TestCase, 'setUpClass'):
+ sys.stderr.write("The unittest2 module or Python 2.7 is required to run this script.")
+ sys.exit(1)
+
+from dump_sections import (
+ upper_hex,
+ format_bank_number,
+ calculate_bank_quantity,
+ dump_section,
+ dump_incbin_for_section,
+)
+
+class TestDumpSections(unittest.TestCase):
+ def test_upper_hex(self):
+ number = 0x52
+ self.assertEquals(number, int("0x" + upper_hex(number), 16))
+
+ number = 0x1
+ self.assertEquals(number, int("0x" + upper_hex(number), 16))
+
+ number = 0x0
+ self.assertEquals(number, int("0x" + upper_hex(number), 16))
+
+ number = 0xAA
+ self.assertEquals(number, int("0x" + upper_hex(number), 16))
+
+ number = 0xFFFFAAA0000
+ self.assertEquals(number, int("0x" + upper_hex(number), 16))
+
+ def test_format_bank_number(self):
+ address = 0x0
+ self.assertEquals("0", format_bank_number(address))
+
+ address = 0x4000
+ self.assertEquals("1", format_bank_number(address))
+
+ address = 0x1FC000
+ self.assertEquals("7F", format_bank_number(address))
+
+ def test_dump_section(self):
+ self.assertIn("SECTION", dump_section(str(0)))
+ self.assertIn("HOME", dump_section(str(0)))
+ self.assertNotIn("HOME", dump_section(str(1)))
+ self.assertIn("DATA", dump_section(str(2)))
+ self.assertIn("BANK", dump_section(str(40)))
+ self.assertNotIn("BANK", dump_section(str(0)))
+
+ def test_dump_incbin_for_section(self):
+ self.assertIn("INCBIN", dump_incbin_for_section(0))
+
+ def test_dump_incbin_for_section_separator(self):
+ separator = "\n\n"
+ self.assertIn(separator, dump_incbin_for_section(0, separator=separator))
+
+ separator = "\t\t" # dumb
+ self.assertIn(separator, dump_incbin_for_section(0, separator=separator))
+
+ def test_dump_incbin_for_section_default(self):
+ rom = "baserom.gbc"
+ self.assertIn(rom, dump_incbin_for_section(0))
+
+ rom = "baserom"
+ self.assertIn(rom, dump_incbin_for_section(0x4000))
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests.py b/tests.py
new file mode 100644
index 0000000..61f46d6
--- /dev/null
+++ b/tests.py
@@ -0,0 +1,1015 @@
+# -*- coding: utf-8 -*-
+
+import os
+import sys
+import inspect
+from copy import copy
+import hashlib
+import random
+import json
+
+from interval_map import IntervalMap
+from chars import chars, jap_chars
+
+from romstr import (
+ RomStr,
+ AsmList,
+)
+
+from item_constants import (
+ item_constants,
+ find_item_label_by_id,
+ generate_item_constants,
+)
+
+from pointers import (
+ calculate_bank,
+ calculate_pointer,
+)
+
+from pksv import (
+ pksv_gs,
+ pksv_crystal,
+)
+
+from labels import (
+ remove_quoted_text,
+ line_has_comment_address,
+ line_has_label,
+ get_label_from_line,
+)
+
+from crystal import (
+ rom,
+ load_rom,
+ rom_until,
+ direct_load_rom,
+ parse_script_engine_script_at,
+ parse_text_engine_script_at,
+ parse_text_at2,
+ find_all_text_pointers_in_script_engine_script,
+ SingleByteParam,
+ HexByte,
+ MultiByteParam,
+ PointerLabelParam,
+ ItemLabelByte,
+ DollarSignByte,
+ DecimalParam,
+ rom_interval,
+ map_names,
+ Label,
+ scan_for_predefined_labels,
+ all_labels,
+ write_all_labels,
+ parse_map_header_at,
+ old_parse_map_header_at,
+ process_00_subcommands,
+ parse_all_map_headers,
+ translate_command_byte,
+ map_name_cleaner,
+ load_map_group_offsets,
+ load_asm,
+ asm,
+ is_valid_address,
+ index,
+ how_many_until,
+ grouper,
+ get_pokemon_constant_by_id,
+ generate_map_constant_labels,
+ get_map_constant_label_by_id,
+ get_id_for_map_constant_label,
+ calculate_pointer_from_bytes_at,
+ isolate_incbins,
+ process_incbins,
+ get_labels_between,
+ generate_diff_insert,
+ find_labels_without_addresses,
+ rom_text_at,
+ get_label_for,
+ split_incbin_line_into_three,
+ reset_incbins,
+)
+
+# for testing all this crap
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+# check for things we need in unittest
+if not hasattr(unittest.TestCase, 'setUpClass'):
+ sys.stderr.write("The unittest2 module or Python 2.7 is required to run this script.")
+ sys.exit(1)
+
+class TestCram(unittest.TestCase):
+ "this is where i cram all of my unit tests together"
+
+ @classmethod
+ def setUpClass(cls):
+ global rom
+ cls.rom = direct_load_rom()
+ rom = cls.rom
+
+ @classmethod
+ def tearDownClass(cls):
+ del cls.rom
+
+ def test_generic_useless(self):
+ "do i know how to write a test?"
+ self.assertEqual(1, 1)
+
+ def test_map_name_cleaner(self):
+ name = "hello world"
+ cleaned_name = map_name_cleaner(name)
+ self.assertNotEqual(name, cleaned_name)
+ self.failUnless(" " not in cleaned_name)
+ name = "Some Random Pokémon Center"
+ cleaned_name = map_name_cleaner(name)
+ self.assertNotEqual(name, cleaned_name)
+ self.failIf(" " in cleaned_name)
+ self.failIf("é" in cleaned_name)
+
+ def test_grouper(self):
+ data = range(0, 10)
+ groups = grouper(data, count=2)
+ self.assertEquals(len(groups), 5)
+ data = range(0, 20)
+ groups = grouper(data, count=2)
+ self.assertEquals(len(groups), 10)
+ self.assertNotEqual(data, groups)
+ self.assertNotEqual(len(data), len(groups))
+
+ def test_direct_load_rom(self):
+ rom = self.rom
+ self.assertEqual(len(rom), 2097152)
+ self.failUnless(isinstance(rom, RomStr))
+
+ def test_load_rom(self):
+ global rom
+ rom = None
+ load_rom()
+ self.failIf(rom == None)
+ rom = RomStr(None)
+ load_rom()
+ self.failIf(rom == RomStr(None))
+
+ def test_load_asm(self):
+ asm = load_asm()
+ joined_lines = "\n".join(asm)
+ self.failUnless("SECTION" in joined_lines)
+ self.failUnless("bank" in joined_lines)
+ self.failUnless(isinstance(asm, AsmList))
+
+ def test_rom_file_existence(self):
+ "ROM file must exist"
+ self.failUnless("baserom.gbc" in os.listdir("../"))
+
+ def test_rom_md5(self):
+ "ROM file must have the correct md5 sum"
+ rom = self.rom
+ correct = "9f2922b235a5eeb78d65594e82ef5dde"
+ md5 = hashlib.md5()
+ md5.update(rom)
+ md5sum = md5.hexdigest()
+ self.assertEqual(md5sum, correct)
+
+ def test_bizarre_http_presence(self):
+ rom_segment = self.rom[0x112116:0x112116+8]
+ self.assertEqual(rom_segment, "HTTP/1.0")
+
+ def test_rom_interval(self):
+ address = 0x100
+ interval = 10
+ correct_strings = ['0x0', '0xc3', '0x6e', '0x1', '0xce',
+ '0xed', '0x66', '0x66', '0xcc', '0xd']
+ byte_strings = rom_interval(address, interval, strings=True)
+ self.assertEqual(byte_strings, correct_strings)
+ correct_ints = [0, 195, 110, 1, 206, 237, 102, 102, 204, 13]
+ ints = rom_interval(address, interval, strings=False)
+ self.assertEqual(ints, correct_ints)
+
+ def test_rom_until(self):
+ address = 0x1337
+ byte = 0x13
+ bytes = rom_until(address, byte, strings=True)
+ self.failUnless(len(bytes) == 3)
+ self.failUnless(bytes[0] == '0xd5')
+ bytes = rom_until(address, byte, strings=False)
+ self.failUnless(len(bytes) == 3)
+ self.failUnless(bytes[0] == 0xd5)
+
+ def test_how_many_until(self):
+ how_many = how_many_until(chr(0x13), 0x1337)
+ self.assertEqual(how_many, 3)
+
+ def test_calculate_bank(self):
+ self.failUnless(calculate_bank(0x8000) == 2)
+ self.failUnless(calculate_bank("0x9000") == 2)
+ self.failUnless(calculate_bank(0) == 0)
+ for address in [0x4000, 0x5000, 0x6000, 0x7000]:
+ self.assertRaises(Exception, calculate_bank, address)
+
+ def test_calculate_pointer(self):
+ # for offset <= 0x4000
+ self.assertEqual(calculate_pointer(0x0000), 0x0000)
+ self.assertEqual(calculate_pointer(0x3FFF), 0x3FFF)
+ # for 0x4000 <= offset <= 0x7FFFF
+ self.assertEqual(calculate_pointer(0x430F, bank=5), 0x1430F)
+ # for offset >= 0x7FFF
+ self.assertEqual(calculate_pointer(0x8FFF, bank=6), calculate_pointer(0x8FFF, bank=7))
+
+ def test_calculate_pointer_from_bytes_at(self):
+ addr1 = calculate_pointer_from_bytes_at(0x100, bank=False)
+ self.assertEqual(addr1, 0xc300)
+ addr2 = calculate_pointer_from_bytes_at(0x100, bank=True)
+ self.assertEqual(addr2, 0x2ec3)
+
+ def test_rom_text_at(self):
+ self.assertEquals(rom_text_at(0x112116, 8), "HTTP/1.0")
+
+ def test_translate_command_byte(self):
+ self.failUnless(translate_command_byte(crystal=0x0) == 0x0)
+ self.failUnless(translate_command_byte(crystal=0x10) == 0x10)
+ self.failUnless(translate_command_byte(crystal=0x40) == 0x40)
+ self.failUnless(translate_command_byte(gold=0x0) == 0x0)
+ self.failUnless(translate_command_byte(gold=0x10) == 0x10)
+ self.failUnless(translate_command_byte(gold=0x40) == 0x40)
+ self.assertEqual(translate_command_byte(gold=0x0), translate_command_byte(crystal=0x0))
+ self.failUnless(translate_command_byte(gold=0x52) == 0x53)
+ self.failUnless(translate_command_byte(gold=0x53) == 0x54)
+ self.failUnless(translate_command_byte(crystal=0x53) == 0x52)
+ self.failUnless(translate_command_byte(crystal=0x52) == None)
+ self.assertRaises(Exception, translate_command_byte, None, gold=0xA4)
+
+ def test_pksv_integrity(self):
+ "does pksv_gs look okay?"
+ self.assertEqual(pksv_gs[0x00], "2call")
+ self.assertEqual(pksv_gs[0x2D], "givepoke")
+ self.assertEqual(pksv_gs[0x85], "waitbutton")
+ self.assertEqual(pksv_crystal[0x00], "2call")
+ self.assertEqual(pksv_crystal[0x86], "waitbutton")
+ self.assertEqual(pksv_crystal[0xA2], "credits")
+
+ def test_chars_integrity(self):
+ self.assertEqual(chars[0x80], "A")
+ self.assertEqual(chars[0xA0], "a")
+ self.assertEqual(chars[0xF0], "¥")
+ self.assertEqual(jap_chars[0x44], "ぱ")
+
+ def test_map_names_integrity(self):
+ def map_name(map_group, map_id): return map_names[map_group][map_id]["name"]
+ self.assertEqual(map_name(2, 7), "Mahogany Town")
+ self.assertEqual(map_name(3, 0x34), "Ilex Forest")
+ self.assertEqual(map_name(7, 0x11), "Cerulean City")
+
+ def test_load_map_group_offsets(self):
+ addresses = load_map_group_offsets()
+ self.assertEqual(len(addresses), 26, msg="there should be 26 map groups")
+ addresses = load_map_group_offsets()
+ self.assertEqual(len(addresses), 26, msg="there should still be 26 map groups")
+ self.assertIn(0x94034, addresses)
+ for address in addresses:
+ self.assertGreaterEqual(address, 0x4000)
+ self.failIf(0x4000 <= address <= 0x7FFF)
+ self.failIf(address <= 0x4000)
+
+ def test_index(self):
+ self.assertTrue(index([1,2,3,4], lambda f: True) == 0)
+ self.assertTrue(index([1,2,3,4], lambda f: f==3) == 2)
+
+ def test_get_pokemon_constant_by_id(self):
+ x = get_pokemon_constant_by_id
+ self.assertEqual(x(1), "BULBASAUR")
+ self.assertEqual(x(151), "MEW")
+ self.assertEqual(x(250), "HO_OH")
+
+ def test_find_item_label_by_id(self):
+ x = find_item_label_by_id
+ self.assertEqual(x(249), "HM_07")
+ self.assertEqual(x(173), "BERRY")
+ self.assertEqual(x(45), None)
+
+ def test_generate_item_constants(self):
+ x = generate_item_constants
+ r = x()
+ self.failUnless("HM_07" in r)
+ self.failUnless("EQU" in r)
+
+ def test_get_label_for(self):
+ global all_labels
+ temp = copy(all_labels)
+ # this is basd on the format defined in get_labels_between
+ all_labels = [{"label": "poop", "address": 0x5,
+ "offset": 0x5, "bank": 0,
+ "line_number": 2
+ }]
+ self.assertEqual(get_label_for(5), "poop")
+ all_labels = temp
+
+ def test_generate_map_constant_labels(self):
+ ids = generate_map_constant_labels()
+ self.assertEqual(ids[0]["label"], "OLIVINE_POKECENTER_1F")
+ self.assertEqual(ids[1]["label"], "OLIVINE_GYM")
+
+ def test_get_id_for_map_constant_label(self):
+ global map_internal_ids
+ map_internal_ids = generate_map_constant_labels()
+ self.assertEqual(get_id_for_map_constant_label("OLIVINE_GYM"), 1)
+ self.assertEqual(get_id_for_map_constant_label("OLIVINE_POKECENTER_1F"), 0)
+
+ def test_get_map_constant_label_by_id(self):
+ global map_internal_ids
+ map_internal_ids = generate_map_constant_labels()
+ self.assertEqual(get_map_constant_label_by_id(0), "OLIVINE_POKECENTER_1F")
+ self.assertEqual(get_map_constant_label_by_id(1), "OLIVINE_GYM")
+
+ def test_is_valid_address(self):
+ self.assertTrue(is_valid_address(0))
+ self.assertTrue(is_valid_address(1))
+ self.assertTrue(is_valid_address(10))
+ self.assertTrue(is_valid_address(100))
+ self.assertTrue(is_valid_address(1000))
+ self.assertTrue(is_valid_address(10000))
+ self.assertFalse(is_valid_address(2097153))
+ self.assertFalse(is_valid_address(2098000))
+ addresses = [random.randrange(0,2097153) for i in range(0, 9+1)]
+ for address in addresses:
+ self.assertTrue(is_valid_address(address))
+
+class TestIntervalMap(unittest.TestCase):
+ def test_intervals(self):
+ i = IntervalMap()
+ first = "hello world"
+ second = "testing 123"
+ i[0:5] = first
+ i[5:10] = second
+ self.assertEqual(i[0], first)
+ self.assertEqual(i[1], first)
+ self.assertNotEqual(i[5], first)
+ self.assertEqual(i[6], second)
+ i[3:10] = second
+ self.assertEqual(i[3], second)
+ self.assertNotEqual(i[4], first)
+
+ def test_items(self):
+ i = IntervalMap()
+ first = "hello world"
+ second = "testing 123"
+ i[0:5] = first
+ i[5:10] = second
+ results = list(i.items())
+ self.failUnless(len(results) == 2)
+ self.assertEqual(results[0], ((0, 5), "hello world"))
+ self.assertEqual(results[1], ((5, 10), "testing 123"))
+
+class TestRomStr(unittest.TestCase):
+ """RomStr is a class that should act exactly like str()
+ except that it never shows the contents of it string
+ unless explicitly forced"""
+ sample_text = "hello world!"
+ sample = None
+
+ def setUp(self):
+ if self.sample == None:
+ self.__class__.sample = RomStr(self.sample_text)
+
+ def test_equals(self):
+ "check if RomStr() == str()"
+ self.assertEquals(self.sample_text, self.sample)
+
+ def test_not_equal(self):
+ "check if RomStr('a') != RomStr('b')"
+ self.assertNotEqual(RomStr('a'), RomStr('b'))
+
+ def test_appending(self):
+ "check if RomStr()+'a'==str()+'a'"
+ self.assertEquals(self.sample_text+'a', self.sample+'a')
+
+ def test_conversion(self):
+ "check if RomStr() -> str() works"
+ self.assertEquals(str(self.sample), self.sample_text)
+
+ def test_inheritance(self):
+ self.failUnless(issubclass(RomStr, str))
+
+ def test_length(self):
+ self.assertEquals(len(self.sample_text), len(self.sample))
+ self.assertEquals(len(self.sample_text), self.sample.length())
+ self.assertEquals(len(self.sample), self.sample.length())
+
+ def test_rom_interval(self):
+ global rom
+ load_rom()
+ address = 0x100
+ interval = 10
+ correct_strings = ['0x0', '0xc3', '0x6e', '0x1', '0xce',
+ '0xed', '0x66', '0x66', '0xcc', '0xd']
+ byte_strings = rom.interval(address, interval, strings=True)
+ self.assertEqual(byte_strings, correct_strings)
+ correct_ints = [0, 195, 110, 1, 206, 237, 102, 102, 204, 13]
+ ints = rom.interval(address, interval, strings=False)
+ self.assertEqual(ints, correct_ints)
+
+ def test_rom_until(self):
+ global rom
+ load_rom()
+ address = 0x1337
+ byte = 0x13
+ bytes = rom.until(address, byte, strings=True)
+ self.failUnless(len(bytes) == 3)
+ self.failUnless(bytes[0] == '0xd5')
+ bytes = rom.until(address, byte, strings=False)
+ self.failUnless(len(bytes) == 3)
+ self.failUnless(bytes[0] == 0xd5)
+
+class TestAsmList(unittest.TestCase):
+ """AsmList is a class that should act exactly like list()
+ except that it never shows the contents of its list
+ unless explicitly forced"""
+
+ def test_equals(self):
+ base = [1,2,3]
+ asm = AsmList(base)
+ self.assertEquals(base, asm)
+ self.assertEquals(asm, base)
+ self.assertEquals(base, list(asm))
+
+ def test_inheritance(self):
+ self.failUnless(issubclass(AsmList, list))
+
+ def test_length(self):
+ base = range(0, 10)
+ asm = AsmList(base)
+ self.assertEquals(len(base), len(asm))
+ self.assertEquals(len(base), asm.length())
+ self.assertEquals(len(base), len(list(asm)))
+ self.assertEquals(len(asm), asm.length())
+
+ def test_remove_quoted_text(self):
+ x = remove_quoted_text
+ self.assertEqual(x("hello world"), "hello world")
+ self.assertEqual(x("hello \"world\""), "hello ")
+ input = 'hello world "testing 123"'
+ self.assertNotEqual(x(input), input)
+ input = "hello world 'testing 123'"
+ self.assertNotEqual(x(input), input)
+ self.failIf("testing" in x(input))
+
+ def test_line_has_comment_address(self):
+ x = line_has_comment_address
+ self.assertFalse(x(""))
+ self.assertFalse(x(";"))
+ self.assertFalse(x(";;;"))
+ self.assertFalse(x(":;"))
+ self.assertFalse(x(":;:"))
+ self.assertFalse(x(";:"))
+ self.assertFalse(x(" "))
+ self.assertFalse(x("".join(" " * 5)))
+ self.assertFalse(x("".join(" " * 10)))
+ self.assertFalse(x("hello world"))
+ self.assertFalse(x("hello_world"))
+ self.assertFalse(x("hello_world:"))
+ self.assertFalse(x("hello_world:;"))
+ self.assertFalse(x("hello_world: ;"))
+ self.assertFalse(x("hello_world: ; "))
+ self.assertFalse(x("hello_world: ;" + "".join(" " * 5)))
+ self.assertFalse(x("hello_world: ;" + "".join(" " * 10)))
+ self.assertTrue(x(";1"))
+ self.assertTrue(x(";F"))
+ self.assertTrue(x(";$00FF"))
+ self.assertTrue(x(";0x00FF"))
+ self.assertTrue(x("; 0x00FF"))
+ self.assertTrue(x(";$3:$300"))
+ self.assertTrue(x(";0x3:$300"))
+ self.assertTrue(x(";$3:0x300"))
+ self.assertTrue(x(";3:300"))
+ self.assertTrue(x(";3:FFAA"))
+ self.assertFalse(x('hello world "how are you today;0x1"'))
+ self.assertTrue(x('hello world "how are you today:0x1";1'))
+ returnable = {}
+ self.assertTrue(x("hello_world: ; 0x4050", returnable=returnable, bank=5))
+ self.assertTrue(returnable["address"] == 0x14050)
+
+ def test_line_has_label(self):
+ x = line_has_label
+ self.assertTrue(x("hi:"))
+ self.assertTrue(x("Hello: "))
+ self.assertTrue(x("MyLabel: ; test xyz"))
+ self.assertFalse(x(":"))
+ self.assertFalse(x(";HelloWorld:"))
+ self.assertFalse(x("::::"))
+ self.assertFalse(x(":;:;:;:::"))
+
+ def test_get_label_from_line(self):
+ x = get_label_from_line
+ self.assertEqual(x("HelloWorld: "), "HelloWorld")
+ self.assertEqual(x("HiWorld:"), "HiWorld")
+ self.assertEqual(x("HiWorld"), None)
+
+ def test_find_labels_without_addresses(self):
+ global asm
+ asm = ["hello_world: ; 0x1", "hello_world2: ;"]
+ labels = find_labels_without_addresses()
+ self.failUnless(labels[0]["label"] == "hello_world2")
+ asm = ["hello world: ;1", "hello_world: ;2"]
+ labels = find_labels_without_addresses()
+ self.failUnless(len(labels) == 0)
+ asm = None
+
+ def test_get_labels_between(self):
+ global asm
+ x = get_labels_between#(start_line_id, end_line_id, bank)
+ asm = ["HelloWorld: ;1",
+ "hi:",
+ "no label on this line",
+ ]
+ labels = x(0, 2, 0x12)
+ self.assertEqual(len(labels), 1)
+ self.assertEqual(labels[0]["label"], "HelloWorld")
+ del asm
+
+ # this test takes a lot of time :(
+ def xtest_scan_for_predefined_labels(self):
+ # label keys: line_number, bank, label, offset, address
+ load_asm()
+ all_labels = scan_for_predefined_labels()
+ label_names = [x["label"] for x in all_labels]
+ self.assertIn("GetFarByte", label_names)
+ self.assertIn("AddNTimes", label_names)
+ self.assertIn("CheckShininess", label_names)
+
+ def test_write_all_labels(self):
+ """dumping json into a file"""
+ filename = "test_labels.json"
+ # remove the current file
+ if os.path.exists(filename):
+ os.system("rm " + filename)
+ # make up some labels
+ labels = []
+ # fake label 1
+ label = {"line_number": 5, "bank": 0, "label": "SomeLabel", "address": 0x10}
+ labels.append(label)
+ # fake label 2
+ label = {"line_number": 15, "bank": 2, "label": "SomeOtherLabel", "address": 0x9F0A}
+ labels.append(label)
+ # dump to file
+ write_all_labels(labels, filename=filename)
+ # open the file and read the contents
+ file_handler = open(filename, "r")
+ contents = file_handler.read()
+ file_handler.close()
+ # parse into json
+ obj = json.read(contents)
+ # begin testing
+ self.assertEqual(len(obj), len(labels))
+ self.assertEqual(len(obj), 2)
+ self.assertEqual(obj, labels)
+
+ def test_isolate_incbins(self):
+ global asm
+ asm = ["123", "456", "789", "abc", "def", "ghi",
+ 'INCBIN "baserom.gbc",$12DA,$12F8 - $12DA',
+ "jkl",
+ 'INCBIN "baserom.gbc",$137A,$13D0 - $137A']
+ lines = isolate_incbins()
+ self.assertIn(asm[6], lines)
+ self.assertIn(asm[8], lines)
+ for line in lines:
+ self.assertIn("baserom", line)
+
+ def test_process_incbins(self):
+ global incbin_lines, processed_incbins, asm
+ incbin_lines = ['INCBIN "baserom.gbc",$12DA,$12F8 - $12DA',
+ 'INCBIN "baserom.gbc",$137A,$13D0 - $137A']
+ asm = copy(incbin_lines)
+ asm.insert(1, "some other random line")
+ processed_incbins = process_incbins()
+ self.assertEqual(len(processed_incbins), len(incbin_lines))
+ self.assertEqual(processed_incbins[0]["line"], incbin_lines[0])
+ self.assertEqual(processed_incbins[2]["line"], incbin_lines[1])
+
+ def test_reset_incbins(self):
+ global asm, incbin_lines, processed_incbins
+ # temporarily override the functions
+ global load_asm, isolate_incbins, process_incbins
+ temp1, temp2, temp3 = load_asm, isolate_incbins, process_incbins
+ def load_asm(): pass
+ def isolate_incbins(): pass
+ def process_incbins(): pass
+ # call reset
+ reset_incbins()
+ # check the results
+ self.assertTrue(asm == [] or asm == None)
+ self.assertTrue(incbin_lines == [])
+ self.assertTrue(processed_incbins == {})
+ # reset the original functions
+ load_asm, isolate_incbins, process_incbins = temp1, temp2, temp3
+
+ def test_find_incbin_to_replace_for(self):
+ global asm, incbin_lines, processed_incbins
+ asm = ['first line', 'second line', 'third line',
+ 'INCBIN "baserom.gbc",$90,$200 - $90',
+ 'fifth line', 'last line']
+ isolate_incbins()
+ process_incbins()
+ line_num = find_incbin_to_replace_for(0x100)
+ # must be the 4th line (the INBIN line)
+ self.assertEqual(line_num, 3)
+
+ def test_split_incbin_line_into_three(self):
+ global asm, incbin_lines, processed_incbins
+ asm = ['first line', 'second line', 'third line',
+ 'INCBIN "baserom.gbc",$90,$200 - $90',
+ 'fifth line', 'last line']
+ isolate_incbins()
+ process_incbins()
+ content = split_incbin_line_into_three(3, 0x100, 10)
+ # must end up with three INCBINs in output
+ self.failUnless(content.count("INCBIN") == 3)
+
+ def test_analyze_intervals(self):
+ global asm, incbin_lines, processed_incbins
+ asm, incbin_lines, processed_incbins = None, [], {}
+ asm = ['first line', 'second line', 'third line',
+ 'INCBIN "baserom.gbc",$90,$200 - $90',
+ 'fifth line', 'last line',
+ 'INCBIN "baserom.gbc",$33F,$4000 - $33F']
+ isolate_incbins()
+ process_incbins()
+ largest = analyze_intervals()
+ self.assertEqual(largest[0]["line_number"], 6)
+ self.assertEqual(largest[0]["line"], asm[6])
+ self.assertEqual(largest[1]["line_number"], 3)
+ self.assertEqual(largest[1]["line"], asm[3])
+
+ def test_generate_diff_insert(self):
+ global asm
+ asm = ['first line', 'second line', 'third line',
+ 'INCBIN "baserom.gbc",$90,$200 - $90',
+ 'fifth line', 'last line',
+ 'INCBIN "baserom.gbc",$33F,$4000 - $33F']
+ diff = generate_diff_insert(0, "the real first line", debug=False)
+ self.assertIn("the real first line", diff)
+ self.assertIn("INCBIN", diff)
+ self.assertNotIn("No newline at end of file", diff)
+ self.assertIn("+"+asm[1], diff)
+
+class TestMapParsing(unittest.TestCase):
+ def test_parse_all_map_headers(self):
+ global parse_map_header_at, old_parse_map_header_at, counter
+ counter = 0
+ for k in map_names.keys():
+ if "offset" not in map_names[k].keys():
+ map_names[k]["offset"] = 0
+ temp = parse_map_header_at
+ temp2 = old_parse_map_header_at
+ def parse_map_header_at(address, map_group=None, map_id=None, debug=False):
+ global counter
+ counter += 1
+ return {}
+ old_parse_map_header_at = parse_map_header_at
+ parse_all_map_headers(debug=False)
+ # parse_all_map_headers is currently doing it 2x
+ # because of the new/old map header parsing routines
+ self.assertEqual(counter, 388 * 2)
+ parse_map_header_at = temp
+ old_parse_map_header_at = temp2
+
+class TestTextScript(unittest.TestCase):
+ """for testing 'in-script' commands, etc."""
+ #def test_to_asm(self):
+ # pass # or raise NotImplementedError, bryan_message
+ #def test_find_addresses(self):
+ # pass # or raise NotImplementedError, bryan_message
+ #def test_parse_text_at(self):
+ # pass # or raise NotImplementedError, bryan_message
+
+class TestEncodedText(unittest.TestCase):
+ """for testing chars-table encoded text chunks"""
+
+ def test_process_00_subcommands(self):
+ g = process_00_subcommands(0x197186, 0x197186+601, debug=False)
+ self.assertEqual(len(g), 42)
+ self.assertEqual(len(g[0]), 13)
+ self.assertEqual(g[1], [184, 174, 180, 211, 164, 127, 20, 231, 81])
+
+ def test_parse_text_at2(self):
+ oakspeech = parse_text_at2(0x197186, 601, debug=False)
+ self.assertIn("encyclopedia", oakspeech)
+ self.assertIn("researcher", oakspeech)
+ self.assertIn("dependable", oakspeech)
+
+ def test_parse_text_engine_script_at(self):
+ p = parse_text_engine_script_at(0x197185, debug=False)
+ self.assertEqual(len(p.commands), 2)
+ self.assertEqual(len(p.commands[0]["lines"]), 41)
+
+ # don't really care about these other two
+ def test_parse_text_from_bytes(self): pass
+ def test_parse_text_at(self): pass
+
+class TestScript(unittest.TestCase):
+ """for testing parse_script_engine_script_at and script parsing in
+ general. Script should be a class."""
+ #def test_parse_script_engine_script_at(self):
+ # pass # or raise NotImplementedError, bryan_message
+
+ def test_find_all_text_pointers_in_script_engine_script(self):
+ address = 0x197637 # 0x197634
+ script = parse_script_engine_script_at(address, debug=False)
+ bank = calculate_bank(address)
+ r = find_all_text_pointers_in_script_engine_script(script, bank=bank, debug=False)
+ results = list(r)
+ self.assertIn(0x197661, results)
+
+class TestLabel(unittest.TestCase):
+ def test_label_making(self):
+ line_number = 2
+ address = 0xf0c0
+ label_name = "poop"
+ l = Label(name=label_name, address=address, line_number=line_number)
+ self.failUnless(hasattr(l, "name"))
+ self.failUnless(hasattr(l, "address"))
+ self.failUnless(hasattr(l, "line_number"))
+ self.failIf(isinstance(l.address, str))
+ self.failIf(isinstance(l.line_number, str))
+ self.failUnless(isinstance(l.name, str))
+ self.assertEqual(l.line_number, line_number)
+ self.assertEqual(l.name, label_name)
+ self.assertEqual(l.address, address)
+
+class TestByteParams(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ load_rom()
+ cls.address = 10
+ cls.sbp = SingleByteParam(address=cls.address)
+
+ @classmethod
+ def tearDownClass(cls):
+ del cls.sbp
+
+ def test__init__(self):
+ self.assertEqual(self.sbp.size, 1)
+ self.assertEqual(self.sbp.address, self.address)
+
+ def test_parse(self):
+ self.sbp.parse()
+ self.assertEqual(str(self.sbp.byte), str(45))
+
+ def test_to_asm(self):
+ self.assertEqual(self.sbp.to_asm(), "$2d")
+ self.sbp.should_be_decimal = True
+ self.assertEqual(self.sbp.to_asm(), str(45))
+
+ # HexByte and DollarSignByte are the same now
+ def test_HexByte_to_asm(self):
+ h = HexByte(address=10)
+ a = h.to_asm()
+ self.assertEqual(a, "$2d")
+
+ def test_DollarSignByte_to_asm(self):
+ d = DollarSignByte(address=10)
+ a = d.to_asm()
+ self.assertEqual(a, "$2d")
+
+ def test_ItemLabelByte_to_asm(self):
+ i = ItemLabelByte(address=433)
+ self.assertEqual(i.byte, 54)
+ self.assertEqual(i.to_asm(), "COIN_CASE")
+ self.assertEqual(ItemLabelByte(address=10).to_asm(), "$2d")
+
+ def test_DecimalParam_to_asm(self):
+ d = DecimalParam(address=10)
+ x = d.to_asm()
+ self.assertEqual(x, str(0x2d))
+
+class TestMultiByteParam(unittest.TestCase):
+ def setup_for(self, somecls, byte_size=2, address=443, **kwargs):
+ self.cls = somecls(address=address, size=byte_size, **kwargs)
+ self.assertEqual(self.cls.address, address)
+ self.assertEqual(self.cls.bytes, rom_interval(address, byte_size, strings=False))
+ self.assertEqual(self.cls.size, byte_size)
+
+ def test_two_byte_param(self):
+ self.setup_for(MultiByteParam, byte_size=2)
+ self.assertEqual(self.cls.to_asm(), "$f0c0")
+
+ def test_three_byte_param(self):
+ self.setup_for(MultiByteParam, byte_size=3)
+
+ def test_PointerLabelParam_no_bank(self):
+ self.setup_for(PointerLabelParam, bank=None)
+ # assuming no label at this location..
+ self.assertEqual(self.cls.to_asm(), "$f0c0")
+ global all_labels
+ # hm.. maybe all_labels should be using a class?
+ all_labels = [{"label": "poop", "address": 0xf0c0,
+ "offset": 0xf0c0, "bank": 0,
+ "line_number": 2
+ }]
+ self.assertEqual(self.cls.to_asm(), "poop")
+
+class TestPostParsing: #(unittest.TestCase):
+ """tests that must be run after parsing all maps"""
+ @classmethod
+ def setUpClass(cls):
+ run_main()
+
+ def test_signpost_counts(self):
+ self.assertEqual(len(map_names[1][1]["signposts"]), 0)
+ self.assertEqual(len(map_names[1][2]["signposts"]), 2)
+ self.assertEqual(len(map_names[10][5]["signposts"]), 7)
+
+ def test_warp_counts(self):
+ self.assertEqual(map_names[10][5]["warp_count"], 9)
+ self.assertEqual(map_names[18][5]["warp_count"], 3)
+ self.assertEqual(map_names[15][1]["warp_count"], 2)
+
+ def test_map_sizes(self):
+ self.assertEqual(map_names[15][1]["height"], 18)
+ self.assertEqual(map_names[15][1]["width"], 10)
+ self.assertEqual(map_names[7][1]["height"], 4)
+ self.assertEqual(map_names[7][1]["width"], 4)
+
+ def test_map_connection_counts(self):
+ self.assertEqual(map_names[7][1]["connections"], 0)
+ self.assertEqual(map_names[10][1]["connections"], 12)
+ self.assertEqual(map_names[10][2]["connections"], 12)
+ self.assertEqual(map_names[11][1]["connections"], 9) # or 13?
+
+ def test_second_map_header_address(self):
+ self.assertEqual(map_names[11][1]["second_map_header_address"], 0x9509c)
+ self.assertEqual(map_names[1][5]["second_map_header_address"], 0x95bd0)
+
+ def test_event_address(self):
+ self.assertEqual(map_names[17][5]["event_address"], 0x194d67)
+ self.assertEqual(map_names[23][3]["event_address"], 0x1a9ec9)
+
+ def test_people_event_counts(self):
+ self.assertEqual(len(map_names[23][3]["people_events"]), 4)
+ self.assertEqual(len(map_names[10][3]["people_events"]), 9)
+
+class TestMetaTesting(unittest.TestCase):
+ """test whether or not i am finding at least
+ some of the tests in this file"""
+ tests = None
+
+ def setUp(self):
+ if self.tests == None:
+ self.__class__.tests = assemble_test_cases()
+
+ def test_assemble_test_cases_count(self):
+ "does assemble_test_cases find some tests?"
+ self.failUnless(len(self.tests) > 0)
+
+ def test_assemble_test_cases_inclusion(self):
+ "is this class found by assemble_test_cases?"
+ # i guess it would have to be for this to be running?
+ self.failUnless(self.__class__ in self.tests)
+
+ def test_assemble_test_cases_others(self):
+ "test other inclusions for assemble_test_cases"
+ self.failUnless(TestRomStr in self.tests)
+ self.failUnless(TestCram in self.tests)
+
+ def test_check_has_test(self):
+ self.failUnless(check_has_test("beaver", ["test_beaver"]))
+ self.failUnless(check_has_test("beaver", ["test_beaver_2"]))
+ self.failIf(check_has_test("beaver_1", ["test_beaver"]))
+
+ def test_find_untested_methods(self):
+ untested = find_untested_methods()
+ # the return type must be an iterable
+ self.failUnless(hasattr(untested, "__iter__"))
+ #.. basically, a list
+ self.failUnless(isinstance(untested, list))
+
+ def test_find_untested_methods_method(self):
+ """create a function and see if it is found"""
+ # setup a function in the global namespace
+ global some_random_test_method
+ # define the method
+ def some_random_test_method(): pass
+ # first make sure it is in the global scope
+ members = inspect.getmembers(sys.modules[__name__], inspect.isfunction)
+ func_names = [functuple[0] for functuple in members]
+ self.assertIn("some_random_test_method", func_names)
+ # test whether or not it is found by find_untested_methods
+ untested = find_untested_methods()
+ self.assertIn("some_random_test_method", untested)
+ # remove the test method from the global namespace
+ del some_random_test_method
+
+ def test_load_tests(self):
+ loader = unittest.TestLoader()
+ suite = load_tests(loader, None, None)
+ suite._tests[0]._testMethodName
+ membership_test = lambda member: \
+ inspect.isclass(member) and issubclass(member, unittest.TestCase)
+ tests = inspect.getmembers(sys.modules[__name__], membership_test)
+ classes = [x[1] for x in tests]
+ for test in suite._tests:
+ self.assertIn(test.__class__, classes)
+
+ def test_report_untested(self):
+ untested = find_untested_methods()
+ output = report_untested()
+ if len(untested) > 0:
+ self.assertIn("NOT TESTED", output)
+ for name in untested:
+ self.assertIn(name, output)
+ elif len(untested) == 0:
+ self.assertNotIn("NOT TESTED", output)
+
+def assemble_test_cases():
+ """finds classes that inherit from unittest.TestCase
+ because i am too lazy to remember to add them to a
+ global list of tests for the suite runner"""
+ classes = []
+ clsmembers = inspect.getmembers(sys.modules[__name__], inspect.isclass)
+ for (name, some_class) in clsmembers:
+ if issubclass(some_class, unittest.TestCase):
+ classes.append(some_class)
+ return classes
+
+def load_tests(loader, tests, pattern):
+ suite = unittest.TestSuite()
+ for test_class in assemble_test_cases():
+ tests = loader.loadTestsFromTestCase(test_class)
+ suite.addTests(tests)
+ return suite
+
+def check_has_test(func_name, tested_names):
+ """checks if there is a test dedicated to this function"""
+ if "test_"+func_name in tested_names:
+ return True
+ for name in tested_names:
+ if "test_"+func_name in name:
+ return True
+ return False
+
+def find_untested_methods():
+ """finds all untested functions in this module
+ by searching for method names in test case
+ method names."""
+ untested = []
+ avoid_funcs = ["main", "run_tests", "run_main", "copy", "deepcopy"]
+ test_funcs = []
+ # get a list of all classes in this module
+ classes = inspect.getmembers(sys.modules[__name__], inspect.isclass)
+ # for each class..
+ for (name, klass) in classes:
+ # only look at those that have tests
+ if issubclass(klass, unittest.TestCase):
+ # look at this class' methods
+ funcs = inspect.getmembers(klass, inspect.ismethod)
+ # for each method..
+ for (name2, func) in funcs:
+ # store the ones that begin with test_
+ if "test_" in name2 and name2[0:5] == "test_":
+ test_funcs.append([name2, func])
+ # assemble a list of all test method names (test_x, test_y, ..)
+ tested_names = [funcz[0] for funcz in test_funcs]
+ # now get a list of all functions in this module
+ funcs = inspect.getmembers(sys.modules[__name__], inspect.isfunction)
+ # for each function..
+ for (name, func) in funcs:
+ # we don't care about some of these
+ if name in avoid_funcs: continue
+ # skip functions beginning with _
+ if name[0] == "_": continue
+ # check if this function has a test named after it
+ has_test = check_has_test(name, tested_names)
+ if not has_test:
+ untested.append(name)
+ return untested
+
+def report_untested():
+ """
+ This reports about untested functions in the global namespace. This was
+ originally in the crystal module, where it would list out the majority of
+ the functions. Maybe it should be moved back.
+ """
+ untested = find_untested_methods()
+ output = "NOT TESTED: ["
+ first = True
+ for name in untested:
+ if first:
+ output += name
+ first = False
+ else: output += ", "+name
+ output += "]\n"
+ output += "total untested: " + str(len(untested))
+ return output
+
+def run_tests(): # rather than unittest.main()
+ loader = unittest.TestLoader()
+ suite = load_tests(loader, None, None)
+ unittest.TextTestRunner(verbosity=2).run(suite)
+ print report_untested()
+
+# run the unit tests when this file is executed directly
+if __name__ == "__main__":
+ run_tests()
+
diff --git a/type_constants.py b/type_constants.py
new file mode 100644
index 0000000..da89b0b
--- /dev/null
+++ b/type_constants.py
@@ -0,0 +1,21 @@
+# -*- coding: utf-8 -*-
+type_constants = {
+ "NORMAL": 0x00,
+ "FIGHTING": 0x01,
+ "FLYING": 0x02,
+ "POISON": 0x03,
+ "GROUND": 0x04,
+ "ROCK": 0x05,
+ "BUG": 0x07,
+ "GHOST": 0x08,
+ "STEEL": 0x09,
+ "CURSE_T": 0x13,
+ "FIRE": 0x14,
+ "WATER": 0x15,
+ "GRASS": 0x16,
+ "ELECTRIC": 0x17,
+ "PSYCHIC": 0x18,
+ "ICE": 0x19,
+ "DRAGON": 0x1A,
+ "DARK": 0x1B,
+}