AkiKagura commited on
Commit
24da135
·
1 Parent(s): 6c2c213

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -24,25 +24,36 @@ img_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("AkiKagura/mkgen-diffu
24
  img_pipe.safety_checker = empty_checker
25
  img_pipe.to(device)
26
 
27
- source_img = gr.Image(source="upload", type="filepath", label="init_img | 512*512 px")
28
  gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[1], height="auto")
29
 
30
- def resize(value,img):
31
  #baseheight = value
32
  img = Image.open(img)
33
  #hpercent = (baseheight/float(img.size[1]))
34
  #wsize = int((float(img.size[0])*float(hpercent)))
35
  #img = img.resize((wsize,baseheight), Image.Resampling.LANCZOS)
36
- img = img.resize((value,value), Image.Resampling.LANCZOS)
37
- return img
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  def infer(source_img, prompt, guide, steps, seed, strength):
41
  generator = torch.Generator('cpu').manual_seed(seed)
42
 
43
- source_image = resize(512, source_img)
44
  source_image.save('source.png')
45
- images_list = img_pipe([prompt] * 1, init_image=source_image, strength=strength, guidance_scale=guide, num_inference_steps=steps)
46
  images = []
47
 
48
  for i, image in enumerate(images_list["images"]):
 
24
  img_pipe.safety_checker = empty_checker
25
  img_pipe.to(device)
26
 
27
+ source_img = gr.Image(source="upload", type="filepath", label="init_img")
28
  gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[1], height="auto")
29
 
30
+ def resize(img):
31
  #baseheight = value
32
  img = Image.open(img)
33
  #hpercent = (baseheight/float(img.size[1]))
34
  #wsize = int((float(img.size[0])*float(hpercent)))
35
  #img = img.resize((wsize,baseheight), Image.Resampling.LANCZOS)
36
+ hsize = img.size[1]
37
+ wsize = img.size[0]
38
+ if 6*wsize <= 5*hsize:
39
+ wsize = 512
40
+ hsize = 768
41
+ elif 4*wsize >= 5*hsize:
42
+ wsize = 768
43
+ hsize = 512
44
+ else:
45
+ wsize = 512
46
+ hsize = 512
47
+ img = img.resize((wsize,hsize), Image.Resampling.LANCZOS)
48
+ return img, wsize, hsize
49
 
50
 
51
  def infer(source_img, prompt, guide, steps, seed, strength):
52
  generator = torch.Generator('cpu').manual_seed(seed)
53
 
54
+ source_image, img_w, img_h = resize(source_img)
55
  source_image.save('source.png')
56
+ images_list = img_pipe([prompt] * 1, init_image=source_image, strength=strength, guidance_scale=guide, num_inference_steps=steps, width=img_w, height=img_h)
57
  images = []
58
 
59
  for i, image in enumerate(images_list["images"]):