# coding: utf-8 import re import sys import argparse from gbz import * from gbaddr import * sys.path.insert(0, 'extras/pokemontools') from gbz80disasm import * rom = bytearray(open('baserom.gbc', 'r').read()) def naive_parse(start, end): """ Linear parsing at arbitrary addresses. Prone to interpreting data as asm. """ address = start while address < end: output, offset = output_bank_opcodes_plus_wram(address) if offset < end: label = find_label(get_local_address(address), address/0x4000) if not label: label = 'Function%x' % address print '%s: ' % (label) + output + '\n' address = offset def parse_unknowns(start, results={}, debug=True, default_bank=1): """ Parse any unknown functions starting from some address. Only follows calls/jumps. """ if results.has_key(start): # already been here return results top_bank = start / 0x4000 if top_bank == 0: top_bank = default_bank output, offset = output_bank_opcodes_plus_wram(start) label = find_label(get_local_address(start), top_bank) if not label: label = 'Function%x' % start results[start] = label + ': ' + output else: results[start] = None # don't look into stuff that's already done return results # call/jp/jr bank = top_bank patterns = [] for addr_pattern in [r'(\$[0-9a-fA-F]+)', r'([\._0-9a-zA-Z]+)']: for pat in ['call', 'jp', 'jr']: for condition in ['', 'z\, ', 'nz\, ', 'c\, ', 'nc\, ']: patterns += [r'('+pat+r')' + r' ' + condition + addr_pattern] matches = [] for pattern in patterns: matches = re.finditer(pattern, output) for match in matches: bank = top_bank # rst Bankswitch # how do you get just the last one? bankswitch = r'ld a, (\$[0-9a-fA-F]+)' + r'.*?\n\t*?' + r'rst [\$10|Bankswitch]' for m in re.finditer(bankswitch, output[:match.start()]): bank = eval(m.expand(r'\1').replace('$','0x')) addr = match.expand(r'\2').replace('$','0x') if addr in ['z', 'nz', 'c', 'nc']: continue # wrong regex if debug: sys.stdout.write(match.expand(r'\1') + ' ' + str(bank) + ' ' + addr + '\n') sys.stdout.flush() try: # address addr = eval(addr) if addr > 0x8000: if debug: sys.stdout.write(output) sys.stdout.flush() continue address = addr if addr < 0x4000 else (addr + bank * 0x4000 - (0x4000 if bank > 0 else 0)) except: # label address = find_address_from_label(addr) if address is None: continue # has this address been parsed already? if not results.has_key(address): results = parse_unknowns(address, results, debug, default_bank) # if a new function is getting parsed there's no reason to not use the new label if results[start]: results[start] = results[start].replace( '${:04x}'.format(get_local_address(address)), 'Function{:x}'.format(address) ) # rst FarCall bank = top_bank # by the way, this doesn't follow callba/callab patterns = [] for addr_pattern in [r'(\$[0-9a-fA-F]+)', r'([\._0-9a-zA-Z]+)']: newline = r'.*?\n\t*?' a = 'ld a, ' + addr_pattern + newline hl = 'ld hl, ' + addr_pattern + newline farcall = r'rst [\$8|FarCall]' patterns += [a + hl + farcall] patterns += [hl + a + farcall] matches = [] for pattern in patterns: matches = re.finditer(pattern, output) for match in matches: bank, addr = match.expand(r'\1').replace('$','0x'), match.expand(r'\2').replace('$','0x') if pattern.startswith('ld hl'): addr, bank = bank, addr if debug: sys.stdout.write('FarCall ' + str(bank) + ' ' + addr + '\n') sys.stdout.flush() try: # address addr = eval(addr) if addr > 0x8000: if debug: sys.stdout.write(output) sys.stdout.flush() continue bank = eval(bank) address = addr if addr < 0x4000 else (addr + bank * 0x4000 - (0x4000 if bank > 0 else 0)) except: # label address = find_address_from_label(addr) if address is None: continue # has this address been parsed already? if not results.has_key(address): results = parse_unknowns(address, results, debug, default_bank) # if a farcall is getting parsed there's no reason to not use the new label if results[start]: results[start] = results[start].replace( 'ld a, %s\n\tld hl, %s\n\trst FarCall' % (match.expand(r'\1'), match.expand(r'\2')), 'callba Function%x' % (address) ).replace( 'ld hl, %s\n\tld a, %s\n\trst FarCall' % (match.expand(r'\1'), match.expand(r'\2')), 'callab Function%x' % (address) ) # callab/callba bank = top_bank patterns = [] for addr_pattern in [r'([\._0-9a-zA-Z_]+)']: # r'(\$[0-9a-fA-F]+)', patterns += [r'(callab) ' + addr_pattern] patterns += [r'(callba) ' + addr_pattern] matches = [] for pattern in patterns: matches = re.finditer(pattern, output) for match in matches: addr = match.expand(r'\2').replace('$','0x') if debug: sys.stdout.write(match.expand(r'\1') + ' ' + str(bank) + ' ' + addr + '\n') sys.stdout.flush() address = find_address_from_label(addr) if address is None: continue # has this address been parsed already? if not results.has_key(address): results = parse_unknowns(address, results, debug, default_bank) # rst JumpTable bank = top_bank patterns = [] for addr_pattern in [r'(\$[0-9a-fA-F]+)', r'([\._0-9a-zA-Z]+)']: newline = r'.*?\n\t*?' hl = 'ld hl, ' + addr_pattern jumptable = r'rst [\$28|JumpTable]' patterns += [hl + newline + jumptable] manual_jumptable = """ add hl, de add hl, de ld a, [hli] ld h, [hl] ld l, a jp [hl]""".replace('[', r'\[').replace(']', r'\]') patterns += [hl + manual_jumptable] matches = [] for pattern in patterns: matches = re.finditer(pattern, output) for match in matches: addr_string = match.expand(r'\1') addr = addr_string.replace('$','0x') if debug: sys.stdout.write('JumpTable ' + str(bank) + ' ' + addr + '\n') sys.stdout.flush() try: # address addr = eval(addr) if addr > 0x8000: if debug: sys.stdout.write(output) sys.stdout.flush() continue address = addr if addr < 0x4000 else (addr + bank * 0x4000 - (0x4000 if bank > 0 else 0)) except: # label address = find_address_from_label(addr) if address is None: continue # has this address been parsed already? if not results.has_key(address): results = parse_unknown_table(address, bank, results, debug, default_bank=default_bank) if results[start]: label = 'Jumptable_%x' % address results[start] = results[start].replace(addr_string, label) return results def parse_unknown_table(start_address, bank, results={}, debug=True, default_bank=1): # tag this as read results[start_address] = None address = start_address label = find_label(get_local_address(address), bank) if not label: label = 'Jumptable_%x' % address def pointer_at(a): return rom[a] + rom[a+1] * 0x100 def expand_pointer(a, b): if b == 0 or a < 0x4000: return a return a + b * 0x4000 - 0x4000 table = '' table += '%s: ; %x\n' % (label, address) # assume at least one pointer first_pointer = pointer_at(address) print hex(first_pointer) pointers = [] while len(pointers) < 0x100: if get_local_address(address) == first_pointer: break if pointer_at(address) >= 0x8000: break if bank > 0 and pointer_at(address) < 0x4000: break pointers += [pointer_at(address)] first_pointer = min(first_pointer, pointer_at(address)) address += 2 for pointer in pointers: pointer_address = expand_pointer(pointer, bank) pointer_label = find_label(pointer, bank) if pointer_label is None: pointer_label = 'Function%x' % pointer_address table += '\tdw %s\n' % pointer_label results = parse_unknowns(pointer_address, results, debug=debug, default_bank=default_bank) table += '; %x' % address if debug: print table sys.stdout.flush() results[start_address] = table return results def spit_out_unknowns(start): last_address = start for address, asm in sorted(parse_unknowns(start).items()): if asm: if address != last_address: print '\nINCBIN "baserom.gbc", $%x, $%x - $%x\n\n' % (address, last_address, address) print asm + '\n' last_address = int(asm.split('\n')[-1].replace(';','').strip(), 16) def insert_unknowns_into_incbins(start, f='main.asm', dry=False, **kwargs): unknowns = sorted(parse_unknowns(start, **kwargs).items(),) if not dry: insert_asm_into_incbins(unknowns, f) def insert_unknown_table_into_incbins(start, f='main.asm', dry=False, **kwargs): unknowns = sorted(parse_unknown_table(start, start/0x4000, **kwargs).items()) if not dry: insert_asm_into_incbins(unknowns, f) def insert_asm_into_incbins(asms, f='main.asm'): text = open(f, 'r').read() text = text.decode('utf8') lines = text.split('\n') for address, asm in asms: if asm: asm = asm.decode('utf8') for i, line in enumerate(lines): if 'INCBIN "baserom.gbc"' in line: start, length = [eval(s.replace('$','0x')) for s in line.split(',')[1:]] end = start + length if start <= address < end: try: last_address = int(asm.split('\n')[-1].replace(';','').strip(), 16) except: print 'not an address' print asm sys.stdout.flush() break output = '' if address > start: output += 'INCBIN "baserom.gbc", $%x, $%x - $%x\n' % (start, address, start) output += '\n' output += asm if last_address < end: output += '\n\n' output += 'INCBIN "baserom.gbc", $%x, $%x - $%x' % (last_address, end, last_address) elif last_address > end: # back out break try: lines.remove(line) for l in output.split('\n')[::-1]: lines.insert(i, l) except UnicodeWarning: print 'UnicodeWarning' print line sys.stdout.flush() break text = '\n'.join(lines) text = text.encode('utf8') with open(f, 'w') as out: out.write(text) if __name__ == '__main__': ap = argparse.ArgumentParser() ap.add_argument('-t', '--table', action='store_true') ap.add_argument('--dry-run', dest='dry', action='store_true') ap.add_argument('-o', dest='filename', default='main.asm') ap.add_argument('--default-bank', dest='default_bank', default='1') ap.add_argument('offset') args = ap.parse_args() offset = args.offset filename = args.filename table = args.table dry = args.dry default_bank = int(args.default_bank, 16) offset = gbaddr_int(offset) if table: insert_unknown_table_into_incbins(offset, filename, dry, default_bank=default_bank) else: insert_unknowns_into_incbins(offset, filename, dry, default_bank=default_bank)