File size: 2,866 Bytes
6bcb009
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import cv2
import numpy as np
import torch
import threading
from torchvision import transforms
from clip.clipseg import CLIPDensePredT
from numpy import asarray
from typing import Any, List, Callable
import numpy as np

from roop.typing import Face, Frame
from roop.utilities import resolve_relative_path

THREAD_LOCK_CLIP = threading.Lock()


class Mask_Clip2Seg():

    model_clip = None

    processorname = 'clip2seg'
    type = 'mask'


    def Initialize(self, devicename):
        if self.model_clip is None:
            self.model_clip = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
            self.model_clip.eval();
            self.model_clip.load_state_dict(torch.load('models/CLIP/rd64-uni-refined.pth', map_location=torch.device('cpu')), strict=False)

        device = torch.device(devicename)
        self.model_clip.to(device)


    def Run(self, img1, keywords:str) -> Frame:
        if keywords is None or len(keywords) < 1 or img1 is None:
            return img1
        
        source_image_small = cv2.resize(img1, (256,256))
        
        img_mask = np.full((source_image_small.shape[0],source_image_small.shape[1]), 0, dtype=np.float32)
        mask_border = 1
        l = 0
        t = 0
        r = 1
        b = 1
        
        mask_blur = 5
        clip_blur = 5
        
        img_mask = cv2.rectangle(img_mask, (mask_border+int(l), mask_border+int(t)), 
                                (256 - mask_border-int(r), 256-mask_border-int(b)), (255, 255, 255), -1)    
        img_mask = cv2.GaussianBlur(img_mask, (mask_blur*2+1,mask_blur*2+1), 0)    
        img_mask /= 255

        
        input_image = source_image_small

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            transforms.Resize((256, 256)),
        ])
        img = transform(input_image).unsqueeze(0)

        thresh = 0.5
        prompts = keywords.split(',')
        with THREAD_LOCK_CLIP:
            with torch.no_grad():
                preds = self.model_clip(img.repeat(len(prompts),1,1,1), prompts)[0]
        clip_mask = torch.sigmoid(preds[0][0])
        for i in range(len(prompts)-1):
            clip_mask += torch.sigmoid(preds[i+1][0])
           
        clip_mask = clip_mask.data.cpu().numpy()
        np.clip(clip_mask, 0, 1)
        
        clip_mask[clip_mask>thresh] = 1.0
        clip_mask[clip_mask<=thresh] = 0.0
        kernel = np.ones((5, 5), np.float32)
        clip_mask = cv2.dilate(clip_mask, kernel, iterations=1)
        clip_mask = cv2.GaussianBlur(clip_mask, (clip_blur*2+1,clip_blur*2+1), 0)
       
        img_mask *= clip_mask
        img_mask[img_mask<0.0] = 0.0
        return img_mask
       


    def Release(self):
        self.model_clip = None