import gradio as gr import re from utils import read_pcd, render_point_cloud, render_pcd_file, set_seed from inference.utils import get_legend from inference.inference import segment_obj, get_heatmap from huggingface_hub import login import os os.chdir("Pointcept/libs/pointops") os.system("python setup.py install") os.chdir("../../../") login(token=os.getenv('hfkey')) parts_dict = { "fireplug": "bonnet of a fireplug,side cap of a fireplug,barrel of a fireplug,base of a fireplug", "mickey": "ear,head,arms,hands,body,legs", "motorvehicle": "wheel of a motor vehicle,seat of a motor vehicle,handle of a motor vehicle", "teddy": "head,body,arms,legs", "lamppost": "lighting of a lamppost,pole of a lamppost", "shirt": "sleeve of a shirt,collar of a shirt,body of a shirt", "capybara": "hat worn by a capybara,head,body,feet", "corgi": "head,leg,body,ear", "pushcar": "wheel,body,handle", "plant": "pot,plant", "chair": "back of chair,leg,seat" } source_dict = { "fireplug":"objaverse", "mickey":"objaverse", "motorvehicle":"objaverse", "teddy":"objaverse", "lamppost":"objaverse", "shirt":"objaverse", "capybara": "wild", "corgi": "wild", "pushcar": "wild", "plant": "wild", "chair": "wild" } def predict(pcd_path, inference_mode, part_queries): set_seed() xyz, rgb, normal = read_pcd(pcd_path) if inference_mode == "Segmentation": parts = [part.strip(" ") for part in re.split(r'[,;.|]', part_queries)] if len(parts)< 2: raise gr.Error("For segmentation mode, please provide 2 or more parts", duration=5) seg_rgb = segment_obj(xyz, rgb, normal, parts).cpu().numpy() legend = get_legend(parts) return render_point_cloud(xyz, seg_rgb, legend=legend) elif inference_mode == "Localization": if "," in part_queries or ";" in part_queries or "." in part_queries: raise gr.Error("For localization mode, please provide only one part", duration=5) heatmap_rgb = get_heatmap(xyz, rgb, normal, part_queries).cpu().numpy() return render_point_cloud(xyz, heatmap_rgb) else: return None def on_select(evt: gr.SelectData): obj_name = evt.value['image']['orig_name'][:-4] src = source_dict[obj_name] return [f"examples/{src}/{obj_name}.pcd", parts_dict[obj_name]] with gr.Blocks(theme=gr.themes.Default(text_size="lg", radius_size="none")) as demo: gr.HTML( '''<h1 text-align="center">Find Any Part in 3D</h1> <p style='font-size: 16px;'>This is a demo for Find3D: Find Any Part in 3D! Two modes are supported: <b>segmentation</b> and <b>localization</b>. <br> For <b>segmentation mode</b>, please provide multiple part queries in the "queries" text box, in the format of comma-separated string, such as "part1,part2,part3". After hitting "Run", the model will segment the object into the provided parts. <br> For <b>localization mode</b>, please only provide <b>one query string</b> in the "queries" text box. After hitting "Run", the model will generate a heatmap for the provided query text. Please click on the buttons below "Objaverse" and "In the Wild" for some examples. You can also upload your own .pcd files.</p> <p style='font-size: 16px;'>Hint: When uploading your own point cloud, please first close the existing point cloud by clicking on the "x" button. <br> We show some sample queries for the provided examples. When working with your own point cloud, feel free to rephrase the query (e.g. "part" vs "part of a object") to achieve better performance!</p> ''' ) with gr.Row(variant="panel"): with gr.Column(scale=4): file_upload = gr.File( label="Upload Point Cloud File", type="filepath", file_types=[".pcd"], value="examples/objaverse/lamppost.pcd" ) inference_mode = gr.Radio( choices=["Segmentation", "Localization"], label="Inference Mode", value="Segmentation", ) part_queries = gr.Textbox( label="Part Queries", value="lighting of a lamppost,pole of a lamppost", ) run_button = gr.Button( value="Run", variant="primary", ) with gr.Column(scale=4): input_image = gr.Image(label="Input Image", visible=False, type='pil', image_mode='RGBA', height=290) input_point_cloud = gr.Plot(label="Input Point Cloud") with gr.Column(scale=4): output_point_cloud = gr.Plot(label="Output Result") with gr.Row(variant="panel"): with gr.Column(scale=6): title = gr.HTML('''<h1 text-align="center">Objaverse</h1> <p style='font-size: 16px;'>Online 3D assets from Objaverse!</p> ''') gallery_objaverse = gr.Gallery([("examples/objaverse/lamppost.jpg", "lamppost"), ("examples/objaverse/fireplug.jpg", "fireplug"), ("examples/objaverse/mickey.jpg", "Mickey"), ("examples/objaverse/motorvehicle.jpg", "motor vehicle"), ("examples/objaverse/teddy.jpg", "teddy bear"), ("examples/objaverse/shirt.jpg", "shirt")], columns=3, allow_preview=False) gallery_objaverse.select(fn=on_select, inputs=None, outputs=[file_upload, part_queries]) with gr.Column(scale=6): title = gr.HTML("""<h1 text-align="center">In the Wild</h1> <p style='font-size: 16px;'>Challenging in-the-wild reconstructions from iPhone photos & AI-generated images!</p> """) gallery_wild = gr.Gallery([("examples/wild/capybara.png", "DALLE-capybara"), ("examples/wild/corgi.jpg", "DALLE-corgi"), ("examples/wild/plant.jpg", "iPhone-plant"), ("examples/wild/pushcar.jpg", "iPhone-pushcar"), ("examples/wild/chair.jpg", "iPhone-chair")], columns=3, allow_preview=False) gallery_wild.select(fn=on_select, inputs=None, outputs=[file_upload, part_queries]) file_upload.change( fn=render_pcd_file, inputs=[file_upload], outputs=[input_point_cloud], ) run_button.click( fn=predict, inputs=[file_upload, inference_mode, part_queries], outputs=[output_point_cloud], ) demo.load( fn=render_pcd_file, inputs=[file_upload], outputs=[input_point_cloud]) # initialize demo.launch()