ADKU commited on
Commit
12e1b40
·
verified ·
1 Parent(s): 3bee96e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -80
app.py CHANGED
@@ -6,111 +6,155 @@ import torch
6
  import pandas as pd
7
  import gradio as gr
8
  from transformers import AutoTokenizer, AutoModel, GPT2LMHeadModel, GPT2Tokenizer
 
 
 
 
 
9
 
10
  # Set cache directory for Hugging Face models
11
  os.environ["HF_HOME"] = "/tmp/huggingface"
12
 
13
- # Load dataset
14
  DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
15
- if not os.path.exists(DATASET_PATH):
16
- raise FileNotFoundError(f"Dataset file not found at {DATASET_PATH}")
17
- df = pd.read_json(DATASET_PATH)
 
 
 
 
 
18
 
19
  # Clean text
20
  def clean_text(text):
21
- return text.strip().lower()
22
 
23
  df["cleaned_abstract"] = df["abstract"].apply(clean_text)
24
 
25
  # Precompute BM25 Index
26
- tokenized_corpus = [paper.split() for paper in df["cleaned_abstract"]]
27
- bm25 = BM25Okapi(tokenized_corpus)
 
 
 
 
 
28
 
29
- # Load SciBERT for embeddings (preloaded globally)
30
- sci_bert_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
31
- sci_bert_model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
- sci_bert_model.to(device)
34
- sci_bert_model.eval()
35
 
36
- # Load GPT-2 for QA (using distilgpt2 for efficiency)
37
- gpt2_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2", cache_dir="/tmp/huggingface")
38
- gpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2", cache_dir="/tmp/huggingface")
39
- gpt2_model.to(device)
40
- gpt2_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # Generate SciBERT embeddings
43
  def generate_embeddings_sci_bert(texts, batch_size=32):
44
- all_embeddings = []
45
- for i in range(0, len(texts), batch_size):
46
- batch = texts[i:i + batch_size]
47
- inputs = sci_bert_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
48
- inputs = {key: val.to(device) for key, val in inputs.items()}
49
- with torch.no_grad():
50
- outputs = sci_bert_model(**inputs)
51
- embeddings = outputs.last_hidden_state.mean(dim=1)
52
- all_embeddings.append(embeddings.cpu().numpy())
53
- torch.cuda.empty_cache()
54
- return np.concatenate(all_embeddings, axis=0)
 
 
 
 
55
 
56
  # Precompute embeddings and FAISS index
57
- abstracts = df["cleaned_abstract"].tolist()
58
- embeddings = generate_embeddings_sci_bert(abstracts)
59
- dimension = embeddings.shape[1]
60
- faiss_index = faiss.IndexFlatL2(dimension)
61
- faiss_index.add(embeddings.astype(np.float32))
 
 
 
 
 
62
 
63
  # Hybrid search function
64
  def get_relevant_papers(query, top_k=5):
65
  if not query.strip():
66
  return []
67
- query_embedding = generate_embeddings_sci_bert([query])
68
- distances, indices = faiss_index.search(query_embedding.astype(np.float32), top_k)
69
- tokenized_query = query.split()
70
- bm25_scores = bm25.get_scores(tokenized_query)
71
- bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
72
- combined_indices = list(set(indices[0]) | set(bm25_top_indices))
73
- ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
74
- papers = []
75
- for i, index in enumerate(ranked_results[:top_k]):
76
- paper = df.iloc[index]
77
- papers.append(f"{i+1}. {paper['title']} - Abstract: {paper['cleaned_abstract'][:200]}...")
78
- return papers
 
 
 
 
79
 
80
  # GPT-2 QA function
81
  def answer_question(paper, question, history):
 
 
82
  if not question.strip():
83
- return "Please ask a question!", history
84
  if question.lower() in ["exit", "done"]:
