Spaces:
Build error
Build error
import gradio as gr | |
import shap | |
import transformers | |
import os | |
import re | |
import subprocess | |
import sys | |
import tempfile | |
model = gr.load("ejschwartz/oo-method-test-model-bylibrary", src="models") | |
model_interp = transformers.pipeline("text-classification", "ejschwartz/oo-method-test-model-bylibrary") | |
def get_all_dis(bname, addrs=None): | |
anafile = tempfile.NamedTemporaryFile(prefix=os.path.basename(bname) + "_", suffix=".bat_ana") | |
ananame = anafile.name | |
addrstr = "" | |
if addrs is not None: | |
addrstr = " ".join([f"--function-at {x}" for x in addrs]) | |
subprocess.check_output(f"bat-ana {addrstr} --no-post-analysis -o {ananame} {bname} 2>/dev/null", shell=True) | |
output = subprocess.check_output(f"bat-dis --no-insn-address --no-bb-cfg-arrows --color=off {ananame} 2>/dev/null", shell=True) | |
output = re.sub(b' +', b' ', output) | |
func_dis = {} | |
last_func = None | |
current_output = [] | |
for l in output.splitlines(): | |
if l.startswith(b";;; function 0x"): | |
if last_func is not None: | |
func_dis[last_func] = b"\n".join(current_output) | |
last_func = int(l.split()[2], 16) | |
current_output.clear() | |
if not b";;" in l: | |
current_output.append(l) | |
if last_func is not None: | |
if last_func in func_dis: | |
print("Warning: Ignoring multiple functions at the same address") | |
else: | |
func_dis[last_func] = b"\n".join(current_output) | |
return func_dis | |
def get_funs(f): | |
funs = get_all_dis(f.name) | |
return "\n".join(("%#x" % addr) for addr in funs.keys()) | |
with gr.Blocks() as demo: | |
all_dis_state = gr.State() | |
gr.Markdown( | |
""" | |
# Function/Method Detector | |
First, upload a binary. | |
This model was only trained on 32-bit MSVC++ binaries. You can provide | |
other types of binaries, but the result will probably be gibberish. | |
""" | |
) | |
file_widget = gr.File(label="Binary file") | |
with gr.Column(visible=False) as col: | |
#output = gr.Textbox("Output") | |
gr.Markdown(""" | |
Great, you selected an executable! Now pick the function you would like to analyze. | |
""") | |
fun_dropdown = gr.Dropdown(label="Select a function", choices=["Woohoo!"], interactive=True) | |
gr.Markdown(""" | |
Below you can find the selected function's disassembly, and the model's | |
prediction of whether the function is an object-oriented method or a | |
regular function. | |
""") | |
with gr.Row(visible=True) as result: | |
disassembly = gr.Textbox(label="Disassembly", lines=20) | |
with gr.Column(): | |
clazz = gr.Label() | |
interpret_button = gr.Button("Interpret (very slow)") | |
interpretation = gr.components.Interpretation(disassembly) | |
example_widget = gr.Examples( | |
examples=[f.path for f in os.scandir(os.path.join(os.path.dirname(__file__), "examples"))], | |
inputs=file_widget, | |
outputs=[all_dis_state, disassembly, clazz] | |
) | |
def file_change_fn(file, progress=gr.Progress()): | |
if file is None: | |
return {col: gr.update(visible=False), | |
all_dis_state: None} | |
else: | |
#fun_data = {42: 2, 43: 3} | |
progress(0, desc="Disassembling executable") | |
fun_data = get_all_dis(file.name) | |
addrs = ["%#x" % addr for addr in fun_data.keys()] | |
return {col: gr.update(visible=True), | |
fun_dropdown: gr.Dropdown.update(choices=addrs, value=addrs[0]), | |
all_dis_state: fun_data | |
} | |
def function_change_fn(selected_fun, fun_data): | |
disassembly_str = fun_data[int(selected_fun, 16)].decode("utf-8") | |
load_results = model.fn(disassembly_str) | |
top_k = {e['label']: e['confidence'] for e in load_results['confidences']} | |
return {disassembly: gr.Textbox.update(value=disassembly_str), | |
clazz: gr.Label.update(top_k), | |
# I can't figure out how to hide this | |
#interpretation: {} | |
} | |
# XXX: Ideally we'd use the gr.load model, which uses the huggingface | |
# inference API. But shap library appears to use information in the | |
# transformers pipeline, and I don't feel like figuring out how to | |
# reimplement that, so we'll just use a regular transformers pipeline here | |
# for interpretation. | |
def interpretation_function(text, progress=gr.Progress(track_tqdm=True)): | |
progress(0, desc="Interpreting function") | |
explainer = shap.Explainer(model_interp) | |
shap_values = explainer([text]) | |
# Dimensions are (batch size, text size, number of classes) | |
# Since we care about positive sentiment, use index 1 | |
scores = list(zip(shap_values.data[0], shap_values.values[0, :, 1])) | |
# Scores contains (word, score) pairs | |
# Format expected by gr.components.Interpretation | |
return {"original": text, "interpretation": scores} | |
file_widget.change(file_change_fn, file_widget, [col, fun_dropdown, all_dis_state]) | |
fun_dropdown.change(function_change_fn, [fun_dropdown, all_dis_state], [disassembly, clazz, interpretation]) | |
interpret_button.click(interpretation_function, disassembly, interpretation) | |
demo.queue() | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |