Spaces:
Sleeping
Sleeping
File size: 3,118 Bytes
c80ed73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import gradio as gr
import torch
import subprocess
import tempfile
def process_code(pytorch_code):
output = {}
try:
# Step 1: Write the input code to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as code_file:
code_file.write(pytorch_code.encode())
code_file.flush()
# Step 2: Run the PyTorch code to generate TorchScript
script_output_path = tempfile.NamedTemporaryFile(delete=False, suffix=".pt").name
subprocess.run(
["python3", code_file.name, script_output_path],
check=True,
capture_output=True,
)
# Step 3: Convert TorchScript to MLIR (torch dialect)
torch_mlir_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mlir").name
subprocess.run(
["torchscript-to-mlir", "-i", script_output_path, "-o", torch_mlir_path],
check=True,
capture_output=True,
)
with open(torch_mlir_path, "r") as file:
output["Torch Dialect"] = file.read()
# Step 4: Lower Torch dialect to Linalg dialect
linalg_mlir_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mlir").name
subprocess.run(
["mlir-opt", torch_mlir_path, "-convert-torch-to-linalg", "-o", linalg_mlir_path],
check=True,
capture_output=True,
)
with open(linalg_mlir_path, "r") as file:
output["Linalg Dialect"] = file.read()
# Step 5: Lower Linalg dialect to GPU dialect
gpu_mlir_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mlir").name
subprocess.run(
["mlir-opt", linalg_mlir_path, "-convert-linalg-to-gpu", "-o", gpu_mlir_path],
check=True,
capture_output=True,
)
with open(gpu_mlir_path, "r") as file:
output["GPU Dialect"] = file.read()
# Step 6: Lower GPU dialect to LLVM dialect
llvm_mlir_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mlir").name
subprocess.run(
["mlir-opt", gpu_mlir_path, "-convert-gpu-to-llvm", "-o", llvm_mlir_path],
check=True,
capture_output=True,
)
with open(llvm_mlir_path, "r") as file:
output["LLVM Dialect"] = file.read()
except subprocess.CalledProcessError as e:
output["Error"] = f"An error occurred: {e.stderr.decode()}"
return output
# Gradio interface
iface = gr.Interface(
fn=process_code,
inputs="text",
outputs=[
gr.Textbox(label="Torch Dialect"),
gr.Textbox(label="Linalg Dialect"),
gr.Textbox(label="GPU Dialect"),
gr.Textbox(label="LLVM Dialect"),
],
title="PyTorch to MLIR Lowering",
description="Input PyTorch code for matrix multiplication to see each lowering step in MLIR.",
)
if __name__ == "__main__":
iface.launch()
|