Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import subprocess | |
import tempfile | |
def process_code(pytorch_code): | |
output = {} | |
try: | |
# Step 1: Write the input code to a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as code_file: | |
code_file.write(pytorch_code.encode()) | |
code_file.flush() | |
# Step 2: Run the PyTorch code to generate TorchScript | |
script_output_path = tempfile.NamedTemporaryFile(delete=False, suffix=".pt").name | |
subprocess.run( | |
["python3", code_file.name, script_output_path], | |
check=True, | |
capture_output=True, | |
) | |
# Step 3: Convert TorchScript to MLIR (torch dialect) | |
torch_mlir_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mlir").name | |
subprocess.run( | |
["torchscript-to-mlir", "-i", script_output_path, "-o", torch_mlir_path], | |
check=True, | |
capture_output=True, | |
) | |
with open(torch_mlir_path, "r") as file: | |
output["Torch Dialect"] = file.read() | |
# Step 4: Lower Torch dialect to Linalg dialect | |
linalg_mlir_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mlir").name | |
subprocess.run( | |
["mlir-opt", torch_mlir_path, "-convert-torch-to-linalg", "-o", linalg_mlir_path], | |
check=True, | |
capture_output=True, | |
) | |
with open(linalg_mlir_path, "r") as file: | |
output["Linalg Dialect"] = file.read() | |
# Step 5: Lower Linalg dialect to GPU dialect | |
gpu_mlir_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mlir").name | |
subprocess.run( | |
["mlir-opt", linalg_mlir_path, "-convert-linalg-to-gpu", "-o", gpu_mlir_path], | |
check=True, | |
capture_output=True, | |
) | |
with open(gpu_mlir_path, "r") as file: | |
output["GPU Dialect"] = file.read() | |
# Step 6: Lower GPU dialect to LLVM dialect | |
llvm_mlir_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mlir").name | |
subprocess.run( | |
["mlir-opt", gpu_mlir_path, "-convert-gpu-to-llvm", "-o", llvm_mlir_path], | |
check=True, | |
capture_output=True, | |
) | |
with open(llvm_mlir_path, "r") as file: | |
output["LLVM Dialect"] = file.read() | |
except subprocess.CalledProcessError as e: | |
output["Error"] = f"An error occurred: {e.stderr.decode()}" | |
return output | |
# Gradio interface | |
iface = gr.Interface( | |
fn=process_code, | |
inputs="text", | |
outputs=[ | |
gr.Textbox(label="Torch Dialect"), | |
gr.Textbox(label="Linalg Dialect"), | |
gr.Textbox(label="GPU Dialect"), | |
gr.Textbox(label="LLVM Dialect"), | |
], | |
title="PyTorch to MLIR Lowering", | |
description="Input PyTorch code for matrix multiplication to see each lowering step in MLIR.", | |
) | |
if __name__ == "__main__": | |
iface.launch() | |