File size: 3,118 Bytes
c80ed73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import torch
import subprocess
import tempfile

def process_code(pytorch_code):
    output = {}

    try:
        # Step 1: Write the input code to a temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as code_file:
            code_file.write(pytorch_code.encode())
            code_file.flush()

            # Step 2: Run the PyTorch code to generate TorchScript
            script_output_path = tempfile.NamedTemporaryFile(delete=False, suffix=".pt").name
            subprocess.run(
                ["python3", code_file.name, script_output_path],
                check=True,
                capture_output=True,
            )

            # Step 3: Convert TorchScript to MLIR (torch dialect)
            torch_mlir_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mlir").name
            subprocess.run(
                ["torchscript-to-mlir", "-i", script_output_path, "-o", torch_mlir_path],
                check=True,
                capture_output=True,
            )
            with open(torch_mlir_path, "r") as file:
                output["Torch Dialect"] = file.read()

            # Step 4: Lower Torch dialect to Linalg dialect
            linalg_mlir_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mlir").name
            subprocess.run(
                ["mlir-opt", torch_mlir_path, "-convert-torch-to-linalg", "-o", linalg_mlir_path],
                check=True,
                capture_output=True,
            )
            with open(linalg_mlir_path, "r") as file:
                output["Linalg Dialect"] = file.read()

            # Step 5: Lower Linalg dialect to GPU dialect
            gpu_mlir_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mlir").name
            subprocess.run(
                ["mlir-opt", linalg_mlir_path, "-convert-linalg-to-gpu", "-o", gpu_mlir_path],
                check=True,
                capture_output=True,
            )
            with open(gpu_mlir_path, "r") as file:
                output["GPU Dialect"] = file.read()

            # Step 6: Lower GPU dialect to LLVM dialect
            llvm_mlir_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mlir").name
            subprocess.run(
                ["mlir-opt", gpu_mlir_path, "-convert-gpu-to-llvm", "-o", llvm_mlir_path],
                check=True,
                capture_output=True,
            )
            with open(llvm_mlir_path, "r") as file:
                output["LLVM Dialect"] = file.read()

    except subprocess.CalledProcessError as e:
        output["Error"] = f"An error occurred: {e.stderr.decode()}"

    return output

# Gradio interface
iface = gr.Interface(
    fn=process_code,
    inputs="text",
    outputs=[
        gr.Textbox(label="Torch Dialect"),
        gr.Textbox(label="Linalg Dialect"),
        gr.Textbox(label="GPU Dialect"),
        gr.Textbox(label="LLVM Dialect"),
    ],
    title="PyTorch to MLIR Lowering",
    description="Input PyTorch code for matrix multiplication to see each lowering step in MLIR.",
)

if __name__ == "__main__":
    iface.launch()