gabrielaltay commited on
Commit
e3a3f96
·
1 Parent(s): f40eab1

update model load

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import tempfile
2
 
3
  from colpali_engine.models.paligemma_colbert_architecture import ColPali
@@ -14,6 +15,7 @@ from torch.utils.data import DataLoader
14
  from transformers import AutoProcessor
15
 
16
 
 
17
  SS = st.session_state
18
 
19
 
@@ -55,7 +57,11 @@ def load_colpali_model():
55
  device = get_device()
56
  dtype = get_dtype(device)
57
 
58
- model = ColPali.from_pretrained(paligemma_model_name, torch_dtype=dtype).eval()
 
 
 
 
59
  model.load_adapter(colpali_model_name)
60
  model.to(device)
61
  processor = AutoProcessor.from_pretrained(colpali_model_name)
@@ -188,8 +194,14 @@ with st.container(border=True):
188
  ] + page_images
189
 
190
  genai.configure(api_key=st.secrets["google_genai_api_key"])
191
- # gen_model = genai.GenerativeModel(model_name="gemini-1.5-flash")
192
- gen_model = genai.GenerativeModel(model_name="gemini-1.5-pro")
 
 
 
 
 
 
193
  response = gen_model.generate_content(prompt)
194
  text = response.candidates[0].content.parts[0].text
195
  SS["response"] = text
 
1
+ import os
2
  import tempfile
3
 
4
  from colpali_engine.models.paligemma_colbert_architecture import ColPali
 
15
  from transformers import AutoProcessor
16
 
17
 
18
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
  SS = st.session_state
20
 
21
 
 
57
  device = get_device()
58
  dtype = get_dtype(device)
59
 
60
+ model = ColPali.from_pretrained(
61
+ paligemma_model_name,
62
+ torch_dtype=dtype,
63
+ token=st.secrets["hf_access_token"],
64
+ ).eval()
65
  model.load_adapter(colpali_model_name)
66
  model.to(device)
67
  processor = AutoProcessor.from_pretrained(colpali_model_name)
 
194
  ] + page_images
195
 
196
  genai.configure(api_key=st.secrets["google_genai_api_key"])
197
+ # genai_model_name = "gemini-1.5-flash"
198
+ genai_model_name = "gemini-1.5-pro"
199
+ gen_model = genai.GenerativeModel(
200
+ model_name=genai_model_name,
201
+ generation_config=genai.GenerationConfig(
202
+ temperature=0.1,
203
+ ),
204
+ )
205
  response = gen_model.generate_content(prompt)
206
  text = response.candidates[0].content.parts[0].text
207
  SS["response"] = text