Jonny001 commited on
Commit
f5c908e
1 Parent(s): ac016ec

Update roop/predicter.py

Browse files
Files changed (1) hide show
  1. roop/predicter.py +8 -12
roop/predicter.py CHANGED
@@ -1,8 +1,8 @@
1
  import threading
2
- import numpy
3
- import opennsfw2
4
  from PIL import Image
5
- from keras import Model
 
6
 
7
  from roop.typing import Frame
8
 
@@ -10,7 +10,6 @@ PREDICTOR = None
10
  THREAD_LOCK = threading.Lock()
11
  MAX_PROBABILITY = 0.85
12
 
13
-
14
  def get_predictor() -> Model:
15
  global PREDICTOR
16
 
@@ -19,25 +18,22 @@ def get_predictor() -> Model:
19
  PREDICTOR = opennsfw2.make_open_nsfw_model()
20
  return PREDICTOR
21
 
22
-
23
  def clear_predictor() -> None:
24
  global PREDICTOR
25
 
26
- PREDICTOR = None
27
-
28
 
29
  def predict_frame(target_frame: Frame) -> bool:
30
  image = Image.fromarray(target_frame)
31
- image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO)
32
- views = numpy.expand_dims(image, axis=0)
33
  _, probability = get_predictor().predict(views)[0]
34
  return probability > MAX_PROBABILITY
35
 
36
-
37
  def predict_image(target_path: str) -> bool:
38
  return opennsfw2.predict_image(target_path) > MAX_PROBABILITY
39
 
40
-
41
  def predict_video(target_path: str) -> bool:
42
  _, probabilities = opennsfw2.predict_video_frames(video_path=target_path, frame_interval=100)
43
- return any(probability > MAX_PROBABILITY for probability in probabilities)
 
1
  import threading
2
+ import numpy as np
 
3
  from PIL import Image
4
+ from keras.models import Model
5
+ import opennsfw2
6
 
7
  from roop.typing import Frame
8
 
 
10
  THREAD_LOCK = threading.Lock()
11
  MAX_PROBABILITY = 0.85
12
 
 
13
  def get_predictor() -> Model:
14
  global PREDICTOR
15
 
 
18
  PREDICTOR = opennsfw2.make_open_nsfw_model()
19
  return PREDICTOR
20
 
 
21
  def clear_predictor() -> None:
22
  global PREDICTOR
23
 
24
+ with THREAD_LOCK:
25
+ PREDICTOR = None
26
 
27
  def predict_frame(target_frame: Frame) -> bool:
28
  image = Image.fromarray(target_frame)
29
+ processed_image = opennsfw2.preprocess_image(image, opennsfw2.Preprocessing.YAHOO)
30
+ views = np.expand_dims(processed_image, axis=0)
31
  _, probability = get_predictor().predict(views)[0]
32
  return probability > MAX_PROBABILITY
33
 
 
34
  def predict_image(target_path: str) -> bool:
35
  return opennsfw2.predict_image(target_path) > MAX_PROBABILITY
36
 
 
37
  def predict_video(target_path: str) -> bool:
38
  _, probabilities = opennsfw2.predict_video_frames(video_path=target_path, frame_interval=100)
39
+ return any(probability > MAX_PROBABILITY for probability in probabilities)