1
0
mirror of https://gitlab.os-k.eu/os-k-team/kvisc.git synced 2023-08-25 14:05:46 +02:00
kvisc/as/k-as.py
2019-08-06 22:47:39 +02:00

870 lines
22 KiB
Python
Executable File

#!/usr/bin/python3
# The OS/K Team licenses this file to you under the MIT license.
# See the LICENSE file in the project root for more information.
import re
import os
import sys
import subprocess
from array import array
from tempfile import TemporaryFile
from collections import OrderedDict
#print("k-as command line: '{}'".format(sys.argv))
WANT_DISASM = False
if len(sys.argv) != 5:
print("Usage: {} (source file) (memory entry point) (output file) (symbols file)"
.format(sys.argv[0]))
sys.exit(1)
source = TemporaryFile(mode="w+")
instrs = TemporaryFile(mode="w+")
b_data = TemporaryFile(mode="w+b")
b_text = TemporaryFile(mode="w+b")
lst_regs = open(os.path.join(sys.path[0], "regs.lst"))
lst_instrs = open(os.path.join(sys.path[0], "instrs.lst"))
main_src = open(sys.argv[1])
b_out = open(sys.argv[3], "wb")
b_sym = open(sys.argv[4], "w")
start_addr = int(sys.argv[2], base=0)
# os.chdir(os.path.dirname(sys.argv[1]))
def leave(i):
source.close()
instrs.close()
b_out.close()
b_sym.close()
b_data.close()
b_text.close()
main_src.close()
lst_regs.close()
lst_instrs.close()
sys.exit(i)
#-------------------------------------------------------------------------------
# Defines
pdefs = dict()
# registers
pregs = dict()
# instructions
pinstrs = list()
# labels
plabels_text = OrderedDict()
plabels_data = OrderedDict()
# size of .data section
pdata = 0
# size of .text section
ptext = 0
# for local labels
plastlabel = ''
# file currently being parsed
pcurfile = sys.argv[1]
# after parse() is done, pdata and ptext are never modified
#-------------------------------------------------------------------------------
def name_valid(name):
for c in name.lower():
if not(c in 'abcdefghijklmnopqrstuvwxyz0123456789[$._+]=,'):
return False
return True
def is_number(s):
try:
int(s, base=0)
except ValueError:
return False
return True
arith_expr = re.compile(r'((0x[0-9A-Fa-f]+|[0-9]+)\s*([|&^+\-*]|<<|>>))+\s*(0x[0-9A-Fa-f]+|[0-9]+)')
def arith_eval(s):
if arith_expr.match(s):
return str(eval(s))
return s
#-------------------------------------------------------------------------------
def parse_lst_regs():
global pregs
i = 0
for _, line in enumerate(lst_regs):
line = line.strip()
if len(line) == 0:
continue
for reg in line.split():
pregs[reg] = i
i += 1
def parse_lst_instrs():
global pinstrs
for _, line in enumerate(lst_instrs):
pinstrs.append(line.strip())
#-------------------------------------------------------------------------------
inc_depth = 0
inc_depth_max = 16
# Quickly goes through source file and resolves "include" directives ONLY
def do_includes(fi):
global inc_depth
global pcurfile
for _, line in enumerate(fi):
line = line.rstrip()
tok = line.split(None, 1)
if len(tok) == 0:
continue
if tok[0] == "include":
if len(tok) == 1:
print("Missing parameter for include directive")
leave(1)
if tok[1][0] not in "'\"" or tok[1][-1] != tok[1][0]:
print("Invalid format for include directive: {}".format(line))
leave(1)
old_curf = pcurfile
pcurfile = tok[1][1:-1]
try:
new_fi = open(pcurfile, "r")
except:
print("Couldn't open file: {}".format(line))
leave(1)
inc_depth += 1
if inc_depth >= inc_depth_max:
print("Maximal include depth reached: {}".format(line))
leave(1)
source.write("$file: {}:\n".format(pcurfile.replace(' ', '')))
do_includes(new_fi)
pcurfile = old_curf
source.write("$file: {}:\n".format(pcurfile.replace(' ', '')))
else:
source.write("{}\n".format(line))
inc_depth -= 1
#-------------------------------------------------------------------------------
def parse():
global ptext
global pcurfile
global plastlabel
source.seek(0)
pcurfile = sys.argv[1]
for ln_no, line in enumerate(source):
line = line.rstrip()
if len(line) == 0:
continue
# len("$file: ") == 7
if len(line) > 7 and line[:7] == "$file: ":
pcurfile = line[7:]
continue
quote = False
for i in range(len(line)):
if line[i] in "'\"":
quote = not quote
if line[i] in '#;@!/' and not quote:
line = line[:i].rstrip()
break
if quote:
print("Unterminated string in line: {}".format(line))
leave(1)
if len(line) == 0:
continue
if line[0] == ' ' or line[0] == '\t':
line = line.lstrip()
instrs.write(pcurfile + ' ' + hex(ptext) + ' ')
ptext += parse_instr(line)
instrs.write("\n")
continue
# Preprocessor or label?
if line[-1] == ':':
if name_valid(line[:-1]):
label = line[:-1]
if label[0] == '.':
label = plastlabel + label
else:
plastlabel = label
plabels_text[label] = ptext
else:
print("Bad label name: {}".format(line[:-1]))
leave(1)
continue
# Preprocessor, .data, or invalid
parse_preproc(line)
#-------------------------------------------------------------------------------
escape_dict = {
'n': '\n',
't': '\t',
'r': '\r',
'v': '\v',
'f': '\f',
'"': '"',
'\'': '\'',
'\\': '\\',
}
def parse_preproc(line):
global pdata
tok = line.split(None, 2)
# preprocessor
if len(tok) > 1 and tok[1] == ':=':
if len(tok) < 3:
print("Invalid format: {}".format(line))
leave(1)
s = tok[0]
if s in pdefs:
s = pdefs[s]
if s[0] == '.':
s = plastlabel + s
pdefs[s] = tok[2]
return
# .data
if len(tok) > 1 and tok[1] == '=':
if len(tok) < 3:
print("Invalid format: {}".format(line))
leave(1)
label = tok[0]
if label[0] == '.':
label = plastlabel + label
plabels_data[label] = pdata
# number data
if is_number(tok[2]):
written = b_data.write(int(tok[2], base=0).to_bytes(8, byteorder='little', signed=False))
assert(written == 8)
pdata += written
# buffer / bss
elif tok[2][0] == '[':
assert(tok[2][-1] == ']')
s = tok[2][1:-1].strip()
if s[0] == '.':
s = plastlabel + s
if s in pdefs:
s = pdefs[s]
if not is_number(s):
print("Invalid bss format: {}".format(line))
leave(1)
i = int(s, base=0)
# if (i % 8) != 0:
# i = i + (8 - i % 8)
written = b_data.write(bytearray(i))
assert(written == i)
pdata += written
# string data
elif tok[2][0] in "'\"":
s = tok[2].strip()
assert(s[-1] == tok[2][0])
s = s[1:-1]
real_len = 0
escaping = False
for c in s:
# escape sequences
if not escaping and c == '\\':
escaping = True
continue
if escaping:
escaping = False
if c in escape_dict:
c = escape_dict[c]
else:
print("Unrecognized escape sequence: {}".format(line))
leave(1)
b_data.write(ord(c).to_bytes(1, byteorder='little', signed=False))
real_len += 1
pdata += 1
b_data.write(int(0).to_bytes(1, byteorder='little', signed=False))
pdata += 1
l = real_len + 1 # s + null-term
# align
#if (l % 8) != 0:
# for i in range(8 - l % 8):
# written = b_data.write(int(0).to_bytes(1, byteorder='little', signed=False))
# pdata += 1
pdefs[label + "_len"] = str(real_len)
else:
print("Invalid format: {}".format(line))
leave(1)
return
print("Unrecognized directive: {}".format(line))
leave(1)
#-------------------------------------------------------------------------------
pconds = {
'c': 0b00001,
'o': 0b00010,
'z': 0b00011,
'e': 0b00011,
's': 0b00100,
'pe': 0b00101,
'po': 0b10101,
'b': 0b00001,
'be': 0b00110,
'l': 0b00111,
'le': 0b01000,
'a': 0b10110, # nbe
'ae': 0b10001, # nb
'g': 0b11000, # nle
'ge': 0b10111, # nl
'axz': 0b01001,
'bxz': 0b01010,
'cxz': 0b01011,
'dxz': 0b01100,
'axnz': 0b11001,
'bxnz': 0b11010,
'cxnz': 0b11011,
'dxnz': 0b11100,
}
def get_cond_mask(cond, line):
mask = 0
if cond[0] == 'n':
cond = cond[1:]
mask = 0b10000
if cond not in pconds:
print("Invalid condition suffix: {}".format(line))
leave(1)
return (mask | pconds[cond])
#-------------------------------------------------------------------------------
fmts = {
"r": 0b00000000,
"m_r": 0b00100000,
"m_rr": 0b01000000,
"m_rriw": 0b01100000,
"m_rrid": 0b10000000,
"m_rrii": 0b10100000,
"m_riq": 0b11000000,
"imm8": 0b11100001,
"imm16": 0b11100010,
"imm32": 0b11100100,
"imm64": 0b11101000,
}
pref2len = {
"b" : 1,
"w" : 2,
"d" : 4,
"l" : 4,
"q" : 8,
}
def parse_instr(line):
if line == None or len(line) == 0:
return 0
tok = line.split(None, 1)
instr = tok[0].strip()
if len(tok) > 1:
params = tok[1].strip()
else:
params = None
size = 1
# Byte 2 (rep|lock|0|cond)
b2 = 0
if len(instr) > 2 and '.' in instr:
instr, suf = instr.split('.', 1)
if len(instr) == 0:
print("Missing instruction name before suffixes: {}".format(line))
if len(suf) > 2 and suf[:3] == "rep":
if len(suf) > 3:
suf = suf[3:]
if len(suf) > 0 and suf[0] == '.':
suf = suf[1:]
else:
suf = ''
b2 |= 1<<7 # REP
if len(suf) > 0:
b2 |= get_cond_mask(suf, line)
instr_name = instr
instr_args = ''
if params == None or len(params) == 0:
if b2 == 0:
instrs.write("{}".format(instr_name))
else:
size += 1
instrs.write("%%suff {} %%imm8 {}".format(instr_name, b2))
return size
tok = params.split(',')
# 'call' special case... temporary
if instr_name == 'call':
if len(tok) == 2:
instr_name = 'xcall2'
elif len(tok) == 3:
instr_name = 'xcall3'
#
# Parse operands
#
for word in tok:
word = word.strip()
instr_args += ' '
mlen = 0
if len(word) == 0:
print("Wrong syntax in line: '{}'".format(line))
leave(1)
# local labels
if word[0] == '.':
word = plastlabel + word
# preprocessor
if word in pdefs:
word = pdefs[word]
# Fall through
# arithmetic expressions
word = arith_eval(word)
# memory length prefixes
if len(word) > 2 and '[' in word:
if word[0] in 'bwldq':
mlen = pref2len[word[0]]
else:
print("Bad memory length prefix: {}".format(line))
leave(1)
word = word[1:].strip()
assert(word[0] == '[')
#
# Determine memory format
#
if word[0] in '[(':
assert(word[-1] in '])')
word = word[1:-1]
# preprocessor, again
if word in pdefs:
word = pdefs[word]
# Fall through
# Make sure we got an access length prefix
if mlen == 0:
print("Missing access length modifier: {}".format(line))
leave(1)
# cheap way of getting [reg - imm] to work
word = word.replace('-', '+ -')
# remove every spaces!
word = word.replace(' ', '')
#
# Offsets
#
if '+' in word:
reg1 = "zero"
reg2 = "zero"
imm1 = '1'
imm2 = '0'
wtok = word.split('+')
#
# [reg] and [reg*imm]
#
if len(wtok) == 1:
if '*' in wtok[0]:
assert(len(wtok[0].split('*')) == 2)
reg2, imm1 = wtok[0].split('*', 1)
else:
reg1 = wtok[0]
#
# [reg+reg], [reg+imm], [reg*imm+imm], [reg+reg*imm]
#
elif len(wtok) == 2:
# Must be [reg*imm+imm]
if '*' in wtok[0]:
assert(len(wtok[0].split('*')) == 2)
assert(is_number(wtok[1].strip()))
reg2, imm1 = wtok[0].split('*', 1)
imm2 = wtok[1]
# Must be [reg+reg*imm]
elif '*' in wtok[1]:
assert(len(wtok[1].split('*')) == 2)
reg1 = wtok[0]
reg2, imm1 = wtok[1].split('*', 1)
elif is_number(wtok[1].strip()):
reg1 = wtok[0]
imm2 = wtok[1]
# Must be [reg+reg]
else:
reg1 = wtok[0]
reg2 = wtok[1]
#
# [reg+reg+imm], [reg+reg*imm8+imm]
#
else:
assert(len(wtok) == 3)
reg1 = wtok[0]
imm2 = wtok[2]
if '*' in wtok[1]:
assert(len(wtok[1].split('*')) == 2)
reg2, imm1 = wtok[1].split('*', 1)
else:
reg2 = wtok[1]
#
# Update instr_args
#
if imm1 == '1':
# [reg+reg]
if imm2 == '0':
instr_args += "%%imm8 {} {} {}".format(fmts["m_rr"]|mlen, '$'+reg1, '$'+reg2)
size += 3
# [reg+reg+imm]
else:
instr_args += "%%imm8 {} {} {} %%signed %%imm16 {}"\
.format(fmts["m_rriw"]|mlen, '$'+reg1, '$'+reg2, imm2)
size += 5
# [reg+reg*imm+imm]
else:
instr_args += "%%imm8 {} {} {} %%imm8 {} %%signed %%imm32 {}"\
.format(fmts["m_rrii"]|mlen, '$'+reg1, '$'+reg2, imm1, imm2)
size += 8
# [reg]
elif '$'+word in pregs:
instr_args += "%%imm8 {} {}".format(fmts["m_r"]|mlen, '$'+word)
size += 2
# [imm], converted to [zero+imm]
else:
instr_args += "%%imm8 {} $zero $zero %%signed %%imm32 {}".format(fmts["m_rrid"]|mlen, word)
size += 7
continue
# preprocessor, yet again
if word in pdefs:
word = pdefs[word]
# fallthrough
# characters 'c'
if len(word) == 3 and word[0] == word[-1] == "'":
word = str(ord(word[1]))
# fallthrough
# register index $reg
if len(word) == 4 and word[0] == '$':
if '$'+word[1:] in pregs:
word = str(pregs['$'+word[1:]])
# fallthrough
# immediates
if is_number(word):
n = int(word, base=0)
if n < 0 or n > 0xFFFFFFFF:
size += 9
instr_args += "%%imm8 {} ".format(fmts["imm64"])
instr_args += "%%imm64 {}".format(word)
elif n > 0xFFFF:
size += 5
instr_args += "%%imm8 {} ".format(fmts["imm32"])
instr_args += "%%imm32 {}".format(word)
elif n > 0xFF:
size += 3
instr_args += "%%imm8 {} ".format(fmts["imm16"])
instr_args += "%%imm16 {}".format(word)
else:
size += 2
instr_args += "%%imm8 {} ".format(fmts["imm8"])
instr_args += "%%imm8 {}".format(word)
continue
# register
elif '$'+word in pregs:
size += 1
instr_args += '$'+word
continue
# it's a label (a 32-bit immediate)
# ModRM + imm
size += 5
instr_args += "%%imm8 {} ".format(fmts["imm32"])
if word[0] == '.':
instr_args += plastlabel
instr_args += word
if b2 == 0:
instrs.write("{} {}".format(instr_name, instr_args))
else:
size += 1
instrs.write("%%suff {} %%imm8 {} {}".format(instr_name, b2, instr_args))
return size
#-------------------------------------------------------------------------------
special_syms = {
"%%suff",
"%%imm8",
"%%imm16",
"%%imm32",
"%%imm64",
"%%signed",
}
def gentext():
text_start = start_addr
data_start = text_start + ptext
if (data_start % 8) != 0:
data_start += (8 - data_start % 8)
instrs.seek(0)
suff_mask = 0
for _, line in enumerate(instrs):
tok = line.strip().split()
if WANT_DISASM:
print(tok)
tok = tok[2:]
for word in tok:
if len(word) == 0:
continue
if word in pregs:
idx = pregs[word]
b_text.write(idx.to_bytes(1, byteorder='little', signed=False))
continue
if word in pinstrs:
idx = pinstrs.index(word) | suff_mask
b_text.write(idx.to_bytes(1, byteorder='little', signed=False))
suff_mask = 0
continue
if word in plabels_text:
addr = text_start + plabels_text[word]
b_text.write(addr.to_bytes(4, byteorder='little', signed=False))
continue
if word in plabels_data:
addr = data_start + plabels_data[word]
b_text.write(addr.to_bytes(4, byteorder='little', signed=False))
continue
if word in special_syms:
if word == "%%suff":
suff_mask = 1<<7
elif word == "%%imm8":
lastimm = 1
elif word == "%%imm16":
lastimm = 2
elif word == "%%imm32":
lastimm = 4
elif word == "%%imm64":
lastimm = 8
elif word == "%%signed":
lastimm = 2
isSigned = True
else:
isSigned = False
continue
if is_number(word):
if word[0] == '-':
isSigned = True
else:
isSigned = False
b_text.write(int(word, base=0).to_bytes(lastimm, byteorder='little', signed=isSigned))
continue
print("Assembly error, unknown token '{}' in line: {}".format(word, line))
leave(1)
#-------------------------------------------------------------------------------
def sort_by_list(dict_, list_):
for key in list_:
dict_.move_to_end(key)
def gensym():
text_start = start_addr
data_start = text_start + ptext
if (data_start % 8) != 0:
data_start += (8 - data_start % 8)
for label in plabels_text:
plabels_text[label] += text_start
for label in plabels_data:
plabels_data[label] += data_start
plabels_all = OrderedDict(list(plabels_text.items()) + list(plabels_data.items()))
for key, value in sorted(plabels_all.items(), key=lambda item: item[1]):
b_sym.write("{} {}\n".format(key, value))
#-------------------------------------------------------------------------------
def genout():
b_text.seek(0)
b_data.seek(0)
b_out.write(b_text.read())
if (ptext % 8) != 0:
data_align = (8 - ptext % 8)
for i in range(data_align):
b_out.write(int(0).to_bytes(1, byteorder='little', signed=False))
b_out.write(b_data.read())
#-------------------------------------------------------------------------------
parse_lst_instrs()
parse_lst_regs()
do_includes(main_src)
parse()
gentext()
genout()
gensym()
#-------------------------------------------------------------------------------
print("Finished producing {}\n\ttext section size: {} bytes\n\tdata section size: {} bytes\n" \
.format(sys.argv[3], ptext, pdata))
leave(0)
#-------------------------------------------------------------------------------