nielsr HF staff commited on
Commit
76c8f3a
1 Parent(s): 51d259a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, VisionEncoderDecoderModel
3
  import torch
4
 
5
  torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
@@ -13,6 +13,7 @@ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image
13
 
14
  vitgpt_processor = AutoImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
15
  vitgpt_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
 
16
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
@@ -20,12 +21,15 @@ git_model.to(device)
20
  blip_model.to(device)
21
  vitgpt_model.to(device)
22
 
23
- def generate_caption(processor, model, image):
24
  inputs = processor(images=image, return_tensors="pt").to(device)
25
 
26
  generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
27
-
28
- generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
29
 
30
  return generated_caption
31
 
@@ -35,7 +39,7 @@ def generate_captions(image):
35
 
36
  caption_blip = generate_caption(blip_processor, blip_model, image)
37
 
38
- caption_vitgpt = generate_caption(vitgpt_processor, vitgpt_model, image)
39
 
40
  return caption_git, caption_blip, caption_vitgpt
41
 
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor, AutoModelForCausalLM, BlipForConditionalGeneration, VisionEncoderDecoderModel
3
  import torch
4
 
5
  torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
 
13
 
14
  vitgpt_processor = AutoImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
15
  vitgpt_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
16
+ vitgpt_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
17
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
 
21
  blip_model.to(device)
22
  vitgpt_model.to(device)
23
 
24
+ def generate_caption(processor, model, image, tokenizer=None):
25
  inputs = processor(images=image, return_tensors="pt").to(device)
26
 
27
  generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
28
+
29
+ if tokenizer is not None:
30
+ generated_ids = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
31
+ else:
32
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
33
 
34
  return generated_caption
35
 
 
39
 
40
  caption_blip = generate_caption(blip_processor, blip_model, image)
41
 
42
+ caption_vitgpt = generate_caption(vitgpt_processor, vitgpt_model, image, vitgpt_tokenizer)
43
 
44
  return caption_git, caption_blip, caption_vitgpt
45