85
- return "Conversation ended. Select a new paper or search again!", []
86
-
87
- # Extract title and abstract from paper string
88
- title = paper.split(" - Abstract: ")[0].split(". ", 1)[1]
89
- abstract = paper.split(" - Abstract: ")[1].rstrip("...")
90
-
91
- # Build context with history
92
- context = f"Title: {title}\nAbstract: {abstract}\n\nPrevious conversation:\n"
93
- for user_q, bot_a in history:
94
- context += f"User: {user_q}\nAssistant: {bot_a}\n"
95
- context += f"User: {question}\nAssistant: "
96
-
97
- # Generate response
98
- inputs = gpt2_tokenizer(context, return_tensors="pt", truncation=True, max_length=512)
99
- inputs = {key: val.to(device) for key, val in inputs.items()}
100
- with torch.no_grad():
101
- outputs = gpt2_model.generate(
102
- inputs["input_ids"],
103
- max_new_tokens=100,
104
- do_sample=True,
105
- temperature=0.7,
106
- top_k=50,
107
- pad_token_id=gpt2_tokenizer.eos_token_id
108
- )
109
- response = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True)
110
- response = response[len(context):].strip()
111
-
112
- history.append((question, response))
113
- return response, history
 
 
 
 
 
114
 
115
  # Gradio UI
