krishnv commited on
Commit
31e8f8b
·
verified ·
1 Parent(s): ce396c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -22
app.py CHANGED
@@ -1,38 +1,41 @@
1
  from PIL import Image
2
- from transformers import VisionEncoderDecoderModel , ViTFeatureExtractor , PreTrainedTokenizerFast
3
  import gradio as gr
4
 
5
- model = VisionEncoderDecoderModel.from_pretrained("ydshieh/vit-gpt2-coco-en")
6
- vit_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch32-224-in21k")
7
- tokenizer = PreTrainedTokenizerFast.from_pretrained("distilgpt2")
8
-
9
 
 
10
  def caption_images(image):
11
- pixel_values = vit_feature_extractor(images=image,return_tensors="pt").pixel_values
12
- encoder_outputs = model.generate(pixel_values.to('cpu'),num_beams=5)
13
- generated_sentence = tokenizer.batch_decode(encoder_outputs,skip_special_tokens=True)
14
-
15
- return (generated_sentence[0].strip())
16
-
17
-
 
18
  inputs = [
19
- gr.components.Image(type='pil',label='Original Image')
20
  ]
21
 
22
  outputs = [
23
- gr.components.Textbox(label='Caption')
24
  ]
25
 
26
- title = "Simple Image captioning Application"
 
27
  description = "Upload an image to see the caption generated"
28
- example =['messi.jpg']
29
 
 
30
  gr.Interface(
31
- caption_images,
32
- inputs,
33
- outputs,
34
  title=title,
35
- description = description,
36
- examples = example,
37
  ).launch(debug=True)
38
-
 
1
  from PIL import Image
2
+ from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, PreTrainedTokenizerFast
3
  import gradio as gr
4
 
5
+ # Load the model and processor
6
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/git-base")
7
+ feature_extractor = ViTFeatureExtractor.from_pretrained("microsoft/git-base")
8
+ tokenizer = PreTrainedTokenizerFast.from_pretrained("microsoft/git-base")
9
 
10
+ # Define the captioning function
11
  def caption_images(image):
12
+ # Preprocess the image
13
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
14
+ # Generate captions
15
+ encoder_outputs = model.generate(pixel_values.to('cpu'), num_beams=5)
16
+ generated_sentence = tokenizer.batch_decode(encoder_outputs, skip_special_tokens=True)
17
+ return generated_sentence[0].strip()
18
+
19
+ # Define Gradio interface components
20
  inputs = [
21
+ gr.inputs.Image(type='pil', label='Original Image')
22
  ]
23
 
24
  outputs = [
25
+ gr.outputs.Textbox(label='Caption')
26
  ]
27
 
28
+ # Define Gradio app properties
29
+ title = "Simple Image Captioning Application"
30
  description = "Upload an image to see the caption generated"
31
+ example = ['messi.jpg'] # Replace with a valid path to an example image
32
 
33
+ # Create and launch the Gradio interface
34
  gr.Interface(
35
+ fn=caption_images,
36
+ inputs=inputs,
37
+ outputs=outputs,
38
  title=title,
39
+ description=description,
40
+ examples=example,
41
  ).launch(debug=True)