Dhanush4149 commited on
Commit
038a86f
·
verified ·
1 Parent(s): 0330532

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -4
app.py CHANGED
@@ -1,18 +1,54 @@
 
1
  import streamlit as st
2
  from transformers import pipeline
3
  import traceback
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def load_pipelines():
6
  """
7
- Load summarization pipelines with error handling.
8
 
9
  Returns:
10
  dict: Dictionary of model pipelines
11
  """
12
  try:
13
- bart_pipeline = pipeline("summarization", model="facebook/bart-large-cnn")
14
- t5_pipeline = pipeline("summarization", model="t5-large")
15
- pegasus_pipeline = pipeline("pegasus-cnn_dailymail")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  return {
17
  'BART': bart_pipeline,
18
  'T5': t5_pipeline,
@@ -52,6 +88,9 @@ def generate_summary(pipeline, text, model_name):
52
  def main():
53
  st.title("Text Summarization with Pre-trained Models")
54
 
 
 
 
55
  # Text input
56
  text_input = st.text_area("Enter text to summarize:")
57
 
 
1
+ import os
2
  import streamlit as st
3
  from transformers import pipeline
4
  import traceback
5
 
6
+ # Use Hugging Face Spaces' recommended persistent storage
7
+ CACHE_DIR = os.path.join(os.getcwd(), "model_cache")
8
+
9
+ def ensure_cache_dir():
10
+ """
11
+ Ensure the cache directory exists.
12
+
13
+ Returns:
14
+ str: Path to the cache directory
15
+ """
16
+ os.makedirs(CACHE_DIR, exist_ok=True)
17
+ return CACHE_DIR
18
+
19
  def load_pipelines():
20
  """
21
+ Load summarization pipelines with persistent caching.
22
 
23
  Returns:
24
  dict: Dictionary of model pipelines
25
  """
26
  try:
27
+ # Ensure cache directory exists
28
+ cache_dir = ensure_cache_dir()
29
+
30
+ # Define model paths within the cache directory
31
+ bart_cache = os.path.join(cache_dir, "bart-large-cnn")
32
+ t5_cache = os.path.join(cache_dir, "t5-large")
33
+ pegasus_cache = os.path.join(cache_dir, "pegasus-cnn_dailymail")
34
+
35
+ # Load pipelines with explicit cache directories
36
+ bart_pipeline = pipeline(
37
+ "summarization",
38
+ model="facebook/bart-large-cnn",
39
+ cache_dir=bart_cache
40
+ )
41
+ t5_pipeline = pipeline(
42
+ "summarization",
43
+ model="t5-large",
44
+ cache_dir=t5_cache
45
+ )
46
+ pegasus_pipeline = pipeline(
47
+ "summarization",
48
+ model="google/pegasus-cnn_dailymail",
49
+ cache_dir=pegasus_cache
50
+ )
51
+
52
  return {
53
  'BART': bart_pipeline,
54
  'T5': t5_pipeline,
 
88
  def main():
89
  st.title("Text Summarization with Pre-trained Models")
90
 
91
+ # Display cache directory info (optional)
92
+ st.info(f"Models will be cached in: {CACHE_DIR}")
93
+
94
  # Text input
95
  text_input = st.text_area("Enter text to summarize:")
96