116
  with gr.Blocks(
@@ -148,18 +192,21 @@ with gr.Blocks(
148
 
149
  # Update selected paper
150
  paper_dropdown.change(
151
- fn=lambda x: x,
152
  inputs=paper_dropdown,
153
- outputs=selected_paper
154
  )
155
 
156
  # Handle chat
157
  chat_btn.click(
158
  fn=answer_question,
159
  inputs=[selected_paper, question_input, history_state],
160
- outputs=[chatbot, history_state],
161
- _js="() => {document.querySelector('.chatbot').scrollTop = document.querySelector('.chatbot').scrollHeight;}"
 
 
 
162
  )
163
 
164
  # Launch the app
165
- demo.launch()
 
6
  import pandas as pd
7
  import gradio as gr
8
  from transformers import AutoTokenizer, AutoModel, GPT2LMHeadModel, GPT2Tokenizer
9
+ import logging
10
+
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
 
15
  # Set cache directory for Hugging Face models
16
  os.environ["HF_HOME"] = "/tmp/huggingface"
17
 
18
+ # Load dataset with error handling
19
  DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
20
+ try:
21
+ if not os.path.exists(DATASET_PATH):
22
+ raise FileNotFoundError(f"Dataset file not found at {DATASET_PATH}")
23
+ df = pd.read_json(DATASET_PATH)
24
+ logger.info("Dataset loaded successfully")
25
+ except Exception as e:
26
+ logger.error(f"Failed to load dataset: {e}")
27
+ raise
28
 
29
  # Clean text
30
  def clean_text(text):
31
+ return text.strip().lower() if isinstance(text, str) else ""
32
 
33
  df["cleaned_abstract"] = df["abstract"].apply(clean_text)
34
 
35
  # Precompute BM25 Index
36
+ try:
37
+ tokenized_corpus = [paper.split() for paper in df["cleaned_abstract"]]
38
+ bm25 = BM25Okapi(tokenized_corpus)
39
+ logger.info("BM25 index created")
40
+ except Exception as e:
41
+ logger.error(f"BM25 index creation failed: {e}")
42
+ raise
43
 
44
+ # Load models with error handling
 
 
45
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+ logger.info(f"Using device: {device}")
 
47
 
48
+ try:
49
+ # SciBERT for embeddings
50
+ sci_bert_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
51
+ sci_bert_model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
52
+ sci_bert_model.to(device)
53
+ sci_bert_model.eval()
54
+ logger.info("SciBERT loaded")
55
+
56
+ # DistilGPT-2 for QA
57
+ gpt2_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2", cache_dir="/tmp/huggingface")
58
+ gpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2", cache_dir="/tmp/huggingface")
59
+ gpt2_model.to(device)
60
+ gpt2_model.eval()
61
+ logger.info("DistilGPT-2 loaded")
62
+ except Exception as e:
63
+ logger.error(f"Model loading failed: {e}")
64
+ raise
65
 
66
  # Generate SciBERT embeddings
67
  def generate_embeddings_sci_bert(texts, batch_size=32):
68
+ try:
69
+ all_embeddings = []
70
+ for i in range(0, len(texts), batch_size):
71
+ batch = texts[i:i + batch_size]
72
+ inputs = sci_bert_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
73
+ inputs = {key: val.to(device) for key, val in inputs.items()}
74
+ with torch.no_grad():
75
+ outputs = sci_bert_model(**inputs)
76
+ embeddings = outputs.last_hidden_state.mean(dim=1)
77
+ all_embeddings.append(embeddings.cpu().numpy())
78
+ torch.cuda.empty_cache()
79
+ return np.concatenate(all_embeddings, axis=0)
80
+ except Exception as e:
81
+ logger.error(f"Embedding generation failed: {e}")
82
+ return np.zeros((len(texts), 768)) # Fallback to zero embeddings
83
 
84
  # Precompute embeddings and FAISS index
85
+ try:
86
+ abstracts = df["cleaned_abstract"].tolist()
87
+ embeddings = generate_embeddings_sci_bert(abstracts)
88
+ dimension = embeddings.shape[1]
89
+ faiss_index = faiss.IndexFlatL2(dimension)
90
+ faiss_index.add(embeddings.astype(np.float32))
91
+ logger.info("FAISS index created")
92
+ except Exception as e:
93
+ logger.error(f"FAISS index creation failed: {e}")
94
+ raise
95
 
96
  # Hybrid search function
97
  def get_relevant_papers(query, top_k=5):
98
  if not query.strip():
99
  return []
100
+ try:
101
+ query_embedding = generate_embeddings_sci_bert([query])
102
+ distances, indices = faiss_index.search(query_embedding.astype(np.float32), top_k)
103
+ tokenized_query = query.split()
104
+ bm25_scores = bm25.get_scores(tokenized_query)
105
+ bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
106
+ combined_indices = list(set(indices[0]) | set(bm25_top_indices))
107
+ ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
108
+ papers = []
109
+ for i, index in enumerate(ranked_results[:top_k]):
110
+ paper = df.iloc[index]
111
+ papers.append(f"{i+1}. {paper['title']} - Abstract: {paper['cleaned_abstract'][:200]}...")
112
+ return papers
113
+ except Exception as e:
114
+ logger.error(f"Search failed: {e}")
115
+ return ["Search failed. Please try again."]
116
 
117
  # GPT-2 QA function
118
  def answer_question(paper, question, history):
119
+ if not paper:
120
+ return [("Please select a paper first!", "")], history
121
  if not question.strip():
122
+ return [(question, "Please ask a question!")], history
123
  if question.lower() in ["exit", "done"]:
124
+ return [("Conversation ended. Select a new paper or search again!", "")], []
125
+
126
+ try:
127
+ # Extract title and abstract
128
+ title = paper.split(" - Abstract: ")[0].split(". ", 1)[1]
129
+ abstract = paper.split(" - Abstract: ")[1].rstrip("...")
130
+
131
+ # Build context with history
132
+ context = f"Title: {title}\nAbstract: {abstract}\n\nPrevious conversation:\n"
133
+ for user_q, bot_a in history:
134
+ context += f"User: {user_q}\nAssistant: {bot_a}\n"
135
+ context += f"User: {question}\nAssistant: "
136
+
137
+ # Generate response
138
+ inputs = gpt2_tokenizer(context, return_tensors="pt", truncation=True, max_length=512)
139
+ inputs = {key: val.to(device) for key, val in inputs.items()}
140
+ with torch.no_grad():
141
+ outputs = gpt2_model.generate(
142
+ inputs["input_ids"],
143
+ max_new_tokens=100,
144
+ do_sample=True,
145
+ temperature=0.7,
146
+ top_k=50,
147
+ pad_token_id=gpt2_tokenizer.eos_token_id
148
+ )
149
+ response = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True)
150
+ response = response[len(context):].strip()
151
+
152
+ history.append((question, response))
153
+ return history, history # Return updated history for Chatbot
154
+ except Exception as e:
155
+ logger.error(f"QA failed: {e}")
156
+ history.append((question, "Sorry, I couldn’t process that. Try again!"))
157
+ return history, history
158
 
159
  # Gradio UI
160
  with gr.Blocks(
 
192
 
193
  # Update selected paper
194
  paper_dropdown.change(
195
+ fn=lambda x: (x, []), # Reset history when new paper selected
196
  inputs=paper_dropdown,
197
+ outputs=[selected_paper, history_state]
198
  )
199
 
200
  # Handle chat
201
  chat_btn.click(
202
  fn=answer_question,
203
  inputs=[selected_paper, question_input, history_state],
204
+ outputs=[chatbot, history_state]
205
+ ).then(
206
+ fn=lambda: "",
207
+ inputs=None,
208
+ outputs=question_input # Clear question input after sending
209
  )
210
 
211
  # Launch the app
212
+ demo.launch(server_name="0.0.0.0", server_port=7860)