Dhanush4149 commited on
Commit
de13ccb
·
verified ·
1 Parent(s): 038a86f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -76
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import streamlit as st
3
- from transformers import pipeline
4
  import traceback
5
 
6
  # Use Hugging Face Spaces' recommended persistent storage
@@ -16,72 +16,80 @@ def ensure_cache_dir():
16
  os.makedirs(CACHE_DIR, exist_ok=True)
17
  return CACHE_DIR
18
 
19
- def load_pipelines():
20
  """
21
- Load summarization pipelines with persistent caching.
 
 
 
22
 
23
  Returns:
24
- dict: Dictionary of model pipelines
25
  """
26
  try:
27
  # Ensure cache directory exists
28
  cache_dir = ensure_cache_dir()
29
 
30
- # Define model paths within the cache directory
31
- bart_cache = os.path.join(cache_dir, "bart-large-cnn")
32
- t5_cache = os.path.join(cache_dir, "t5-large")
33
- pegasus_cache = os.path.join(cache_dir, "pegasus-cnn_dailymail")
34
 
35
- # Load pipelines with explicit cache directories
36
- bart_pipeline = pipeline(
37
- "summarization",
38
- model="facebook/bart-large-cnn",
39
- cache_dir=bart_cache
40
  )
41
- t5_pipeline = pipeline(
42
- "summarization",
43
- model="t5-large",
44
- cache_dir=t5_cache
45
- )
46
- pegasus_pipeline = pipeline(
47
- "summarization",
48
- model="google/pegasus-cnn_dailymail",
49
- cache_dir=pegasus_cache
50
  )
51
 
52
- return {
53
- 'BART': bart_pipeline,
54
- 'T5': t5_pipeline,
55
- 'Pegasus': pegasus_pipeline
56
- }
57
  except Exception as e:
58
- st.error(f"Error loading models: {str(e)}")
59
  st.error(traceback.format_exc())
60
- return {}
61
 
62
- def generate_summary(pipeline, text, model_name):
63
  """
64
- Generate summary for a specific model with error handling.
65
 
66
  Args:
67
- pipeline: Hugging Face summarization pipeline
 
68
  text (str): Input text to summarize
69
- model_name (str): Name of the model
70
 
71
  Returns:
72
- str: Generated summary or error message
73
  """
74
  try:
