skytnt commited on
Commit
c583569
·
1 Parent(s): 5af8e1d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +34 -0
README.md CHANGED
@@ -7,5 +7,39 @@ Model convert from [https://github.com/KichangKim/DeepDanbooru](https://github.c
7
  Usage:
8
 
9
  ```python
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ```
 
7
  Usage:
8
 
9
  ```python
10
+ import cv2
11
+ import numpy as np
12
+ import onnxruntime as rt
13
+ from huggingface_hub import hf_hub_download
14
 
15
+ tagger_model_path = hf_hub_download(repo_id="skytnt/deepdanbooru_onnx", filename="deepdanbooru.onnx")
16
+
17
+ tagger_model = rt.InferenceSession(tagger_model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
18
+ tagger_model_meta = tagger_model.get_modelmeta().custom_metadata_map
19
+ tagger_tags = eval(tagger_model_meta['tags'])
20
+
21
+ def tagger_predict(image, score_threshold):
22
+ h, w = image.shape[:2]
23
+ r = min(512 / w, 512 / h)
24
+ h, w = int(h * r), int(w * r)
25
+ image = cv2.resize(image, (w, h))
26
+ pdx = 512 - w
27
+ pdy = 512 - h
28
+ img_new = np.full([512, 512, 3], 1, dtype=np.float32)
29
+ img_new[pdy // 2:pdy // 2 + h, pdx // 2:pdx // 2 + w] = image
30
+ image = img_new[np.newaxis, :]
31
+ probs = tagger_model.run(None, {"input_1": image})[0][0]
32
+ probs = probs.astype(np.float32)
33
+ res = []
34
+ for prob, label in zip(probs.tolist(), tagger_tags):
35
+ if prob < score_threshold:
36
+ continue
37
+ res.append(label)
38
+ return res
39
+
40
+ img = cv2.imread("test.jpg")
41
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
42
+ img = img.astype(np.float32) / 255
43
+ tags = tagger_predict(img, 0.5)
44
+ print(tags)
45
  ```