#!/usr/bin/python3 # The OS/K Team licenses this file to you under the MIT license. # See the LICENSE file in the project root for more information. import re import os import sys import subprocess from array import array from tempfile import TemporaryFile from collections import OrderedDict #print("k-as command line: '{}'".format(sys.argv)) WANT_DISASM = False if len(sys.argv) != 5: print("Usage: {} (source file) (memory entry point) (output file) (symbols file)" .format(sys.argv[0])) sys.exit(1) source = TemporaryFile(mode="w+") instrs = TemporaryFile(mode="w+") b_data = TemporaryFile(mode="w+b") b_text = TemporaryFile(mode="w+b") lst_regs = open(os.path.join(sys.path[0], "regs.lst")) lst_instrs = open(os.path.join(sys.path[0], "instrs.lst")) main_src = open(sys.argv[1]) b_out = open(sys.argv[3], "wb") b_sym = open(sys.argv[4], "w") start_addr = int(sys.argv[2], base=0) # os.chdir(os.path.dirname(sys.argv[1])) def leave(i): source.close() instrs.close() b_out.close() b_sym.close() b_data.close() b_text.close() main_src.close() lst_regs.close() lst_instrs.close() sys.exit(i) #------------------------------------------------------------------------------- # Defines pdefs = dict() # registers pregs = dict() # instructions pinstrs = list() # labels plabels_text = OrderedDict() plabels_data = OrderedDict() # size of .data section pdata = 0 # size of .text section ptext = 0 # for local labels plastlabel = '' # file currently being parsed pcurfile = sys.argv[1] # after parse() is done, pdata and ptext are never modified #------------------------------------------------------------------------------- def name_valid(name): for c in name.lower(): if not(c in 'abcdefghijklmnopqrstuvwxyz0123456789[$._+]=,'): return False return True def is_number(s): try: int(s, base=0) except ValueError: return False return True arith_expr = re.compile(r'((0x[0-9A-Fa-f]+|[0-9]+)\s*([|&^+\-*]|<<|>>))+\s*(0x[0-9A-Fa-f]+|[0-9]+)') def arith_eval(s): if arith_expr.match(s): return str(eval(s)) return s #------------------------------------------------------------------------------- def parse_lst_regs(): global pregs i = 0 for _, line in enumerate(lst_regs): line = line.strip() if len(line) == 0: continue for reg in line.split(): pregs[reg] = i i += 1 def parse_lst_instrs(): global pinstrs for _, line in enumerate(lst_instrs): pinstrs.append(line.strip()) #------------------------------------------------------------------------------- inc_depth = 0 inc_depth_max = 16 # Quickly goes through source file and resolves "include" directives ONLY def do_includes(fi): global inc_depth global pcurfile for _, line in enumerate(fi): line = line.rstrip() tok = line.split(None, 1) if len(tok) == 0: continue if tok[0] == "include": if len(tok) == 1: print("Missing parameter for include directive") leave(1) if tok[1][0] not in "'\"" or tok[1][-1] != tok[1][0]: print("Invalid format for include directive: {}".format(line)) leave(1) old_curf = pcurfile pcurfile = tok[1][1:-1] try: new_fi = open(pcurfile, "r") except: print("Couldn't open file: {}".format(line)) leave(1) inc_depth += 1 if inc_depth >= inc_depth_max: print("Maximal include depth reached: {}".format(line)) leave(1) source.write("$file: {}:\n".format(pcurfile.replace(' ', ''))) do_includes(new_fi) pcurfile = old_curf source.write("$file: {}:\n".format(pcurfile.replace(' ', ''))) else: source.write("{}\n".format(line)) inc_depth -= 1 #------------------------------------------------------------------------------- def parse(): global ptext global pcurfile global plastlabel source.seek(0) pcurfile = sys.argv[1] for ln_no, line in enumerate(source): line = line.rstrip() if len(line) == 0: continue # len("$file: ") == 7 if len(line) > 7 and line[:7] == "$file: ": pcurfile = line[7:] continue quote = False for i in range(len(line)): if line[i] in "'\"": quote = not quote if line[i] in '#;@!/' and not quote: line = line[:i].rstrip() break if quote: print("Unterminated string in line: {}".format(line)) leave(1) if len(line) == 0: continue if line[0] == ' ' or line[0] == '\t': line = line.lstrip() instrs.write(pcurfile + ' ') ptext += parse_instr(line) instrs.write("\n") continue # Preprocessor or label? if line[-1] == ':': if name_valid(line[:-1]): label = line[:-1] if label[0] == '.': label = plastlabel + label else: plastlabel = label plabels_text[label] = ptext else: print("Bad label name: {}".format(line[:-1])) leave(1) continue # Preprocessor, .data, or invalid parse_preproc(line) #------------------------------------------------------------------------------- escape_dict = { 'n': '\n', 't': '\t', 'r': '\r', 'v': '\v', 'f': '\f', '"': '"', '\'': '\'', '\\': '\\', } def parse_preproc(line): global pdata tok = line.split(None, 2) # preprocessor if len(tok) > 1 and tok[1] == ':=': if len(tok) < 3: print("Invalid format: {}".format(line)) leave(1) s = tok[0] if s in pdefs: s = pdefs[s] if s[0] == '.': s = plastlabel + s pdefs[s] = tok[2] return # .data if len(tok) > 1 and tok[1] == '=': if len(tok) < 3: print("Invalid format: {}".format(line)) leave(1) label = tok[0] if label[0] == '.': label = plastlabel + label plabels_data[label] = pdata # number data if is_number(tok[2]): written = b_data.write(int(tok[2], base=0).to_bytes(8, byteorder='little', signed=False)) assert(written == 8) pdata += written # buffer / bss elif tok[2][0] == '[': assert(tok[2][-1] == ']') s = tok[2][1:-1].strip() if s[0] == '.': s = plastlabel + s if s in pdefs: s = pdefs[s] if not is_number(s): print("Invalid bss format: {}".format(line)) leave(1) i = int(s, base=0) if (i % 8) != 0: i = i + (8 - i % 8) written = b_data.write(bytearray(i)) assert(written == i) pdata += written # string data elif tok[2][0] in "'\"": s = tok[2].strip() assert(s[-1] == tok[2][0]) s = s[1:-1] real_len = 0 escaping = False for c in s: # escape sequences if not escaping and c == '\\': escaping = True continue if escaping: escaping = False if c in escape_dict: c = escape_dict[c] else: print("Unrecognized escape sequence: {}".format(line)) leave(1) b_data.write(ord(c).to_bytes(1, byteorder='little', signed=False)) real_len += 1 pdata += 1 b_data.write(int(0).to_bytes(1, byteorder='little', signed=False)) pdata += 1 l = real_len + 1 # s + null-term # align if (l % 8) != 0: for i in range(8 - l % 8): written = b_data.write(int(0).to_bytes(1, byteorder='little', signed=False)) pdata += 1 pdefs[label + "_len"] = str(real_len) else: print("Invalid format: {}".format(line)) leave(1) return print("Unrecognized directive: {}".format(line)) leave(1) #------------------------------------------------------------------------------- pconds = { 'c': 0b00001, 'o': 0b00010, 'z': 0b00011, 'e': 0b00011, 's': 0b00100, 'pe': 0b00101, 'po': 0b10101, 'b': 0b00001, 'be': 0b00110, 'l': 0b00111, 'le': 0b01000, 'a': 0b10110, # nbe 'ae': 0b10001, # nb 'g': 0b11000, # nle 'ge': 0b10111, # nl 'axz': 0b01001, 'bxz': 0b01010, 'cxz': 0b01011, 'dxz': 0b01100, 'axnz': 0b11001, 'bxnz': 0b11010, 'cxnz': 0b11011, 'dxnz': 0b11100, } def get_cond_mask(cond, line): mask = 0 if cond[0] == 'n': cond = cond[1:] mask = 0b10000 if cond not in pconds: print("Invalid condition suffix: {}".format(line)) leave(1) return (mask | pconds[cond]) #------------------------------------------------------------------------------- pfts = { "reg": 0b00001, "imm64": 0b00010, "bimm64": 0b00100, "brr": 0b00101, "brri": 0b00110, "brrii": 0b00111, "wimm64": 0b01000, "wrr": 0b01001, "wrri": 0b01010, "wrrii": 0b01011, "limm64": 0b01100, "lrr": 0b01101, "lrri": 0b01110, "lrrii": 0b01111, "qimm64": 0b10000, "qrr": 0b10001, "qrri": 0b10010, "qrrii": 0b10011, } def get_fts_mask(ft, line): if ft not in pfts: print("Invalid operand format ({}): {}".format(ft, line)) else: return pfts[ft] #------------------------------------------------------------------------------- def parse_instr(line): if line == None or len(line) == 0: return 0 tok = line.split(None, 1) instr = tok[0].strip() if len(tok) > 1: params = tok[1].strip() else: params = None fellthrough = False size = 4 # Word 2 (rep|ft3|ft2|ft1) w2 = 0 cond = None if len(instr) > 2 and '.' in instr: instr, suf = instr.split('.', 1) if len(instr) == 0: print("Missing instruction name before suffixes: {}".format(line)) if len(suf) > 2 and suf[:3] == "rep": if len(suf) > 3: suf = suf[3:] if len(suf) > 0 and suf[0] == '.': suf = suf[1:] else: suf = '' w2 |= 0x8000 # REP if len(suf) > 0: instrs.write("%%cond ") cond = "%%imm16 {}".format(get_cond_mask(suf, line)) instr_name = instr instr_args = '' if params == None or len(params) == 0: instrs.write("{} ".format(instr_name)) if cond != None: size += 2 instrs.write("{} ".format(cond)) instrs.write("%%imm16 {}".format(w2)) return size tok = params.split(',') # FTn fts = '' # # Parse operands, generating fts along the way # for word in tok: word = word.strip() instr_args += ' ' gotPref = False if len(fts) != 0: fts += ' ' if len(word) == 0: print("Wrong syntax in line: '{}'".format(line)) leave(1) # local labels if word[0] == '.': word = plastlabel + word # preprocessor if word in pdefs: word = pdefs[word] # Fall through # arithmetic expressions word = arith_eval(word) # memory length prefixes if len(word) > 2 and '[' in word: if word[0] in 'bwlq': fts += word[0] gotPref = True else: print("Bad memory length prefix: {}".format(line)) leave(1) word = word[1:].strip() assert(word[0] == '[') # # Determine memory format and save it into fts # if word[0] == '[': assert(word[-1] == ']') word = word[1:-1] # preprocessor, again if word in pdefs: word = pdefs[word] # Fall through # # Make sure we got an access length prefix # if not gotPref: print("Missing access length modifier: {}".format(line)) leave(1) instr_name += "_m" # cheap way of getting [reg - imm] to work word = word.replace('-', '+ -') # remove every spaces! word = word.replace(' ', '') # # Offsets # if '+' in word: reg1 = "rzx" reg2 = "rzx" imm1 = '1' imm2 = '0' wtok = word.split('+') # # [reg] and [reg*imm16] # if len(wtok) == 1: if '*' in wtok[0]: assert(len(wtok[0].split('*')) == 2) reg2, imm1 = wtok[0].split('*', 1) else: reg1 = wtok[0] # # [reg+reg], [reg+imm16], [reg*imm16+imm16], [reg+reg*imm16] # elif len(wtok) == 2: # Must be [reg*imm16+imm16] if '*' in wtok[0]: assert(len(wtok[0].split('*')) == 2) assert(is_number(wtok[1].strip())) print(wtok) reg2, imm1 = wtok[0].split('*', 1) imm2 = wtok[1] # Must be [reg+reg*imm16] elif '*' in wtok[1]: assert(len(wtok[1].split('*')) == 2) reg1 = wtok[0] reg2, imm1 = wtok[1].split('*', 1) elif is_number(wtok[1].strip()): reg1 = wtok[0] imm2 = wtok[1] # Must be [reg+reg] else: reg1 = wtok[0] reg2 = wtok[1] # # [reg+reg+imm16], [reg+reg*imm16+imm16] # else: assert(len(wtok) == 3) reg1 = wtok[0] imm2 = wtok[2] if '*' in wtok[1]: assert(len(wtok[1].split('*')) == 2) reg2, imm1 = wtok[1].split('*', 1) else: reg2 = wtok[1] # # Update fts and instr_args # instr_args += "{}:{} ".format(reg2, reg1) size += 2 if imm1 == '1': if imm2 == '0': fts += 'rr' else: fts += 'rri' size += 2 instr_args += "%%imm16 {}".format(imm2) else: size += 4 fts += 'rrii' instr_args += "%%imm16 {} %%imm16 {}".format(imm1, imm2) continue # # [imm64] or [reg] # else: fellthrough = True # FALLTHROUGH # preprocessor, yet again if word in pdefs: word = pdefs[word] # Fall through # characters 'c' if len(word) == 3 and word[0] == word[-1] == "'": word = str(ord(word[1])) # register index $reg if len(word) == 4 and word[0] == '$': if word[1:] in pregs: word = str(pregs[word[1:]]) # for now every immediate is 64-bit if is_number(word): # +8 for immediate size += 8 if not fellthrough: instr_name += "_i" fts += "imm64" instr_args += "%%imm64 " instr_args += word fellthrough = False continue # register if word in pregs: size += 2 if not fellthrough: instr_name += "_r" fts += "reg" else: fts += "rr" instr_args += word fellthrough = False continue # it's a label (a 64-bit immediate) # +8 for immediate size += 8 if not fellthrough: instr_name += "_i" fts += "imm64" instr_args += "%%imm64 " if word[0] == '.': instr_args += plastlabel instr_args += word fellthrough = False # # Compute FTn # l = len(fts.split()) if l == 3: ft1, ft2, ft3 = fts.split() w2 |= get_fts_mask(ft3, line) << 10 w2 |= get_fts_mask(ft2, line) << 5 w2 |= get_fts_mask(ft1, line) elif l == 2: ft1, ft2 = fts.split() w2 |= get_fts_mask(ft2, line) << 5 w2 |= get_fts_mask(ft1, line) else: assert(l == 1) w2 |= get_fts_mask(fts, line) if cond == None: instrs.write("{} %%imm16 {}{}".format(instr_name, w2, instr_args)) else: size += 2 instrs.write("{} {} %%imm16 {}{}".format(instr_name, cond, w2, instr_args)) return size #------------------------------------------------------------------------------- special_syms = { "%%cond", "%%imm16", "%%imm32", "%%imm64", "%%signed", } def gentext(): text_start = start_addr data_start = text_start + ptext if (data_start % 8) != 0: data_start += (8 - data_start % 8) instrs.seek(0) cond_mask = 0 for _, line in enumerate(instrs): tok = line.strip().split() if WANT_DISASM: print(tok) tok = tok[1:] for word in tok: if len(word) == 0: continue if word in pregs: idx = pregs[word] b_text.write(idx.to_bytes(2, byteorder='little', signed=False)) continue if ':' in word: if len(word.split(':')) < 2: print("Stray ':' in line: {}".format(line)) leave(1) reg2, reg1 = word.split(':', 1) idx1 = pregs[reg1] idx2 = pregs[reg2] b_text.write(((idx1 << 8) | idx2).to_bytes(2, byteorder='little', signed=False)) continue if word in pinstrs: idx = pinstrs.index(word) | cond_mask cond_mask = 0 b_text.write(idx.to_bytes(2, byteorder='little', signed=False)) continue if word in plabels_text: addr = text_start + plabels_text[word] b_text.write(addr.to_bytes(8, byteorder='little', signed=False)) continue if word in plabels_data: addr = data_start + plabels_data[word] b_text.write(addr.to_bytes(8, byteorder='little', signed=False)) continue if word in special_syms: if word == "%%imm16": lastimm = 2 elif word == "%%imm32": lastimm = 4 elif word == "%%imm64": lastimm = 8 elif word == "%%cond": cond_mask = (1 << 13) elif word == "%%signed": lastimm = 2 isSigned = True else: isSigned = False continue if is_number(word): if word[0] == '-': isSigned = True else: isSigned = False b_text.write(int(word, base=0).to_bytes(lastimm, byteorder='little', signed=isSigned)) continue print("Assembly error, unknown token '{}' in line: {}".format(word, line)) leave(1) #------------------------------------------------------------------------------- def sort_by_list(dict_, list_): for key in list_: dict_.move_to_end(key) def gensym(): text_start = start_addr data_start = text_start + ptext if (data_start % 8) != 0: data_start += (8 - data_start % 8) for label in plabels_text: plabels_text[label] += text_start for label in plabels_data: plabels_data[label] += data_start plabels_all = OrderedDict(list(plabels_text.items()) + list(plabels_data.items())) for key, value in sorted(plabels_all.items(), key=lambda item: item[1]): b_sym.write("{} {}\n".format(key, value)) #------------------------------------------------------------------------------- def genout(): b_text.seek(0) b_data.seek(0) b_out.write(b_text.read()) if (ptext % 8) != 0: data_align = (8 - ptext % 8) for i in range(data_align): b_out.write(int(0).to_bytes(1, byteorder='little', signed=False)) b_out.write(b_data.read()) #------------------------------------------------------------------------------- parse_lst_instrs() parse_lst_regs() do_includes(main_src) parse() gentext() genout() gensym() leave(0)