import gradio as gr
import re
from utils import read_pcd, render_point_cloud, render_pcd_file, set_seed
from inference.utils import get_legend
from inference.inference import segment_obj, get_heatmap
from huggingface_hub import login
import os


os.chdir("Pointcept/libs/pointops")
os.system("python setup.py install")
os.chdir("../../../")

login(token=os.getenv('hfkey'))

parts_dict = {
    "fireplug": "bonnet of a fireplug,side cap of a fireplug,barrel of a fireplug,base of a fireplug",
    "mickey": "ear,head,arms,hands,body,legs",
    "motorvehicle": "wheel of a motor vehicle,seat of a motor vehicle,handle of a motor vehicle",
    "teddy": "head,body,arms,legs",
    "lamppost": "lighting of a lamppost,pole of a lamppost",
    "shirt": "sleeve of a shirt,collar of a shirt,body of a shirt",
    "capybara": "hat worn by a capybara,head,body,feet",
    "corgi": "head,leg,body,ear",
    "pushcar": "wheel,body,handle",
    "plant": "pot,plant",
    "chair": "back of chair,leg,seat"
}

source_dict = {
    "fireplug":"objaverse",
    "mickey":"objaverse",
    "motorvehicle":"objaverse",
    "teddy":"objaverse",
    "lamppost":"objaverse",
    "shirt":"objaverse",
    "capybara": "wild",
    "corgi": "wild",
    "pushcar": "wild",
    "plant": "wild",
    "chair": "wild"
}

def predict(pcd_path, inference_mode, part_queries):
    set_seed()
    xyz, rgb, normal = read_pcd(pcd_path)
    if inference_mode == "Segmentation":
        parts = [part.strip(" ") for part in re.split(r'[,;.|]', part_queries)]
        if len(parts)< 2:
            raise gr.Error("For segmentation mode, please provide 2 or more parts", duration=5)
        seg_rgb = segment_obj(xyz, rgb, normal, parts).cpu().numpy()
        legend = get_legend(parts)
        return render_point_cloud(xyz, seg_rgb, legend=legend)
    elif inference_mode == "Localization":
        if "," in part_queries or ";" in part_queries or "." in part_queries:
            raise gr.Error("For localization mode, please provide only one part", duration=5)
        heatmap_rgb = get_heatmap(xyz, rgb, normal, part_queries).cpu().numpy()
        return render_point_cloud(xyz, heatmap_rgb)
    else:
        return None

def on_select(evt: gr.SelectData):
    obj_name = evt.value['image']['orig_name'][:-4]
    src = source_dict[obj_name]
    return [f"examples/{src}/{obj_name}.pcd", parts_dict[obj_name]]


with gr.Blocks(theme=gr.themes.Default(text_size="lg", radius_size="none")) as demo:
    gr.HTML(
        '''<h1 text-align="center">Find Any Part in 3D</h1>
        <p style='font-size: 16px;'>This is a demo for Find3D: Find Any Part in 3D! Two modes are supported: <b>segmentation</b> and <b>localization</b>.
        <br>
        For <b>segmentation mode</b>, please provide multiple part queries in the "queries" text box, in the format of comma-separated string, such as "part1,part2,part3".
        After hitting "Run", the model will segment the object into the provided parts.
        <br>
        For <b>localization mode</b>, please only provide <b>one query string</b> in the "queries" text box. After hitting "Run", the model will generate a heatmap for the provided query text.
        Please click on the buttons below "Objaverse" and "In the Wild" for some examples. You can also upload your own .pcd files.</p>
        <p style='font-size: 16px;'>Hint: 
        When uploading your own point cloud, please first close the existing point cloud by clicking on the "x" button.
        <br>
        We show some sample queries for the provided examples. When working with your own point cloud, feel free to rephrase the query (e.g. "part" vs "part of a object") to achieve better performance!</p>
        '''
    )

    with gr.Row(variant="panel"):
        with gr.Column(scale=4):
            file_upload = gr.File(
                label="Upload Point Cloud File",
                type="filepath",
                file_types=[".pcd"],
                value="examples/objaverse/lamppost.pcd"
            )
            inference_mode = gr.Radio(
                choices=["Segmentation", "Localization"],
                label="Inference Mode",
                value="Segmentation",
            )
            part_queries = gr.Textbox(
                label="Part Queries",
                value="lighting of a lamppost,pole of a lamppost",
            )
            run_button = gr.Button(
                value="Run",
                variant="primary",
            )

        with gr.Column(scale=4):
            input_image = gr.Image(label="Input Image", visible=False, type='pil', image_mode='RGBA', height=290)
            input_point_cloud = gr.Plot(label="Input Point Cloud")

        with gr.Column(scale=4):
            output_point_cloud = gr.Plot(label="Output Result")

    with gr.Row(variant="panel"):
        with gr.Column(scale=6):
            title = gr.HTML('''<h1 text-align="center">Objaverse</h1>
        <p style='font-size: 16px;'>Online 3D assets from Objaverse!</p>
        ''')
            gallery_objaverse = gr.Gallery([("examples/objaverse/lamppost.jpg", "lamppost"),
                                  ("examples/objaverse/fireplug.jpg", "fireplug"),
                                  ("examples/objaverse/mickey.jpg", "Mickey"),
                                  ("examples/objaverse/motorvehicle.jpg", "motor vehicle"),
                                  ("examples/objaverse/teddy.jpg", "teddy bear"),
                                  ("examples/objaverse/shirt.jpg", "shirt")],
                                  columns=3,
                                  allow_preview=False)
            gallery_objaverse.select(fn=on_select, 
                           inputs=None, 
                           outputs=[file_upload, part_queries])
        with gr.Column(scale=6):
            title = gr.HTML("""<h1 text-align="center">In the Wild</h1>
        <p style='font-size: 16px;'>Challenging in-the-wild reconstructions from iPhone photos & AI-generated images!</p>
        """)
            gallery_wild = gr.Gallery([("examples/wild/capybara.png", "DALLE-capybara"),
                                  ("examples/wild/corgi.jpg", "DALLE-corgi"),
                                  ("examples/wild/plant.jpg", "iPhone-plant"),
                                  ("examples/wild/pushcar.jpg", "iPhone-pushcar"),
                                  ("examples/wild/chair.jpg", "iPhone-chair")],
                                  columns=3,
                                  allow_preview=False)
            gallery_wild.select(fn=on_select, 
                           inputs=None, 
                           outputs=[file_upload, part_queries])

    file_upload.change(
        fn=render_pcd_file,
        inputs=[file_upload],
        outputs=[input_point_cloud],
    )
    run_button.click(
        fn=predict,
        inputs=[file_upload, inference_mode, part_queries],
        outputs=[output_point_cloud],
    )
    demo.load(
        fn=render_pcd_file,
        inputs=[file_upload],
        outputs=[input_point_cloud]) # initialize

demo.launch()