DmitrMakeev commited on
Commit
a40535b
·
1 Parent(s): cacd4e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -40
app.py CHANGED
@@ -2,45 +2,49 @@ from PIL import Image
2
  import torch
3
  import gradio as gr
4
  # Загрузка моделей с разными вариантами настроек
5
- model1_512 = torch.hub.load("susanli2016/animegan2-pytorch", "generator", pretrained="celeba_distill")
6
- model2_512 = torch.hub.load("susanli2016/animegan2-pytorch", "generator", pretrained="face_paint_512_v1")
7
- model3_512 = torch.hub.load("susanli2016/animegan2-pytorch", "generator", pretrained="face_paint_512_v2")
8
- model4_512 = torch.hub.load("susanli2016/animegan2-pytorch", "generator", pretrained="paprika")
9
- model1_1024 = torch.hub.load("susanli2016/animegan2-pytorch", "generator", pretrained="celeba_distill_1024")
10
- model2_1024 = torch.hub.load("susanli2016/animegan2-pytorch", "generator", pretrained="face_paint_1024_v1")
11
- model3_1024 = torch.hub.load("susanli2016/animegan2-pytorch", "generator", pretrained="face_paint_1024_v2")
12
- model4_1024 = torch.hub.load("susanli2016/animegan2-pytorch", "generator", pretrained="paprika_1024")
 
 
 
 
 
13
  def inference(img, ver, size):
14
- if size == '512':
15
- if ver == 'Celebrity Distill':
16
- out = face2paint(model1_512, img)
17
- elif ver == 'Face Paint v1':
18
- out = face2paint(model2_512, img)
19
- elif ver == 'Face Paint v2':
20
- out = face2paint(model3_512, img)
21
- elif ver == 'Paprika':
22
- out = face2paint(model4_512, img)
23
- elif size == '1024':
24
- if ver == 'Celebrity Distill':
25
- out = face2paint(model1_1024, img)
26
- elif ver == 'Face Paint v1':
27
- out = face2paint(model2_1024, img)
28
- elif ver == 'Face Paint v2':
29
- out = face2paint(model3_1024, img)
30
- elif ver == 'Paprika':
31
- out = face2paint(model4_1024, img)
32
  return out
33
- # Создание интерфейса Gradio с выбором размера и версии модели
34
- gr.Interface(inference,
35
- [gr.inputs.Image(type="pil"),
36
- gr.inputs.Radio(['512', '1024'], type="value", default='512', label='Size'),
37
- gr.inputs.Radio(['Celebrity Distill', 'Face Paint v1', 'Face Paint v2', 'Paprika'],
38
- type="value",
39
- default='Celebrity Distill',
40
- label='Version')],
41
- gr.outputs.Image(type="pil"),
42
- title=title,
43
- description=description,
44
- article=article,
45
- allow_flagging=False,
46
- allow_screenshot=False).launch()
 
 
 
 
 
 
 
 
2
  import torch
3
  import gradio as gr
4
  # Загрузка моделей с разными вариантами настроек
5
+ model1 = torch.hub.load("bryandlee/animegan2-pytorch:main", "generator", pretrained="celeba_distill")
6
+ model2 = torch.hub.load("bryandlee/animegan2-pytorch:main", "generator", pretrained="face_paint_512_v1")
7
+ model3 = torch.hub.load("bryandlee/animegan2-pytorch:main", "generator", pretrained="face_paint_512_v2")
8
+ model4 = torch.hub.load("bryandlee/animegan2-pytorch:main", "generator", pretrained="paprika")
9
+ def load_face2paint_model(size):
10
+ global face2paint
11
+ if size == 512:
12
+ face2paint = torch.hub.load('bryandlee/animegan2-pytorch:main', 'face2paint',
13
+ size=size, device="cpu", side_by_side=False)
14
+ elif size == 1024:
15
+ face2paint = torch.hub.load("bryandlee/animegan2-pytorch:main", "face2paint",
16
+ size=size, device="cpu", side_by_side=False)
17
+ load_face2paint_model(512) # Загрузка модели с размером 512 по умолчанию
18
  def inference(img, ver, size):
19
+ if size != 512:
20
+ load_face2paint_model(size) # Загрузка модели с выбранным размером
21
+ if ver == 'Celebrity Distill':
22
+ out = face2paint(model1, img)
23
+ elif ver == 'Face Paint v1':
24
+ out = face2paint(model2, img)
25
+ elif ver == 'Face Paint v2':
26
+ out = face2paint(model3, img)
27
+ elif ver == 'Paprika':
28
+ out = face2paint(model4, img)
 
 
 
 
 
 
 
 
29
  return out
30
+ title = "AnimeGANv2"
31
+ description = "Gradio Demo for AnimeGanv2 Face Portrait. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Please use a cropped portrait picture for best results similar to the examples below."
32
+ article = ""
33
+ # Создание интерфейса Gradio с вариантами моделей и размеров
34
+ interface = gr.Interface(inference,
35
+ [gr.inputs.Image(type="pil"),
36
+ gr.inputs.Radio(['Celebrity Distill', 'Face Paint v1', 'Face Paint v2', 'Paprika'],
37
+ type="value",
38
+ default='Celebrity Distill',
39
+ label='Version'),
40
+ gr.inputs.Radio([512, 1024],
41
+ type="value",
42
+ default=512,
43
+ label='Model Size (pixels)')],
44
+ gr.outputs.Image(type="pil"),
45
+ title=title,
46
+ description=description,
47
+ article=article,
48
+ allow_flagging=False,
49
+ allow_screenshot=False)
50
+ interface.launch()