import gradio as gr import torch import subprocess import tempfile def process_code(pytorch_code): output = {} try: # Step 1: Write the input code to a temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as code_file: code_file.write(pytorch_code.encode()) code_file.flush() # Step 2: Run the PyTorch code to generate TorchScript #script_output_path = tempfile.NamedTemporaryFile(delete=False, suffix=".pt").name subprocess.run( ["python3", code_file.name], check=True, capture_output=True, ) print("Converted to onnx.") # Step 3: Convert TorchScript to MLIR (torch dialect) torch_mlir_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mlir").name subprocess.run( ["torchscript-to-mlir", "-i", script_output_path, "-o", torch_mlir_path], check=True, capture_output=True, ) with open(torch_mlir_path, "r") as file: output["Torch Dialect"] = file.read() # Step 4: Lower Torch dialect to Linalg dialect linalg_mlir_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mlir").name subprocess.run( ["mlir-opt", torch_mlir_path, "-convert-torch-to-linalg", "-o", linalg_mlir_path], check=True, capture_output=True, ) with open(linalg_mlir_path, "r") as file: output["Linalg Dialect"] = file.read() # Step 5: Lower Linalg dialect to GPU dialect gpu_mlir_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mlir").name subprocess.run( ["mlir-opt", linalg_mlir_path, "-convert-linalg-to-gpu", "-o", gpu_mlir_path], check=True, capture_output=True, ) with open(gpu_mlir_path, "r") as file: output["GPU Dialect"] = file.read() # Step 6: Lower GPU dialect to LLVM dialect llvm_mlir_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mlir").name subprocess.run( ["mlir-opt", gpu_mlir_path, "-convert-gpu-to-llvm", "-o", llvm_mlir_path], check=True, capture_output=True, ) with open(llvm_mlir_path, "r") as file: output["LLVM Dialect"] = file.read() except subprocess.CalledProcessError as e: output["Error"] = f"An error occurred: {e.stderr.decode()}" return output # Gradio interface iface = gr.Interface( fn=process_code, inputs="text", outputs=[ gr.Textbox(label="Torch Dialect", lines=10), gr.Textbox(label="Linalg Dialect"), gr.Textbox(label="GPU Dialect"), gr.Textbox(label="LLVM Dialect"), ], title="PyTorch to MLIR Lowering", description="Input PyTorch code for matrix multiplication to see each lowering step in MLIR.", ) if __name__ == "__main__": iface.queue() iface.launch(server_name="0.0.0.0", server_port=7860)