# -*- coding: utf-8 -*- import os from copy import copy import hashlib import random import json from pokemontools.interval_map import IntervalMap from pokemontools.chars import chars, jap_chars from pokemontools.romstr import ( RomStr, AsmList, ) from pokemontools.item_constants import ( item_constants, find_item_label_by_id, generate_item_constants, ) from pokemontools.pointers import ( calculate_bank, calculate_pointer, ) from pokemontools.pksv import ( pksv_gs, pksv_crystal, ) from pokemontools.labels import ( remove_quoted_text, line_has_comment_address, line_has_label, get_label_from_line, ) from pokemontools.helpers import ( grouper, index, ) from pokemontools.crystalparts.old_parsers import ( old_parse_map_header_at, ) from pokemontools.crystal import ( script_parse_table, load_rom, rom_until, direct_load_rom, parse_script_engine_script_at, parse_text_engine_script_at, parse_text_at2, find_all_text_pointers_in_script_engine_script, SingleByteParam, HexByte, MultiByteParam, PointerLabelParam, ItemLabelByte, DollarSignByte, DecimalParam, rom_interval, map_names, Label, scan_for_predefined_labels, all_labels, write_all_labels, parse_map_header_at, process_00_subcommands, parse_all_map_headers, translate_command_byte, map_name_cleaner, load_map_group_offsets, load_asm, asm, is_valid_address, how_many_until, get_pokemon_constant_by_id, generate_map_constant_labels, get_map_constant_label_by_id, get_id_for_map_constant_label, calculate_pointer_from_bytes_at, isolate_incbins, process_incbins, get_labels_between, rom_text_at, get_label_for, split_incbin_line_into_three, reset_incbins, parse_rom, # globals engine_flags, ) import pokemontools.wram import unittest import mock class BasicTestCase(unittest.TestCase): "this is where i cram all of my unit tests together" @classmethod def setUpClass(cls): global rom cls.rom = direct_load_rom() rom = cls.rom @classmethod def tearDownClass(cls): del cls.rom def test_direct_load_rom(self): rom = self.rom self.assertEqual(len(rom), 2097152) self.failUnless(isinstance(rom, RomStr)) def test_load_rom(self): rom = load_rom() self.assertNotEqual(rom, None) rom = load_rom() self.assertNotEqual(rom, RomStr(None)) def test_load_asm(self): asm = load_asm() joined_lines = "\n".join(asm) self.failUnless("SECTION" in joined_lines) self.failUnless("bank" in joined_lines) self.failUnless(isinstance(asm, AsmList)) def test_rom_file_existence(self): "ROM file must exist" dirname = os.path.dirname(__file__) filenames = os.listdir(os.path.join(os.path.abspath(dirname), "../../")) self.failUnless("baserom.gbc" in filenames) def test_rom_md5(self): "ROM file must have the correct md5 sum" rom = self.rom correct = "9f2922b235a5eeb78d65594e82ef5dde" md5 = hashlib.md5() md5.update(rom) md5sum = md5.hexdigest() self.assertEqual(md5sum, correct) def test_bizarre_http_presence(self): rom_segment = self.rom[0x112116:0x112116+8] self.assertEqual(rom_segment, "HTTP/1.0") def test_rom_interval(self): address = 0x100 interval = 10 correct_strings = ['0x0', '0xc3', '0x6e', '0x1', '0xce', '0xed', '0x66', '0x66', '0xcc', '0xd'] byte_strings = rom_interval(address, interval, rom=self.rom, strings=True) self.assertEqual(byte_strings, correct_strings) correct_ints = [0, 195, 110, 1, 206, 237, 102, 102, 204, 13] ints = rom_interval(address, interval, rom=self.rom, strings=False) self.assertEqual(ints, correct_ints) def test_rom_until(self): address = 0x1337 byte = 0x13 bytes = rom_until(address, byte, rom=self.rom, strings=True) self.failUnless(len(bytes) == 3) self.failUnless(bytes[0] == '0xd5') bytes = rom_until(address, byte, rom=self.rom, strings=False) self.failUnless(len(bytes) == 3) self.failUnless(bytes[0] == 0xd5) def test_how_many_until(self): how_many = how_many_until(chr(0x13), 0x1337, self.rom) self.assertEqual(how_many, 3) def test_calculate_pointer_from_bytes_at(self): addr1 = calculate_pointer_from_bytes_at(0x100, bank=False) self.assertEqual(addr1, 0xc300) addr2 = calculate_pointer_from_bytes_at(0x100, bank=True) self.assertEqual(addr2, 0x2ec3) def test_rom_text_at(self): self.assertEquals(rom_text_at(0x112116, 8), "HTTP/1.0") class TestRomStr(unittest.TestCase): sample_text = "hello world!" sample = None def test_rom_interval(self): global rom load_rom() address = 0x100 interval = 10 correct_strings = ['0x0', '0xc3', '0x6e', '0x1', '0xce', '0xed', '0x66', '0x66', '0xcc', '0xd'] byte_strings = rom.interval(address, interval, strings=True) self.assertEqual(byte_strings, correct_strings) correct_ints = [0, 195, 110, 1, 206, 237, 102, 102, 204, 13] ints = rom.interval(address, interval, strings=False) self.assertEqual(ints, correct_ints) def test_rom_until(self): global rom load_rom() address = 0x1337 byte = 0x13 bytes = rom.until(address, byte, strings=True) self.failUnless(len(bytes) == 3) self.failUnless(bytes[0] == '0xd5') bytes = rom.until(address, byte, strings=False) self.failUnless(len(bytes) == 3) self.failUnless(bytes[0] == 0xd5) class TestAsmList(unittest.TestCase): # this test takes a lot of time :( def xtest_scan_for_predefined_labels(self): # label keys: line_number, bank, label, offset, address load_asm() all_labels = scan_for_predefined_labels() label_names = [x["label"] for x in all_labels] self.assertIn("GetFarByte", label_names) self.assertIn("AddNTimes", label_names) self.assertIn("CheckShininess", label_names) class TestEncodedText(unittest.TestCase): """for testing chars-table encoded text chunks""" def test_process_00_subcommands(self): g = process_00_subcommands(0x197186, 0x197186+601, debug=False) self.assertEqual(len(g), 42) self.assertEqual(len(g[0]), 13) self.assertEqual(g[1], [184, 174, 180, 211, 164, 127, 20, 231, 81]) def test_parse_text_at2(self): oakspeech = parse_text_at2(0x197186, 601, debug=False) self.assertIn("encyclopedia", oakspeech) self.assertIn("researcher", oakspeech) self.assertIn("dependable", oakspeech) def test_parse_text_engine_script_at(self): p = parse_text_engine_script_at(0x197185, debug=False) self.assertEqual(len(p.commands), 1) self.assertEqual(p.commands[0].to_asm().count("\n"), 40) class TestScript(unittest.TestCase): """for testing parse_script_engine_script_at and script parsing in general. Script should be a class.""" #def test_parse_script_engine_script_at(self): # pass # or raise NotImplementedError, bryan_message def test_find_all_text_pointers_in_script_engine_script(self): address = 0x197637 # 0x197634 script = parse_script_engine_script_at(address, debug=False) bank = calculate_bank(address) r = find_all_text_pointers_in_script_engine_script(script, bank=bank, debug=False) results = list(r) self.assertIn(0x197661, results) class TestByteParams(unittest.TestCase): @classmethod def setUpClass(cls): load_rom() cls.address = 10 cls.sbp = SingleByteParam(address=cls.address) @classmethod def tearDownClass(cls): del cls.sbp def test__init__(self): self.assertEqual(self.sbp.size, 1) self.assertEqual(self.sbp.address, self.address) def test_parse(self): self.sbp.parse() self.assertEqual(str(self.sbp.byte), str(45)) def test_to_asm(self): self.assertEqual(self.sbp.to_asm(), "$2d") self.sbp.should_be_decimal = True self.assertEqual(self.sbp.to_asm(), str(45)) # HexByte and DollarSignByte are the same now def test_HexByte_to_asm(self): h = HexByte(address=10) a = h.to_asm() self.assertEqual(a, "$2d") def test_DollarSignByte_to_asm(self): d = DollarSignByte(address=10) a = d.to_asm() self.assertEqual(a, "$2d") def test_ItemLabelByte_to_asm(self): i = ItemLabelByte(address=433) self.assertEqual(i.byte, 54) self.assertEqual(i.to_asm(), "COIN_CASE") self.assertEqual(ItemLabelByte(address=10).to_asm(), "$2d") def test_DecimalParam_to_asm(self): d = DecimalParam(address=10) x = d.to_asm() self.assertEqual(x, str(0x2d)) class TestMultiByteParam(unittest.TestCase): def setup_for(self, somecls, byte_size=2, address=443, **kwargs): self.rom = load_rom() self.cls = somecls(address=address, size=byte_size, **kwargs) self.assertEqual(self.cls.address, address) self.assertEqual(self.cls.bytes, rom_interval(address, byte_size, rom=self.rom, strings=False)) self.assertEqual(self.cls.size, byte_size) def test_two_byte_param(self): self.setup_for(MultiByteParam, byte_size=2) self.assertEqual(self.cls.to_asm(), "$f0c0") def test_three_byte_param(self): self.setup_for(MultiByteParam, byte_size=3) def test_PointerLabelParam_no_bank(self): self.setup_for(PointerLabelParam, bank=None) # assuming no label at this location.. self.assertEqual(self.cls.to_asm(), "$f0c0") global script_parse_table script_parse_table[0xf0c0:0xf0c0 + 1] = {"label": "poop", "bank": 0, "line_number": 2} self.assertEqual(self.cls.to_asm(), "poop") class TestPostParsing(unittest.TestCase): """tests that must be run after parsing all maps""" @classmethod def setUpClass(cls): cls.rom = direct_load_rom() pokemontools.wram.wram_labels = {} with mock.patch("pokemontools.crystal.read_engine_flags", return_value={}): with mock.patch("pokemontools.crystal.read_event_flags", return_value={}): with mock.patch("pokemontools.crystal.setup_wram_labels", return_value={}): parse_rom(rom=cls.rom, _skip_wram_labels=True, _parse_map_header_at=old_parse_map_header_at, debug=False) def test_signpost_counts(self): self.assertEqual(len(map_names[1][1]["header_new"]["event_header"]["signposts"]), 0) self.assertEqual(len(map_names[1][2]["header_new"]["event_header"]["signposts"]), 2) self.assertEqual(len(map_names[10][5]["header_new"]["event_header"]["signposts"]), 7) def test_warp_counts(self): self.assertEqual(map_names[10][5]["header_new"]["event_header"]["warp_count"], 9) self.assertEqual(map_names[18][5]["header_new"]["event_header"]["warp_count"], 3) self.assertEqual(map_names[15][1]["header_new"]["event_header"]["warp_count"], 2) def test_map_sizes(self): self.assertEqual(map_names[15][1]["header_new"]["second_map_header"]["height"], 18) self.assertEqual(map_names[15][1]["header_new"]["second_map_header"]["width"], 10) self.assertEqual(map_names[7][1]["header_new"]["second_map_header"]["height"], 4) self.assertEqual(map_names[7][1]["header_new"]["second_map_header"]["width"], 4) def test_map_connection_counts(self): #print map_names[10][5] #print map_names[10][5].keys() #print map_names[10][5]["header_new"].keys() self.assertEqual(map_names[7][1]["header_new"]["second_map_header"]["connections"], 0) self.assertEqual(map_names[10][1]["header_new"]["second_map_header"]["connections"], 12) self.assertEqual(map_names[10][2]["header_new"]["second_map_header"]["connections"], 12) self.assertEqual(map_names[11][1]["header_new"]["second_map_header"]["connections"], 9) # or 13? def test_second_map_header_address(self): self.assertEqual(map_names[11][1]["header_new"]["second_map_header_address"], 0x9509c) self.assertEqual(map_names[1][5]["header_new"]["second_map_header_address"], 0x95bd0) def test_event_address(self): self.assertEqual(map_names[17][5]["header_new"]["second_map_header"]["event_address"], 0x194d67) self.assertEqual(map_names[23][3]["header_new"]["second_map_header"]["event_address"], 0x1a9ec9) def test_people_event_counts(self): self.assertEqual(len(map_names[23][3]["header_new"]["event_header"]["people_events"]), 4) self.assertEqual(len(map_names[10][3]["header_new"]["event_header"]["people_events"]), 9) class TestMapParsing(unittest.TestCase): def xtest_parse_all_map_headers(self): global parse_map_header_at, old_parse_map_header_at, counter counter = 0 for k in map_names.keys(): if "offset" not in map_names[k].keys(): map_names[k]["offset"] = 0 temp = parse_map_header_at temp2 = old_parse_map_header_at def parse_map_header_at(address, map_group=None, map_id=None, debug=False): global counter counter += 1 return {} old_parse_map_header_at = parse_map_header_at parse_all_map_headers(debug=False) # parse_all_map_headers is currently doing it 2x # because of the new/old map header parsing routines self.assertEqual(counter, 388 * 2) parse_map_header_at = temp old_parse_map_header_at = temp2 if __name__ == "__main__": unittest.main()