sms07 commited on
Commit
ef239ad
1 Parent(s): 3748c64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -8,29 +8,29 @@ from transformers import (
8
  )
9
 
10
  # Function to load VQA pipeline
11
- @st.cache(allow_output_mutation=True)
12
  def load_vqa_pipeline():
13
  return pipeline(task="visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa")
14
 
15
  # Function to load BERT-based pipeline
16
- @st.cache(allow_output_mutation=True)
17
  def load_bbu_pipeline():
18
  return pipeline(task="fill-mask", model="bert-base-uncased")
19
 
20
  # Function to load Blenderbot model
21
- @st.cache(allow_output_mutation=True)
22
  def load_blenderbot_model():
23
  model_name = "facebook/blenderbot-400M-distill"
24
  tokenizer = BlenderbotTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)
25
  return BlenderbotForConditionalGeneration.from_pretrained(pretrained_model_name_or_path=model_name)
26
 
27
  # Function to load GPT-2 pipeline
28
- @st.cache(allow_output_mutation=True)
29
  def load_gpt2_pipeline():
30
  return pipeline(task="text-generation", model="gpt2")
31
 
32
  # Function to load BERTopic models
33
- @st.cache(allow_output_mutation=True)
34
  def load_topic_models():
35
  topic_model_1 = BERTopic.load(path="davanstrien/chat_topics")
36
  topic_model_2 = BERTopic.load(path="MaartenGr/BERTopic_ArXiv")
 
8
  )
9
 
10
  # Function to load VQA pipeline
11
+ @st.cache_resource(allow_output_mutation=True)
12
  def load_vqa_pipeline():
13
  return pipeline(task="visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa")
14
 
15
  # Function to load BERT-based pipeline
16
+ @st.cache_resource(allow_output_mutation=True)
17
  def load_bbu_pipeline():
18
  return pipeline(task="fill-mask", model="bert-base-uncased")
19
 
20
  # Function to load Blenderbot model
21
+ @st.cache_resource(allow_output_mutation=True)
22
  def load_blenderbot_model():
23
  model_name = "facebook/blenderbot-400M-distill"
24
  tokenizer = BlenderbotTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)
25
  return BlenderbotForConditionalGeneration.from_pretrained(pretrained_model_name_or_path=model_name)
26
 
27
  # Function to load GPT-2 pipeline
28
+ @st.cache_resource(allow_output_mutation=True)
29
  def load_gpt2_pipeline():
30
  return pipeline(task="text-generation", model="gpt2")
31
 
32
  # Function to load BERTopic models
33
+ @st.cache_resource(allow_output_mutation=True)
34
  def load_topic_models():
35
  topic_model_1 = BERTopic.load(path="davanstrien/chat_topics")
36
  topic_model_2 = BERTopic.load(path="MaartenGr/BERTopic_ArXiv")