75
- prompt = "Summarize the below paragraph"
76
- summary = pipeline(f"{prompt}\n{text}",
77
- max_length=150,
78
- min_length=50,
79
- length_penalty=2.0,
80
- num_beams=4,
81
- early_stopping=True)[0]['summary_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  return summary
83
  except Exception as e:
84
- error_msg = f"Error in {model_name} summarization: {str(e)}"
85
  st.error(error_msg)
86
  return error_msg
87
 
@@ -91,6 +99,13 @@ def main():
91
  # Display cache directory info (optional)
92
  st.info(f"Models will be cached in: {CACHE_DIR}")
93
 
 
 
 
 
 
 
 
94
  # Text input
95
  text_input = st.text_area("Enter text to summarize:")
96
 
@@ -100,44 +115,33 @@ def main():
100
  st.error("Please enter text to summarize.")
101
  return
102
 
103
- # Load pipelines
104
- pipelines = load_pipelines()
105
- if not pipelines:
106
- st.error("Failed to load models. Please check your internet connection or try again later.")
107
- return
108
-
109
  # Create columns for progressive display
110
  bart_col, t5_col, pegasus_col = st.columns(3)
111
 
112
- # BART Summary
113
- with bart_col:
114
- with st.spinner('Generating BART Summary...'):
115
- bart_progress = st.progress(0)
116
- bart_progress.progress(50)
117
- bart_summary = generate_summary(pipelines['BART'], text_input, 'BART')
118
- bart_progress.progress(100)
119
- st.subheader("BART Summary")
120
- st.write(bart_summary)
121
-
122
- # T5 Summary
123
- with t5_col:
124
- with st.spinner('Generating T5 Summary...'):
125
- t5_progress = st.progress(0)
126
- t5_progress.progress(50)
127
- t5_summary = generate_summary(pipelines['T5'], text_input, 'T5')
128
- t5_progress.progress(100)
129
- st.subheader("T5 Summary")
130
- st.write(t5_summary)
131
 
132
- # Pegasus Summary
133
- with pegasus_col:
134
- with st.spinner('Generating Pegasus Summary...'):
135
- pegasus_progress = st.progress(0)
136
- pegasus_progress.progress(50)
137
- pegasus_summary = generate_summary(pipelines['Pegasus'], text_input, 'Pegasus')
138
- pegasus_progress.progress(100)
139
- st.subheader("Pegasus Summary")
140
- st.write(pegasus_summary)
141
 
142
  if __name__ == "__main__":
143
  main()
 
1
  import os
2
  import streamlit as st
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import traceback
5
 
6
  # Use Hugging Face Spaces' recommended persistent storage
 
16
  os.makedirs(CACHE_DIR, exist_ok=True)
17
  return CACHE_DIR
18
 
19
+ def load_model_and_tokenizer(model_name):
20
  """
21
+ Load model and tokenizer with persistent caching.
22
+
23
+ Args:
24
+ model_name (str): Name of the model to load
25
 
26
  Returns:
27
+ tuple: (model, tokenizer)
28
  """
29
  try:
30
  # Ensure cache directory exists
31
  cache_dir = ensure_cache_dir()
32
 
33
+ # Construct full cache path for this model
34
+ model_cache_path = os.path.join(cache_dir, model_name.replace('/', '_'))
 
 
35
 
36
+ # Load tokenizer
37
+ tokenizer = AutoTokenizer.from_pretrained(
38
+ model_name,
39
+ cache_dir=model_cache_path
 
40
  )
41
+
42
+ # Load model
43
+ model = AutoModelForSeq2SeqLM.from_pretrained(
44
+ model_name,
45
+ cache_dir=model_cache_path
 
 
 
 
46
  )
47
 
48
+ return model, tokenizer
 
 
 
 
49
  except Exception as e:
50
+ st.error(f"Error loading {model_name}: {str(e)}")
51
  st.error(traceback.format_exc())
52
+ return None, None
53
 
54
+ def generate_summary(model, tokenizer, text, max_length=150):
55
  """
56
+ Generate summary using a specific model and tokenizer.
57
 
58
  Args:
59
+ model: Hugging Face model
60
+ tokenizer: Hugging Face tokenizer
61
  text (str): Input text to summarize
62
+ max_length (int): Maximum length of summary
63
 
64
  Returns:
65
+ str: Generated summary
66
  """
67
  try:
68
+ # Prepare input
69
+ inputs = tokenizer(
70
+ f"summarize: {text}",
71
+ max_length=512,
72
+ return_tensors="pt",
73
+ truncation=True
74
+ )
75
+
76
+ # Generate summary
77
+ summary_ids = model.generate(
78
+ inputs.input_ids,
79
+ num_beams=4,
80
+ max_length=max_length,
81
+ early_stopping=True
82
+ )
83
+
84
+ # Decode summary
85
+ summary = tokenizer.decode(
86
+ summary_ids[0],
87
+ skip_special_tokens=True
88
+ )
89
+
90
  return summary
91
  except Exception as e:
92
+ error_msg = f"Error in summarization: {str(e)}"
93
  st.error(error_msg)
94
  return error_msg
95
 
 
99
  # Display cache directory info (optional)
100
  st.info(f"Models will be cached in: {CACHE_DIR}")
101
 
102
+ # Define models
103
+ models_to_load = {
104
+ 'BART': 'facebook/bart-large-cnn',
105
+ 'T5': 't5-large',
106
+ 'Pegasus': 'google/pegasus-cnn_dailymail'
107
+ }
108
+
109
  # Text input
110
  text_input = st.text_area("Enter text to summarize:")
111
 
 
115
  st.error("Please enter text to summarize.")
116
  return
117
 
 
 
 
 
 
 
118
  # Create columns for progressive display
119
  bart_col, t5_col, pegasus_col = st.columns(3)
120
 
121
+ # Function to process each model
122
+ def process_model(col, model_name, model_path):
123
+ with col:
124
+ with st.spinner(f'Generating {model_name} Summary...'):
125
+ progress = st.progress(0)
126
+ progress.progress(50)
127
+
128
+ # Load model and tokenizer
129
+ model, tokenizer = load_model_and_tokenizer(model_path)
130
+
131
+ if model and tokenizer:
132
+ # Generate summary
133
+ summary = generate_summary(model, tokenizer, text_input)
134
+
135
+ progress.progress(100)
136
+ st.subheader(f"{model_name} Summary")
137
+ st.write(summary)
138
+ else:
139
+ st.error(f"Failed to load {model_name} model")
140
 
141
+ # Process each model
142
+ process_model(bart_col, 'BART', 'facebook/bart-large-cnn')
143
+ process_model(t5_col, 'T5', 't5-large')
144
+ process_model(pegasus_col, 'Pegasus', 'google/pegasus-cnn_dailymail')
 
 
 
 
 
145
 
146
  if __name__ == "__main__":
147
  main()