1
0
mirror of https://gitlab.os-k.eu/os-k-team/kvisc.git synced 2023-08-25 14:05:46 +02:00
kvisc/as/k-as.py
2019-09-07 16:45:03 +02:00

744 lines
19 KiB
Python
Executable File

#!/usr/bin/python3
# The OS/K Team licenses this file to you under the MIT license.
# See the LICENSE file in the project root for more information.
import re
import os
import sys
import subprocess
from array import array
from tempfile import TemporaryFile
from collections import OrderedDict
#print("k-as command line: '{}'".format(sys.argv))
WANT_DISASM = False
if len(sys.argv) != 5:
print("Usage: {} (source file) (memory entry point) (output file) (symbols file)"
.format(sys.argv[0]))
sys.exit(1)
source = TemporaryFile(mode="w+")
instrs = TemporaryFile(mode="w+")
b_data = TemporaryFile(mode="w+b")
b_text = TemporaryFile(mode="w+b")
lst_instrs = open(os.path.join(sys.path[0], "instrs.lst"))
source = open(sys.argv[1])
b_out = open(sys.argv[3], "wb")
b_sym = open(sys.argv[4], "w")
start_addr = int(sys.argv[2], base=0)
# os.chdir(os.path.dirname(sys.argv[1]))
def leave(i):
source.close()
instrs.close()
b_out.close()
b_sym.close()
b_data.close()
b_text.close()
lst_instrs.close()
sys.exit(i)
#-------------------------------------------------------------------------------
# Defines
pdefs = dict()
# instructions
pinstrs = list()
# labels
plabels_text = OrderedDict()
plabels_data = OrderedDict()
# size of .data section
pdata = 0
# size of .text section
ptext = 0
# for local labels
plastlabel = ''
# file currently being parsed
pcurfile = sys.argv[1]
# after parse() is done, pdata and ptext are never modified
#-------------------------------------------------------------------------------
pregs = {
'$r0': 0, '$eax': 0, '$rax': 0,
'$r1': 1, '$ebx': 1, '$rbx': 1,
'$r2': 2, '$ecx': 2, '$rcx': 2,
'$r3': 3, '$edx': 3, '$rdx': 3,
'$r4': 4, '$esi': 4, '$rsi': 4,
'$r5': 5, '$edi': 5, '$rdi': 5,
'$r6': 6, '$ax0': 6,
'$r7': 7, '$ax1': 7,
'$r8': 8, '$ax2': 8,
'$r9': 9, '$ax3': 9,
'$r10': 10, '$ax4': 10,
'$r11': 11, '$ax5': 11,
'$r12': 12, '$nx0': 12,
'$r13': 13, '$nx1': 13,
'$r14': 14, '$nx2': 14,
'$r15': 15, '$nx3': 15,
'$r16': 16, '$nx4': 16,
'$r17': 17, '$nx5': 17,
'$r18': 18, '$nx6': 18,
'$r19': 19, '$nx7': 19,
'$r20': 20, '$nx8': 20,
'$r21': 21, '$grp': 21,
'$r22': 22, '$trp': 22,
'$r23': 23, '$srp': 23,
'$r24': 24, '$tmp': 24,
'$r25': 25, '$rad': 25,
'$r26': 26, '$cr0': 26,
'$r27': 27, '$cr1': 27,
'$r28': 28, '$eip': 28, '$rip': 28,
'$r29': 29, '$ebp': 29, '$rbp': 29,
'$r30': 30, '$esp': 30, '$rsp': 30,
'$r31': 31, '$nul': 31, '$zero': 31,
}
def parse_lst_instrs():
global pinstrs
for _, line in enumerate(lst_instrs):
pinstrs.append(line.strip())
#-------------------------------------------------------------------------------
def name_valid(name):
for c in name.lower():
if not(c in 'abcdefghijklmnopqrstuvwxyz0123456789[$._+]=,'):
return False
return True
def is_number(s):
try:
int(s, base=0)
except ValueError:
return False
return True
arith_expr = re.compile(r'((0x[0-9A-Fa-f]+|[0-9]+)\s*([|&^+\-*]|<<|>>))+\s*(0x[0-9A-Fa-f]+|[0-9]+)')
def arith_eval(s):
if arith_expr.match(s):
return str(eval(s))
return s
#-------------------------------------------------------------------------------
def parse():
global ptext
global pcurfile
global plastlabel
source.seek(0)
pcurfile = sys.argv[1]
for ln_no, line in enumerate(source):
line = line.rstrip()
if len(line) == 0:
continue
if line[0] == '#':
tok = line.split()
if len(tok) < 3:
print("Invalid # directive: {}".format(line))
leave(-1)
pcurfile = tok[2][1:-1]
quote = False
for i in range(len(line)):
if line[i] in "'\"":
quote = not quote
if line[i] in '#;@!/' and not quote:
line = line[:i].rstrip()
break
if quote:
print("Unterminated string in line: {}".format(line))
leave(1)
if len(line) == 0:
continue
if line[0] == ' ' or line[0] == '\t':
line = line.lstrip()
instrs.write(pcurfile + ' ' + hex(ptext) + ' ')
ptext += parse_instr(line)
instrs.write("\n")
continue
# Preprocessor or label?
if line[-1] == ':':
if name_valid(line[:-1]):
label = line[:-1]
if label[0] == '.':
label = plastlabel + label
else:
plastlabel = label
plabels_text[label] = ptext
else:
print("Bad label name: {}".format(line[:-1]))
leave(1)
continue
# Preprocessor, .data, or invalid
parse_preproc(line)
#-------------------------------------------------------------------------------
escape_dict = {
'0': '\0',
'n': '\n',
't': '\t',
'r': '\r',
'v': '\v',
'f': '\f',
'"': '"',
'\'': '\'',
'\\': '\\',
}
def parse_preproc(line):
global pdata
tok = line.split(None, 2)
# preprocessor
if len(tok) > 1 and tok[1] == ':=':
if len(tok) < 3:
print("Invalid format: {}".format(line))
leave(1)
s = tok[0]
if s in pdefs:
s = pdefs[s]
if s[0] == '.':
s = plastlabel + s
pdefs[s] = tok[2]
return
# .data
if len(tok) > 1 and tok[1] == '=':
if len(tok) < 3:
print("Invalid format: {}".format(line))
leave(1)
label = tok[0]
if label[0] == '.':
label = plastlabel + label
plabels_data[label] = pdata
# number data
if is_number(tok[2]):
written = b_data.write(int(tok[2], base=0).to_bytes(8, byteorder='little', signed=False))
assert(written == 8)
pdata += written
# buffer / bss
elif tok[2][0] == '[':
assert(tok[2][-1] == ']')
s = tok[2][1:-1].strip()
if s[0] == '.':
s = plastlabel + s
if s in pdefs:
s = pdefs[s]
if not is_number(s):
print("Invalid bss format: {}".format(line))
leave(1)
i = int(s, base=0)
# if (i % 8) != 0:
# i = i + (8 - i % 8)
written = b_data.write(bytearray(i))
assert(written == i)
pdata += written
# string data
elif tok[2][0] in "'\"":
s = tok[2].strip()
assert(s[-1] == tok[2][0])
s = s[1:-1]
real_len = 0
escaping = False
for c in s:
# escape sequences
if not escaping and c == '\\':
escaping = True
continue
if escaping:
escaping = False
if c in escape_dict:
c = escape_dict[c]
else:
print("Unrecognized escape sequence: {}".format(line))
leave(1)
b_data.write(ord(c).to_bytes(1, byteorder='little', signed=False))
real_len += 1
pdata += 1
b_data.write(int(0).to_bytes(1, byteorder='little', signed=False))
pdata += 1
l = real_len + 1 # s + null-term
# align
#if (l % 8) != 0:
# for i in range(8 - l % 8):
# written = b_data.write(int(0).to_bytes(1, byteorder='little', signed=False))
# pdata += 1
pdefs[label + "_len"] = str(real_len)
else:
print("Invalid format: {}".format(line))
leave(1)
return
print("Unrecognized directive: {}".format(line))
leave(1)
#-------------------------------------------------------------------------------
fmts = {
"r": 0b00000000, "m_r": 0b00100000, "m_rr": 0b01000000,
"m_rriw": 0b01100000, "m_rrid": 0b10000000, "m_rrii": 0b10100000,
"m_riq": 0b11000000, "imm8": 0b11100001, "imm16": 0b11100010,
"imm32": 0b11100100, "imm64": 0b11101000,
}
pref2len = { "b" : 1, "w" : 2, "d" : 4, "l" : 4, "q" : 8 }
def parse_instr(line):
if line == None or len(line) == 0:
return 0
tok = line.split(None, 1)
instr = tok[0].strip()
if len(tok) > 1:
params = tok[1].strip()
else:
params = None
size = 1
instr_name = instr
instr_args = ''
if params == None or len(params) == 0:
instrs.write("{}_0".format(instr_name))
return size
tok = params.split(',')
instr_name += "_{}".format(len(tok))
#
# Parse operands
#
for word in tok:
word = word.strip()
instr_args += ' '
mlen = 0
if len(word) == 0:
print("Wrong syntax in line: '{}'".format(line))
leave(1)
# local labels
if word[0] == '.':
word = plastlabel + word
# preprocessor
if word in pdefs:
word = pdefs[word]
# Fall through
# arithmetic expressions
word = arith_eval(word)
# memory length prefixes
if len(word) > 2 and '[' in word:
if word[0] in 'bwldq':
mlen = pref2len[word[0]]
else:
print("Bad memory length prefix: {}".format(line))
leave(1)
word = word[1:].strip()
assert(word[0] == '[')
#
# Determine memory format
#
if word[0] in '[(':
assert(word[-1] in '])')
word = word[1:-1]
# preprocessor, again
if word in pdefs:
word = pdefs[word]
# Fall through
# Make sure we got an access length prefix
if mlen == 0:
print("Missing access length modifier: {}".format(line))
leave(1)
# cheap way of getting [reg - imm] to work
word = word.replace('-', '+ -')
# remove every spaces!
word = word.replace(' ', '')
#
# Offsets
#
if '+' in word:
reg1 = "zero"
reg2 = "zero"
imm1 = '1'
imm2 = '0'
wtok = word.split('+')
#
# [reg] and [reg*imm]
#
if len(wtok) == 1:
if '*' in wtok[0]:
assert(len(wtok[0].split('*')) == 2)
reg2, imm1 = wtok[0].split('*', 1)
else:
reg1 = wtok[0]
#
# [reg+reg], [reg+imm], [reg*imm+imm], [reg+reg*imm]
#
elif len(wtok) == 2:
# Must be [reg*imm+imm]
if '*' in wtok[0]:
assert(len(wtok[0].split('*')) == 2)
assert(is_number(wtok[1].strip()))
reg2, imm1 = wtok[0].split('*', 1)
imm2 = wtok[1]
# Must be [reg+reg*imm]
elif '*' in wtok[1]:
assert(len(wtok[1].split('*')) == 2)
reg1 = wtok[0]
reg2, imm1 = wtok[1].split('*', 1)
elif is_number(wtok[1].strip()):
reg1 = wtok[0]
imm2 = wtok[1]
# Must be [reg+reg]
else:
reg1 = wtok[0]
reg2 = wtok[1]
#
# [reg+reg+imm], [reg+reg*imm8+imm]
#
else:
assert(len(wtok) == 3)
reg1 = wtok[0]
imm2 = wtok[2]
if '*' in wtok[1]:
assert(len(wtok[1].split('*')) == 2)
reg2, imm1 = wtok[1].split('*', 1)
else:
reg2 = wtok[1]
#
# Update instr_args
#
if imm1 == '1':
# [reg+reg]
if imm2 == '0':
instr_args += "%%imm8 {} {} {}".format(fmts["m_rr"]|mlen, '$'+reg1, '$'+reg2)
size += 3
# [reg+imm16]
elif reg2 == 'zero':
instr_args += "%%imm8 {} {} %%signed %%imm16 {}"\
.format(fmts["m_rriw"]|mlen, '$'+reg1, imm2)
size += 4
# [reg+reg+imm32]
else:
instr_args += "%%imm8 {} {} {} %%signed %%imm32 {}"\
.format(fmts["m_rrid"]|mlen, '$'+reg1, '$'+reg2, imm2)
size += 7
# [reg+reg*imm8+imm32]
else:
instr_args += "%%imm8 {} {} {} %%imm8 {} %%signed %%imm32 {}"\
.format(fmts["m_rrii"]|mlen, '$'+reg1, '$'+reg2, imm1, imm2)
size += 8
# [reg]
elif '$'+word in pregs:
instr_args += "%%imm8 {} {}".format(fmts["m_r"]|mlen, '$'+word)
size += 2
# [imm32], converted to [zero+imm32]
else:
instr_args += "%%imm8 {} $zero $zero %%signed %%imm32 {}".format(fmts["m_rrid"]|mlen, word)
size += 7
continue
# preprocessor, yet again
if word in pdefs:
word = pdefs[word]
# fallthrough
# characters 'c'
if len(word) == 3 and word[0] == word[-1] == "'":
word = str(ord(word[1]))
# fallthrough
# register index $reg
if len(word) == 4 and word[0] == '$':
if '$'+word[1:] in pregs:
word = str(pregs['$'+word[1:]])
# fallthrough
# immediates
if is_number(word):
n = int(word, base=0)
if n < 0 or n > 0xFFFFFFFF:
size += 9
instr_args += "%%imm8 {} ".format(fmts["imm64"])
instr_args += "%%imm64 {}".format(word)
elif n > 0xFFFF:
size += 5
instr_args += "%%imm8 {} ".format(fmts["imm32"])
instr_args += "%%imm32 {}".format(word)
elif n > 0xFF:
size += 3
instr_args += "%%imm8 {} ".format(fmts["imm16"])
instr_args += "%%imm16 {}".format(word)
else:
size += 2
instr_args += "%%imm8 {} ".format(fmts["imm8"])
instr_args += "%%imm8 {}".format(word)
continue
# register
elif '$'+word in pregs:
size += 1
instr_args += '$'+word
continue
# it's a label (a 32-bit immediate)
# ModRM + imm
size += 5
instr_args += "%%imm8 {} ".format(fmts["imm32"])
if word[0] == '.':
instr_args += plastlabel
instr_args += word
instrs.write("{}{}".format(instr_name, instr_args))
return size
#-------------------------------------------------------------------------------
special_syms = {
"%%imm8",
"%%imm16",
"%%imm32",
"%%imm64",
"%%signed",
}
def gentext():
text_start = start_addr
data_start = text_start + ptext
if (data_start % 8) != 0:
data_start += (8 - data_start % 8)
instrs.seek(0)
for _, line in enumerate(instrs):
tok = line.strip().split()
if WANT_DISASM:
print(tok)
tok = tok[2:]
for word in tok:
if len(word) == 0:
continue
if word in pregs:
idx = pregs[word]
b_text.write(idx.to_bytes(1, byteorder='little', signed=False))
continue
if word in pinstrs:
idx = pinstrs.index(word)
b_text.write(idx.to_bytes(1, byteorder='little', signed=False))
continue
if word in plabels_text:
addr = text_start + plabels_text[word]
b_text.write(addr.to_bytes(4, byteorder='little', signed=False))
continue
if word in plabels_data:
addr = data_start + plabels_data[word]
b_text.write(addr.to_bytes(4, byteorder='little', signed=False))
continue
if word in special_syms:
if word == "%%imm8":
lastimm = 1
elif word == "%%imm16":
lastimm = 2
elif word == "%%imm32":
lastimm = 4
elif word == "%%imm64":
lastimm = 8
elif word == "%%signed":
lastimm = 2
isSigned = True
else:
isSigned = False
continue
if is_number(word):
if word[0] == '-':
isSigned = True
else:
isSigned = False
b_text.write(int(word, base=0).to_bytes(lastimm, byteorder='little', signed=isSigned))
continue
print("Assembly error, unknown token '{}' in line: {}".format(word, line))
leave(1)
#-------------------------------------------------------------------------------
def sort_by_list(dict_, list_):
for key in list_:
dict_.move_to_end(key)
def gensym():
text_start = start_addr
data_start = text_start + ptext
if (data_start % 8) != 0:
data_start += (8 - data_start % 8)
for label in plabels_text:
plabels_text[label] += text_start
for label in plabels_data:
plabels_data[label] += data_start
plabels_all = OrderedDict(list(plabels_text.items()) + list(plabels_data.items()))
for key, value in sorted(plabels_all.items(), key=lambda item: item[1]):
b_sym.write("{} {}\n".format(key, value))
#-------------------------------------------------------------------------------
def genout():
b_text.seek(0)
b_data.seek(0)
b_out.write(b_text.read())
if (ptext % 8) != 0:
data_align = (8 - ptext % 8)
for i in range(data_align):
b_out.write(int(0).to_bytes(1, byteorder='little', signed=False))
b_out.write(b_data.read())
#-------------------------------------------------------------------------------
parse_lst_instrs()
parse()
gentext()
genout()
gensym()
#-------------------------------------------------------------------------------
print("Finished producing {}\n\ttext section size: {} bytes\n\tdata section size: {} bytes\n" \
.format(sys.argv[3], ptext, pdata))
leave(0)
#-------------------------------------------------------------------------------