kvisc/as/k-as.py

893 lines
21 KiB
Python
Executable File

#!/usr/bin/python3
# The OS/K Team licenses this file to you under the MIT license.
# See the LICENSE file in the project root for more information.
import re
import os
import sys
import subprocess
from array import array
from tempfile import TemporaryFile
from collections import OrderedDict
#print("k-as command line: '{}'".format(sys.argv))
WANT_DISASM = False
if len(sys.argv) != 5:
print("Usage: {} (source file) (memory entry point) (output file) (symbols file)"
.format(sys.argv[0]))
sys.exit(1)
source = TemporaryFile(mode="w+")
instrs = TemporaryFile(mode="w+")
b_data = TemporaryFile(mode="w+b")
b_text = TemporaryFile(mode="w+b")
lst_regs = open(os.path.join(sys.path[0], "regs.lst"))
lst_instrs = open(os.path.join(sys.path[0], "instrs.lst"))
main_src = open(sys.argv[1])
b_out = open(sys.argv[3], "wb")
b_sym = open(sys.argv[4], "w")
start_addr = int(sys.argv[2], base=0)
# os.chdir(os.path.dirname(sys.argv[1]))
def leave(i):
source.close()
instrs.close()
b_out.close()
b_sym.close()
b_data.close()
b_text.close()
main_src.close()
lst_regs.close()
lst_instrs.close()
sys.exit(i)
#-------------------------------------------------------------------------------
# Defines
pdefs = dict()
# registers
pregs = dict()
# instructions
pinstrs = list()
# labels
plabels_text = OrderedDict()
plabels_data = OrderedDict()
# size of .data section
pdata = 0
# size of .text section
ptext = 0
# for local labels
plastlabel = ''
# file currently being parsed
pcurfile = sys.argv[1]
# after parse() is done, pdata and ptext are never modified
#-------------------------------------------------------------------------------
def name_valid(name):
for c in name.lower():
if not(c in 'abcdefghijklmnopqrstuvwxyz0123456789[$._+]=,'):
return False
return True
def is_number(s):
try:
int(s, base=0)
except ValueError:
return False
return True
arith_expr = re.compile(r'((0x[0-9A-Fa-f]+|[0-9]+)\s*([|&^+\-*]|<<|>>))+\s*(0x[0-9A-Fa-f]+|[0-9]+)')
def arith_eval(s):
if arith_expr.match(s):
return str(eval(s))
return s
#-------------------------------------------------------------------------------
def parse_lst_regs():
global pregs
i = 0
for _, line in enumerate(lst_regs):
line = line.strip()
if len(line) == 0:
continue
for reg in line.split():
pregs[reg] = i
i += 1
def parse_lst_instrs():
global pinstrs
for _, line in enumerate(lst_instrs):
pinstrs.append(line.strip())
#-------------------------------------------------------------------------------
inc_depth = 0
inc_depth_max = 16
# Quickly goes through source file and resolves "include" directives ONLY
def do_includes(fi):
global inc_depth
for _, line in enumerate(fi):
line = line.rstrip()
tok = line.split(None, 1)
if len(tok) == 0:
continue
if tok[0] == "include":
if len(tok) == 1:
print("Missing parameter for include directive")
leave(1)
if tok[1][0] not in "'\"" or tok[1][-1] != tok[1][0]:
print("Invalid format for include directive: {}".format(line))
leave(1)
inc = tok[1][1:-1]
try:
new_fi = open(inc, "r")
except:
print("Couldn't open file: {}".format(line))
leave(1)
inc_depth += 1
if inc_depth >= inc_depth_max:
print("Maximal include depth reached: {}".format(line))
leave(1)
source.write("$file: {}:\n".format(inc.replace(' ', '')))
do_includes(new_fi)
else:
source.write("{}\n".format(line))
inc_depth -= 1
#-------------------------------------------------------------------------------
def parse():
global ptext
global pcurfile
global plastlabel
source.seek(0)
for ln_no, line in enumerate(source):
line = line.rstrip()
if len(line) == 0:
continue
# len("$file: ") == 7
if len(line) > 7 and line[:7] == "$file: ":
pcurfile = line[7:]
continue
quote = False
for i in range(len(line)):
if line[i] in "'\"":
quote = not quote
if line[i] in '#;@!/' and not quote:
line = line[:i].rstrip()
break
if quote:
print("Unterminated string in line: {}".format(line))
leave(1)
if len(line) == 0:
continue
if line[0] == ' ' or line[0] == '\t':
line = line.lstrip()
instrs.write(pcurfile + ' ')
ptext += parse_instr(line)
instrs.write("\n")
continue
# Preprocessor or label?
if line[-1] == ':':
if name_valid(line[:-1]):
label = line[:-1]
if label[0] == '.':
label = plastlabel + label
else:
plastlabel = label
plabels_text[label] = ptext
else:
print("Bad label name: {}".format(line[:-1]))
leave(1)
continue
# Preprocessor, .data, or invalid
parse_preproc(line)
#-------------------------------------------------------------------------------
escape_dict = {
'n': '\n',
't': '\t',
'r': '\r',
'v': '\v',
'f': '\f',
'"': '"',
'\'': '\'',
'\\': '\\',
}
def parse_preproc(line):
global pdata
tok = line.split(None, 2)
# preprocessor
if len(tok) > 1 and tok[1] == ':=':
if len(tok) < 3:
print("Invalid format: {}".format(line))
leave(1)
s = tok[0]
if s in pdefs:
s = pdefs[s]
if s[0] == '.':
s = plastlabel + s
pdefs[s] = tok[2]
return
# .data
if len(tok) > 1 and tok[1] == '=':
if len(tok) < 3:
print("Invalid format: {}".format(line))
leave(1)
label = tok[0]
if label[0] == '.':
label = plastlabel + label
plabels_data[label] = pdata
# number data
if is_number(tok[2]):
written = b_data.write(int(tok[2], base=0).to_bytes(8, byteorder='little', signed=False))
assert(written == 8)
pdata += written
# buffer / bss
elif tok[2][0] == '[':
assert(tok[2][-1] == ']')
s = tok[2][1:-1].strip()
if s[0] == '.':
s = plastlabel + s
if s in pdefs:
s = pdefs[s]
if not is_number(s):
print("Invalid bss format: {}".format(line))
leave(1)
i = int(s, base=0)
if (i % 8) != 0:
i = i + (8 - i % 8)
written = b_data.write(bytearray(i))
assert(written == i)
pdata += written
# string data
elif tok[2][0] in "'\"":
s = tok[2].strip()
assert(s[-1] == tok[2][0])
s = s[1:-1]
real_len = 0
escaping = False
for c in s:
# escape sequences
if not escaping and c == '\\':
escaping = True
continue
if escaping:
escaping = False
if c in escape_dict:
c = escape_dict[c]
else:
print("Unrecognized escape sequence: {}".format(line))
leave(1)
b_data.write(ord(c).to_bytes(1, byteorder='little', signed=False))
real_len += 1
pdata += 1
b_data.write(int(0).to_bytes(1, byteorder='little', signed=False))
pdata += 1
l = real_len + 1 # s + null-term
# align
if (l % 8) != 0:
for i in range(8 - l % 8):
written = b_data.write(int(0).to_bytes(1, byteorder='little', signed=False))
pdata += 1
pdefs[label + "_len"] = str(real_len)
else:
print("Invalid format: {}".format(line))
leave(1)
return
print("Unrecognized directive: {}".format(line))
leave(1)
#-------------------------------------------------------------------------------
pconds = {
'c': 0b00001,
'o': 0b00010,
'z': 0b00011,
'e': 0b00011,
's': 0b00100,
'p': 0b00101,
'a': 0b00110,
'ae': 0b00111,
'b': 0b01000,
'be': 0b01001,
'g': 0b01010,
'ge': 0b01011,
'l': 0b01100,
'le': 0b01101,
'cxz': 0b01110,
'cxnz': 0b11110,
}
def get_cond_mask(cond, line):
mask = 0
if cond[0] == 'n':
cond = cond[1:]
mask = 0b10000
if cond not in pconds:
print("Invalid condition suffix: {}".format(line))
leave(1)
return (mask | pconds[cond])
#-------------------------------------------------------------------------------
pfts = {
"reg": 0b00001,
"imm64": 0b00010,
"bimm64": 0b00100,
"brr": 0b00101,
"brri": 0b00110,
"brrii": 0b00111,
"wimm64": 0b01000,
"wrr": 0b01001,
"wrri": 0b01010,
"wrrii": 0b01011,
"limm64": 0b01100,
"lrr": 0b01101,
"lrri": 0b01110,
"lrrii": 0b01111,
"qimm64": 0b10000,
"qrr": 0b10001,
"qrri": 0b10010,
"qrrii": 0b10011,
}
def get_fts_mask(ft, line):
if ft not in pfts:
print("Invalid operand format ({}): {}".format(ft, line))
else:
return pfts[ft]
#-------------------------------------------------------------------------------
def parse_instr(line):
if line == None or len(line) == 0:
return 0
tok = line.split(None, 1)
instr = tok[0].strip()
if len(tok) > 1:
params = tok[1].strip()
else:
params = None
fellthrough = False
size = 4
# Word 2 (rep|ft3|ft2|ft1)
w2 = 0
cond = None
if len(instr) > 2 and '.' in instr:
instr, suf = instr.split('.', 1)
if len(instr) == 0:
print("Missing instruction name before suffixes: {}".format(line))
if len(suf) > 2 and suf[:3] == "rep":
if len(suf) > 3:
suf = suf[3:]
if len(suf) > 0 and suf[0] == '.':
suf = suf[1:]
else:
suf = ''
w2 |= 0x8000 # REP
if len(suf) > 0:
instrs.write("%%cond ")
cond = "%%imm16 {}".format(get_cond_mask(suf, line))
instr_name = instr
instr_args = ''
if params == None or len(params) == 0:
instrs.write("{} ".format(instr_name))
if cond != None:
size += 2
instrs.write("{} ".format(cond))
instrs.write("%%imm16 {}".format(w2))
return size
tok = params.split(',')
# FTn
fts = ''
#
# Parse operands, generating fts along the way
#
for word in tok:
word = word.strip()
instr_args += ' '
gotPref = False
if len(fts) != 0:
fts += ' '
if len(word) == 0:
print("Wrong syntax in line: '{}'".format(line))
leave(1)
# local labels
if word[0] == '.':
word = plastlabel + word
# preprocessor
if word in pdefs:
word = pdefs[word]
# Fall through
# arithmetic expressions
word = arith_eval(word)
# memory length prefixes
if len(word) > 2 and '[' in word:
if word[0] in 'bwlq':
fts += word[0]
gotPref = True
else:
print("Bad memory length prefix: {}".format(line))
leave(1)
word = word[1:].strip()
assert(word[0] == '[')
#
# Determine memory format and save it into fts
#
if word[0] == '[':
assert(word[-1] == ']')
word = word[1:-1]
# preprocessor, again
if word in pdefs:
word = pdefs[word]
# Fall through
#
# Make sure we got an access length prefix
#
if not gotPref:
print("Missing access length modifier: {}".format(line))
leave(1)
instr_name += "_m"
# cheap way of getting [reg - imm] to work
word = word.replace('-', '+ -')
# remove every spaces!
word = word.replace(' ', '')
#
# Offsets
#
if '+' in word:
reg1 = "inv"
reg2 = "inv"
imm1 = '1'
imm2 = '0'
wtok = word.split('+')
#
# [reg] and [reg*imm16]
#
if len(wtok) == 1:
if '*' in wtok[0]:
assert(len(wtok[0].split('*')) == 2)
reg2, imm1 = wtok[0].split('*', 1)
else:
reg1 = wtok[0]
#
# [reg+reg], [reg+imm16], [reg*imm16+imm16], [reg+reg*imm16]
#
elif len(wtok) == 2:
# Must be [reg*imm16+imm16]
if '*' in wtok[0]:
assert(len(wtok[0].split('*')) == 2)
assert(is_number(wtok[1].strip()))
print(wtok)
reg2, imm1 = wtok[0].split('*', 1)
imm2 = wtok[1]
# Must be [reg+reg*imm16]
elif '*' in wtok[1]:
assert(len(wtok[1].split('*')) == 2)
reg1 = wtok[0]
reg2, imm1 = wtok[1].split('*', 1)
elif is_number(wtok[1].strip()):
reg1 = wtok[0]
imm2 = wtok[1]
# Must be [reg+reg]
else:
reg1 = wtok[0]
reg2 = wtok[1]
#
# [reg+reg+imm16], [reg+reg*imm16+imm16]
#
else:
assert(len(wtok) == 3)
reg1 = wtok[0]
imm2 = wtok[2]
if '*' in wtok[1]:
assert(len(wtok[1].split('*')) == 2)
reg2, imm1 = wtok[1].split('*', 1)
else:
reg2 = wtok[1]
#
# Update fts and instr_args
#
instr_args += "{}:{} ".format(reg2, reg1)
size += 2
if imm1 == '1':
if imm2 == '0':
fts += 'rr'
else:
fts += 'rri'
size += 2
instr_args += "%%imm16 {}".format(imm2)
else:
size += 4
fts += 'rrii'
instr_args += "%%imm16 {} %%imm16 {}".format(imm1, imm2)
continue
#
# [imm64] or [reg]
#
else:
fellthrough = True
# FALLTHROUGH
# preprocessor, yet again
if word in pdefs:
word = pdefs[word]
# Fall through
# characters 'c'
if len(word) == 3 and word[0] == word[-1] == "'":
word = str(ord(word[1]))
# register index $reg
if len(word) == 4 and word[0] == '$':
if word[1:] in pregs:
word = str(pregs[word[1:]])
# for now every immediate is 64-bit
if is_number(word):
# +8 for immediate
size += 8
if not fellthrough:
instr_name += "_i"
fts += "imm64"
instr_args += "%%imm64 "
instr_args += word
fellthrough = False
continue
# register
if word in pregs:
size += 2
if not fellthrough:
instr_name += "_r"
fts += "reg"
else:
fts += "rr"
instr_args += word
fellthrough = False
continue
# it's a label (a 64-bit immediate)
# +8 for immediate
size += 8
if not fellthrough:
instr_name += "_i"
fts += "imm64"
instr_args += "%%imm64 "
if word[0] == '.':
instr_args += plastlabel
instr_args += word
fellthrough = False
#
# Compute FTn
#
l = len(fts.split())
if l == 3:
ft1, ft2, ft3 = fts.split()
w2 |= get_fts_mask(ft3, line) << 10
w2 |= get_fts_mask(ft2, line) << 5
w2 |= get_fts_mask(ft1, line)
elif l == 2:
ft1, ft2 = fts.split()
w2 |= get_fts_mask(ft2, line) << 5
w2 |= get_fts_mask(ft1, line)
else:
assert(l == 1)
w2 |= get_fts_mask(fts, line)
if cond == None:
instrs.write("{} %%imm16 {}{}".format(instr_name, w2, instr_args))
else:
size += 2
instrs.write("{} {} %%imm16 {}{}".format(instr_name, cond, w2, instr_args))
return size
#-------------------------------------------------------------------------------
special_syms = {
"%%cond",
"%%imm16",
"%%imm32",
"%%imm64",
"%%signed",
}
def gentext():
text_start = start_addr
data_start = text_start + ptext
if (data_start % 8) != 0:
data_start += (8 - data_start % 8)
instrs.seek(0)
cond_mask = 0
for _, line in enumerate(instrs):
tok = line.strip().split()
if WANT_DISASM:
print(tok)
tok = tok[1:]
for word in tok:
if len(word) == 0:
continue
if word in pregs:
idx = pregs[word]
b_text.write(idx.to_bytes(2, byteorder='little', signed=False))
continue
if ':' in word:
if len(word.split(':')) < 2:
print("Stray ':' in line: {}".format(line))
leave(1)
reg2, reg1 = word.split(':', 1)
idx1 = pregs[reg1]
idx2 = pregs[reg2]
b_text.write(((idx1 << 8) | idx2).to_bytes(2, byteorder='little', signed=False))
continue
if word in pinstrs:
idx = pinstrs.index(word) | cond_mask
cond_mask = 0
b_text.write(idx.to_bytes(2, byteorder='little', signed=False))
continue
if word in plabels_text:
addr = text_start + plabels_text[word]
b_text.write(addr.to_bytes(8, byteorder='little', signed=False))
continue
if word in plabels_data:
addr = data_start + plabels_data[word]
b_text.write(addr.to_bytes(8, byteorder='little', signed=False))
continue
if word in special_syms:
if word == "%%imm16":
lastimm = 2
elif word == "%%imm32":
lastimm = 4
elif word == "%%imm64":
lastimm = 8
elif word == "%%cond":
cond_mask = (1 << 13)
elif word == "%%signed":
lastimm = 2
isSigned = True
else:
isSigned = False
continue
if is_number(word):
if word[0] == '-':
isSigned = True
else:
isSigned = False
b_text.write(int(word, base=0).to_bytes(lastimm, byteorder='little', signed=isSigned))
continue
print("Assembly error, unknown token '{}' in line: {}".format(word, line))
leave(1)
#-------------------------------------------------------------------------------
def sort_by_list(dict_, list_):
for key in list_:
dict_.move_to_end(key)
def gensym():
text_start = start_addr
data_start = text_start + ptext
if (data_start % 8) != 0:
data_start += (8 - data_start % 8)
for label in plabels_text:
plabels_text[label] += text_start
for label in plabels_data:
plabels_data[label] += data_start
plabels_all = OrderedDict(list(plabels_text.items()) + list(plabels_data.items()))
for key, value in sorted(plabels_all.items(), key=lambda item: item[1]):
b_sym.write("{} {}\n".format(key, value))
#-------------------------------------------------------------------------------
def genout():
b_text.seek(0)
b_data.seek(0)
b_out.write(b_text.read())
if (ptext % 8) != 0:
data_align = (8 - ptext % 8)
for i in range(data_align):
b_out.write(int(0).to_bytes(1, byteorder='little', signed=False))
b_out.write(b_data.read())
#-------------------------------------------------------------------------------
parse_lst_instrs()
parse_lst_regs()
do_includes(main_src)
parse()
gentext()
genout()
gensym()
leave(0)