AisingioroHao0 commited on
Commit
cd61183
1 Parent(s): ae84eb8
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -3,7 +3,8 @@ import gradio as gr
3
  from stable_diffusion_reference_only.pipelines.stable_diffusion_reference_only_pipeline import (
4
  StableDiffusionReferenceOnlyPipeline,
5
  )
6
- import anime_segmentation
 
7
  from diffusers.schedulers import UniPCMultistepScheduler
8
  from PIL import Image
9
  import cv2
@@ -12,7 +13,9 @@ import os
12
  import torch
13
 
14
  if __name__ == "__main__":
 
15
  print(f"Is CUDA available: {torch.cuda.is_available()}")
 
16
  if torch.cuda.is_available():
17
  device = "cuda"
18
  else:
@@ -25,14 +28,14 @@ if __name__ == "__main__":
25
  automatic_coloring_pipeline.scheduler.config
26
  )
27
 
28
- segment_model = anime_segmentation.get_model(
29
  model_path=huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.ckpt")
30
  ).to(device)
31
 
32
  def character_segment(img):
33
  if img is None:
34
  return None
35
- img = anime_segmentation.character_segment(segment_model, img)
36
  img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
37
  return img
38
 
 
3
  from stable_diffusion_reference_only.pipelines.stable_diffusion_reference_only_pipeline import (
4
  StableDiffusionReferenceOnlyPipeline,
5
  )
6
+ from anime_segmentation import get_model as get_anime_segmentation_model
7
+ from anime_segmentation import character_segment as anime_character_segment
8
  from diffusers.schedulers import UniPCMultistepScheduler
9
  from PIL import Image
10
  import cv2
 
13
  import torch
14
 
15
  if __name__ == "__main__":
16
+
17
  print(f"Is CUDA available: {torch.cuda.is_available()}")
18
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
19
  if torch.cuda.is_available():
20
  device = "cuda"
21
  else:
 
28
  automatic_coloring_pipeline.scheduler.config
29
  )
30
 
31
+ segment_model = get_anime_segmentation_model(
32
  model_path=huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.ckpt")
33
  ).to(device)
34
 
35
  def character_segment(img):
36
  if img is None:
37
  return None
38
+ img = anime_character_segment(segment_model, img)
39
  img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
40
  return img
41