drhead commited on
Commit
bd1d180
1 Parent(s): ee151b9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -0
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import torch
6
+ from torchvision.transforms import transforms
7
+ from torchvision.transforms import InterpolationMode
8
+ import torchvision.transforms.functional as TF
9
+
10
+ import spaces
11
+
12
+ import huggingface_hub
13
+ import timm
14
+ from timm.models import VisionTransformer
15
+ import safetensors.torch
16
+
17
+
18
+ torch.jit.script = lambda f: f
19
+ torch.set_grad_enabled(False)
20
+
21
+ class Fit(torch.nn.Module):
22
+ def __init__(
23
+ self,
24
+ bounds: tuple[int, int] | int,
25
+ interpolation = InterpolationMode.LANCZOS,
26
+ grow: bool = True,
27
+ pad: float | None = None
28
+ ):
29
+ super().__init__()
30
+
31
+ self.bounds = (bounds, bounds) if isinstance(bounds, int) else bounds
32
+ self.interpolation = interpolation
33
+ self.grow = grow
34
+ self.pad = pad
35
+
36
+ def forward(self, img: Image) -> Image:
37
+ wimg, himg = img.size
38
+ hbound, wbound = self.bounds
39
+
40
+ hscale = hbound / himg
41
+ wscale = wbound / wimg
42
+
43
+ if not self.grow:
44
+ hscale = min(hscale, 1.0)
45
+ wscale = min(wscale, 1.0)
46
+
47
+ scale = min(hscale, wscale)
48
+ if scale == 1.0:
49
+ return img
50
+
51
+ hnew = min(round(himg * scale), hbound)
52
+ wnew = min(round(wimg * scale), wbound)
53
+
54
+ img = TF.resize(img, (hnew, wnew), self.interpolation)
55
+
56
+ if self.pad is None:
57
+ return img
58
+
59
+ hpad = hbound - hnew
60
+ wpad = wbound - wnew
61
+
62
+ tpad = hpad // 2
63
+ bpad = hpad - tpad
64
+
65
+ lpad = wpad // 2
66
+ rpad = wpad - lpad
67
+
68
+ return TF.pad(img, (lpad, tpad, rpad, bpad), self.pad)
69
+
70
+ def __repr__(self) -> str:
71
+ return (
72
+ f"{self.__class__.__name__}(" +
73
+ f"bounds={self.bounds}, " +
74
+ f"interpolation={self.interpolation.value}, " +
75
+ f"grow={self.grow}, " +
76
+ f"pad={self.pad})"
77
+ )
78
+
79
+ class CompositeAlpha(torch.nn.Module):
80
+ def __init__(
81
+ self,
82
+ background: tuple[float, float, float] | float,
83
+ ):
84
+ super().__init__()
85
+
86
+ self.background = (background, background, background) if isinstance(background, float) else background
87
+ self.background = torch.tensor(self.background).unsqueeze(1).unsqueeze(2)
88
+
89
+ def forward(self, img: torch.Tensor) -> torch.Tensor:
90
+ if img.shape[-3] == 3:
91
+ return img
92
+
93
+ alpha = img[..., 3, None, :, :]
94
+
95
+ img[..., :3, :, :] *= alpha
96
+
97
+ background = self.background.expand(-1, img.shape[-2], img.shape[-1])
98
+ if background.ndim == 1:
99
+ background = background[:, None, None]
100
+ elif background.ndim == 2:
101
+ background = background[None, :, :]
102
+
103
+ img[..., :3, :, :] += (1.0 - alpha) * background
104
+ return img[..., :3, :, :]
105
+
106
+ def __repr__(self) -> str:
107
+ return (
108
+ f"{self.__class__.__name__}(" +
109
+ f"background={self.background})"
110
+ )
111
+
112
+ transform = transforms.Compose([
113
+ Fit((384, 384)),
114
+ transforms.ToTensor(),
115
+ CompositeAlpha(0.5),
116
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
117
+ transforms.CenterCrop((384, 384)),
118
+ ])
119
+
120
+ model_file = huggingface_hub.hf_hub_download(
121
+ repo_id="RedRocket/JointTaggerProject",
122
+ filename="JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors",
123
+ subfolder="JTP_PILOT"
124
+ )
125
+
126
+ model = timm.create_model(
127
+ "vit_so400m_patch14_siglip_384.webli",
128
+ pretrained=False,
129
+ num_classes=9083,
130
+ ) # type: VisionTransformer
131
+
132
+ safetensors.torch.load_model(model, model_file)
133
+ model.eval()
134
+
135
+ tags_file = huggingface_hub.hf_hub_download(
136
+ repo_id="RedRocket/JointTaggerProject",
137
+ filename="tags.json",
138
+ subfolder="JTP_PILOT"
139
+ )
140
+
141
+ with open(tags_file, "r") as file:
142
+ tags = json.load(file) # type: dict
143
+ allowed_tags = tags.keys()
144
+
145
+ @spaces.GPU(duration=5)
146
+ def create_tags(image, threshold):
147
+ img = image.convert('RGB')
148
+ tensor = transform(img).unsqueeze(0)
149
+
150
+ with torch.no_grad():
151
+ logits = model(tensor)
152
+ probabilities = torch.nn.functional.sigmoid(logits[0])
153
+ indices = torch.where(probabilities > threshold)[0]
154
+ values = probabilities[indices]
155
+
156
+ temp = []
157
+ tag_score = dict()
158
+ for i in range(indices.size(0)):
159
+ temp.append([allowed_tags[indices[i]], values[i].item()])
160
+ tag_score[allowed_tags[indices[i]]] = values[i].item()
161
+ temp = [t[0] for t in temp]
162
+ text_no_impl = ", ".join(temp)
163
+ return text_no_impl, tag_score
164
+
165
+ with gr.Blocks() as demo:
166
+ with gr.Tab("Single Image"):
167
+ gr.Interface(
168
+ create_tags,
169
+ inputs=[gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.30, label="Threshold")],
170
+ outputs=[
171
+ gr.Textbox(label="Tag String"),
172
+ gr.Label(label="Tag Predictions", num_top_classes=200),
173
+ ],
174
+ allow_flagging="never",
175
+ )
176
+
177
+ if __name__ == "__main__":
178
+ demo.launch()