Vageesh1 commited on
Commit
655168b
1 Parent(s): 35004f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -5,6 +5,7 @@ import skimage.io as io
5
  import streamlit as st
6
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
7
  from model import generate2,ClipCaptionModel
 
8
 
9
  #model loading code
10
 
@@ -25,8 +26,6 @@ coco_model.load_state_dict(torch.load('COCO_model.h5',map_location=torch.device(
25
  model = model.eval()
26
 
27
 
28
-
29
-
30
  def ui():
31
  st.markdown("# Image Captioning")
32
  uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpeg', 'jpg'])
@@ -36,10 +35,9 @@ def ui():
36
  pil_image = PIL.Image.fromarray(image)
37
  image = preprocess(pil_image).unsqueeze(0).to(device)
38
 
39
- option = st.selectbox('Please select the Model',('Model', 'COCO Model'))
40
 
41
  if option=='Model':
42
-
43
  with torch.no_grad():
44
  prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
45
  prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
@@ -57,6 +55,11 @@ def ui():
57
  st.image(uploaded_file, width = 500, channels = 'RGB')
58
  st.markdown("**PREDICTION:** " + generated_text_prefix)
59
 
 
 
 
 
 
60
 
61
  if __name__ == '__main__':
62
  ui()
 
5
  import streamlit as st
6
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
7
  from model import generate2,ClipCaptionModel
8
+ from engine import inference
9
 
10
  #model loading code
11
 
 
26
  model = model.eval()
27
 
28
 
 
 
29
  def ui():
30
  st.markdown("# Image Captioning")
31
  uploaded_file = st.file_uploader("Upload an Image", type=['png', 'jpeg', 'jpg'])
 
35
  pil_image = PIL.Image.fromarray(image)
36
  image = preprocess(pil_image).unsqueeze(0).to(device)
37
 
38
+ option = st.selectbox('Please select the Model',('Model', 'COCO Model','PreTrained Model'))
39
 
40
  if option=='Model':
 
41
  with torch.no_grad():
42
  prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
43
  prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
 
55
  st.image(uploaded_file, width = 500, channels = 'RGB')
56
  st.markdown("**PREDICTION:** " + generated_text_prefix)
57
 
58
+ elif option=='PreTrained Model':
59
+ out = inference(uploaded_file)
60
+ st.image(uploaded_file, width = 500, channels = 'RGB')
61
+ st.markdown("**PREDICTION:** " + out)
62
+
63
 
64
  if __name__ == '__main__':
65
  ui()