Peijie commited on
Commit
1fb7158
1 Parent(s): 24dce8b

load model inside gr.Block

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -1,16 +1,10 @@
1
  import os
2
- import gradio as gr
3
- print(f"Gradio version {gr.__version__}")
4
- # if gr.__version__ != '4.28.2':
5
- # os.system("pip uninstall -y gradio")
6
- # os.system("pip install gradio==4.28.2")
7
- # print(f"Gradio version: {gr.__version__}")
8
-
9
  import io
10
 
11
  import torch
12
  import json
13
  import base64
 
14
  import numpy as np
15
  from pathlib import Path
16
  from PIL import Image
@@ -20,7 +14,12 @@ from utils.load_model import load_xclip
20
  from utils.predict import xclip_pred
21
 
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
- XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE)
 
 
 
 
 
24
  XCLIP_DESC_PATH = "data/jsons/bs_cub_desc.json"
25
  XCLIP_DESC = json.load(open(XCLIP_DESC_PATH, "r"))
26
  PREPROCESS = lambda x: OWLVIT_PRECESSOR(images=x, return_tensors='pt')
@@ -383,6 +382,7 @@ custom_css = """
383
 
384
  # Define the Gradio interface
385
  with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="PEEB") as demo:
 
386
  current_image = gr.State("")
387
  current_predicted_class = gr.State("")
388
  gt_class = gr.State("")
 
1
  import os
 
 
 
 
 
 
 
2
  import io
3
 
4
  import torch
5
  import json
6
  import base64
7
+ import gradio as gr
8
  import numpy as np
9
  from pathlib import Path
10
  from PIL import Image
 
14
  from utils.predict import xclip_pred
15
 
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
+ XCLIP, OWLVIT_PRECESSOR = None, None
18
+ def initialize_model():
19
+ global XCLIP, OWLVIT_PRECESSOR
20
+ if XCLIP is None or OWLVIT_PRECESSOR is None:
21
+ XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE)
22
+
23
  XCLIP_DESC_PATH = "data/jsons/bs_cub_desc.json"
24
  XCLIP_DESC = json.load(open(XCLIP_DESC_PATH, "r"))
25
  PREPROCESS = lambda x: OWLVIT_PRECESSOR(images=x, return_tensors='pt')
 
382
 
383
  # Define the Gradio interface
384
  with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="PEEB") as demo:
385
+ initialize_model()
386
  current_image = gr.State("")
387
  current_predicted_class = gr.State("")
388
  gt_class = gr.State("")