Spaces:
Running
Running
import gradio as gr | |
from paths import * | |
import numpy as np | |
from vision_tower import DINOv2_MLP | |
from transformers import AutoImageProcessor | |
import torch | |
import os | |
from PIL import Image | |
import torch.nn.functional as F | |
from utils import * | |
from huggingface_hub import hf_hub_download | |
ckpt_path = hf_hub_download(repo_id="Viglong/OriNet", filename="celarge/dino_weight.pt", repo_type="model", cache_dir='./', resume_download=True) | |
print(ckpt_path) | |
save_path = './' | |
device = 'cpu' | |
dino = DINOv2_MLP( | |
dino_mode = 'large', | |
in_dim = 1024, | |
out_dim = 360+180+60+2, | |
evaluate = True, | |
mask_dino = False, | |
frozen_back = False | |
).to(device) | |
dino.eval() | |
print('model create') | |
dino.load_state_dict(torch.load(ckpt_path, map_location='cpu')) | |
print('weight loaded') | |
val_preprocess = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./') | |
def get_3angle(image): | |
# image = Image.open(image_path).convert('RGB') | |
image_inputs = val_preprocess(images = image) | |
image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device) | |
with torch.no_grad(): | |
dino_pred = dino(image_inputs) | |
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1) | |
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1) | |
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1) | |
confidence = F.softmax(dino_pred[:, -2:], dim=-1)[0][0] | |
angles = torch.zeros(4) | |
angles[0] = gaus_ax_pred | |
angles[1] = gaus_pl_pred - 90 | |
angles[2] = gaus_ro_pred - 30 | |
angles[3] = confidence | |
return angles | |
def get_3angle_infer_aug(origin_img, rm_bkg_img): | |
# image = Image.open(image_path).convert('RGB') | |
image = get_crop_images(origin_img, num=3) + get_crop_images(rm_bkg_img, num=3) | |
image_inputs = val_preprocess(images = image) | |
image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device) | |
with torch.no_grad(): | |
dino_pred = dino(image_inputs) | |
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1).to(torch.float32) | |
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1).to(torch.float32) | |
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1).to(torch.float32) | |
gaus_ax_pred = remove_outliers_and_average_circular(gaus_ax_pred) | |
gaus_pl_pred = remove_outliers_and_average(gaus_pl_pred) | |
gaus_ro_pred = remove_outliers_and_average(gaus_ro_pred) | |
confidence = torch.mean(F.softmax(dino_pred[:, -2:], dim=-1), dim=0)[0] | |
angles = torch.zeros(4) | |
angles[0] = gaus_ax_pred | |
angles[1] = gaus_pl_pred - 90 | |
angles[2] = gaus_ro_pred - 30 | |
angles[3] = confidence | |
return angles | |
def infer_func(img, do_rm_bkg, do_infer_aug): | |
origin_img = Image.fromarray(img) | |
if do_infer_aug: | |
rm_bkg_img = background_preprocess(origin_img, True) | |
angles = get_3angle_infer_aug(origin_img, rm_bkg_img) | |
else: | |
rm_bkg_img = background_preprocess(origin_img, do_rm_bkg) | |
angles = get_3angle(rm_bkg_img) | |
phi = np.radians(angles[0]) | |
theta = np.radians(angles[1]) | |
gamma = angles[2] | |
render_axis = render_3D_axis(phi, theta, gamma) | |
res_img = overlay_images_with_scaling(render_axis, rm_bkg_img) | |
# axis_model = "axis.obj" | |
return [res_img, round(float(angles[0]), 2), round(float(angles[1]), 2), round(float(angles[2]), 2), round(float(angles[3]), 2)] | |
server = gr.Interface( | |
flagging_mode='never', | |
fn=infer_func, | |
inputs=[ | |
gr.Image(height=512, width=512, label="upload your image"), | |
gr.Checkbox(label="Remove Background", value=True), | |
gr.Checkbox(label="Inference time augmentation", value=False) | |
], | |
outputs=[ | |
gr.Image(height=512, width=512, label="result image"), | |
# gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"), | |
gr.Textbox(lines=1, label='Azimuth(0~360°)'), | |
gr.Textbox(lines=1, label='Polar(-90~90°)'), | |
gr.Textbox(lines=1, label='Rotation(-90~90°)'), | |
gr.Textbox(lines=1, label='Confidence(0~1)') | |
] | |
) | |
server.launch() | |