D0k-tor commited on
Commit
f595b41
·
1 Parent(s): 9f60554

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -13
app.py CHANGED
@@ -9,21 +9,38 @@ import tensorflow as tf
9
  os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
10
 
11
  device='cpu'
12
- encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
13
- decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
14
- model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
15
 
16
- feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
17
- tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
18
- model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
19
 
20
- def predict(image, max_length=64, num_beams=4):
21
- image = image.convert('RGB')
22
- image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
23
- clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
24
- caption_ids = model.generate(image, max_length = max_length)[0]
25
- caption_text = clean_text(tokenizer.decode(caption_ids))
26
- return caption_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  input = gr.inputs.Image(label="Upload any Image", type = 'pil', optional=True)
29
  output = gr.outputs.Textbox(type="text",label="Captions")
 
9
  os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
10
 
11
  device='cpu'
12
+ # encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
13
+ # decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
14
+ # model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
15
 
16
+ # feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
17
+ # tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
18
+ # model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
19
 
20
+ # def predict(image, max_length=64, num_beams=4):
21
+ # image = image.convert('RGB')
22
+ # image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
23
+ # clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
24
+ # caption_ids = model.generate(image, max_length = max_length)[0]
25
+ # caption_text = clean_text(tokenizer.decode(caption_ids))
26
+ # return caption_text
27
+
28
+ model_id = "nttdataspain/vit-gpt2-coco-lora"
29
+ model = VisionEncoderDecoderModel.from_pretrained(model_id)
30
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
31
+ feature_extractor = ViTFeatureExtractor.from_pretrained(model_id)
32
+
33
+ # Predict function
34
+ def predict_prompts(image):
35
+ img = Image.open(image).convert('RGB')
36
+ model.eval()
37
+ pixel_values = feature_extractor(images=[img], return_tensors="pt").pixel_values
38
+ with torch.no_grad():
39
+ output_ids = model.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True).sequences
40
+
41
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
42
+ preds = [pred.strip() for pred in preds]
43
+ return preds[0]
44
 
45
  input = gr.inputs.Image(label="Upload any Image", type = 'pil', optional=True)
46
  output = gr.outputs.Textbox(type="text",label="Captions")