ADKU commited on
Commit
9699ac9
·
verified ·
1 Parent(s): a91f0db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -27
app.py CHANGED
@@ -8,7 +8,9 @@ from fastapi import FastAPI
8
  from pydantic import BaseModel
9
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast, AutoTokenizer, AutoModel
10
 
11
- # Set Hugging Face cache directory
 
 
12
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface"
13
 
14
  app = FastAPI()
@@ -22,63 +24,77 @@ if not os.path.exists(DATASET_PATH):
22
  # Load dataset
23
  df = pd.read_json(DATASET_PATH)
24
 
25
- # Clean text function
26
  def clean_text(text):
27
  return text.strip().lower()
28
 
29
- df['cleaned_abstract'] = df['abstract'].apply(clean_text)
30
 
31
- # Precompute BM25 Index
32
  tokenized_corpus = [paper.split() for paper in df["cleaned_abstract"]]
33
  bm25 = BM25Okapi(tokenized_corpus)
34
 
35
- # Load FAISS model
36
- embedding_model = "allenai/scibert_scivocab_uncased"
37
- tokenizer = AutoTokenizer.from_pretrained(embedding_model)
38
- model = AutoModel.from_pretrained(embedding_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- # Generate embeddings using SciBERT
41
- def generate_embeddings_sci_bert(texts, batch_size=32):
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
  model.to(device)
44
-
45
  all_embeddings = []
46
  for i in range(0, len(texts), batch_size):
47
- batch = texts[i:i + batch_size]
48
  inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
49
  inputs = {key: val.to(device) for key, val in inputs.items()}
50
-
51
  with torch.no_grad():
52
  outputs = model(**inputs)
53
-
54
  embeddings = outputs.last_hidden_state.mean(dim=1)
55
  all_embeddings.append(embeddings.cpu().numpy())
56
-
 
 
57
  return np.concatenate(all_embeddings, axis=0)
58
 
59
- # Compute document embeddings
60
  abstracts = df["cleaned_abstract"].tolist()
61
- embeddings = generate_embeddings_sci_bert(abstracts, batch_size=32)
62
 
63
- # Initialize FAISS index
64
  dimension = embeddings.shape[1]
65
  faiss_index = faiss.IndexFlatL2(dimension)
66
  faiss_index.add(embeddings.astype(np.float32))
67
 
68
- # API Request Model
69
  class InputText(BaseModel):
70
  query: str
71
  top_k: int = 5
72
 
73
- @app.post("/predict/")
74
- async def predict(data: InputText):
75
- query = data.query
76
- top_k = data.top_k
77
-
78
  if not query.strip():
79
  return {"error": "Query is empty. Please enter a valid search query."}
80
 
81
- # 1️⃣ Generate embedding for query
82
  query_embedding = generate_embeddings_sci_bert([query], batch_size=1)
83
 
84
  # 2️⃣ Perform FAISS similarity search
@@ -93,7 +109,7 @@ async def predict(data: InputText):
93
  combined_indices = list(set(indices[0]) | set(bm25_top_indices))
94
  ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
95
 
96
- # 5️⃣ Retrieve research papers
97
  relevant_papers = []
98
  for i, index in enumerate(ranked_results[:top_k]):
99
  paper = df.iloc[index]
@@ -101,11 +117,16 @@ async def predict(data: InputText):
101
  "rank": i + 1,
102
  "title": paper["title"],
103
  "authors": paper["authors"],
104
- "abstract": paper["cleaned_abstract"]
105
  })
106
 
107
  return {"results": relevant_papers}
108
 
 
 
 
 
 
109
  # Run FastAPI
110
  if __name__ == "__main__":
111
  import uvicorn
 
8
  from pydantic import BaseModel
9
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast, AutoTokenizer, AutoModel
10
 
11
+ # Set cache directory to /tmp/huggingface (fixes permission error)
12
+ os.environ["HF_HOME"] = "/tmp/huggingface"
13
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
14
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface"
15
 
16
  app = FastAPI()
 
24
  # Load dataset
25
  df = pd.read_json(DATASET_PATH)
26
 
27
+ # Clean text function
28
  def clean_text(text):
29
  return text.strip().lower()
30
 
31
+ df["cleaned_abstract"] = df["abstract"].apply(clean_text)
32
 
33
+ # Precompute BM25 Index
34
  tokenized_corpus = [paper.split() for paper in df["cleaned_abstract"]]
35
  bm25 = BM25Okapi(tokenized_corpus)
36
 
37
+ # Load embedding models
38
+ embedding_models = {
39
+ "BERT": "bert-base-uncased",
40
+ "DistilBERT": "distilbert-base-uncased",
41
+ "Sentence-BERT": "all-MiniLM-L6-v2",
42
+ "MiniLM": "sentence-transformers/all-MiniLM-L12-v2",
43
+ "SciBERT": "allenai/scibert_scivocab_uncased",
44
+ }
45
+
46
+ BATCH_SIZE = 32 # Batch size for processing
47
+
48
+ # ✅ Function to clear GPU memory
49
+ def clear_gpu_memory():
50
+ torch.cuda.empty_cache()
51
+
52
+ # ✅ Generate embeddings using SciBERT
53
+ def generate_embeddings_sci_bert(texts, batch_size=BATCH_SIZE):
54
+ model_name = "allenai/scibert_scivocab_uncased"
55
+
56
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/huggingface")
57
+ model = AutoModel.from_pretrained(model_name, cache_dir="/tmp/huggingface")
58
 
 
 
59
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
  model.to(device)
61
+
62
  all_embeddings = []
63
  for i in range(0, len(texts), batch_size):
64
+ batch = texts[i : i + batch_size]
65
  inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
66
  inputs = {key: val.to(device) for key, val in inputs.items()}
67
+
68
  with torch.no_grad():
69
  outputs = model(**inputs)
70
+
71
  embeddings = outputs.last_hidden_state.mean(dim=1)
72
  all_embeddings.append(embeddings.cpu().numpy())
73
+
74
+ clear_gpu_memory()
75
+
76
  return np.concatenate(all_embeddings, axis=0)
77
 
78
+ # Compute embeddings
79
  abstracts = df["cleaned_abstract"].tolist()
80
+ embeddings = generate_embeddings_sci_bert(abstracts, batch_size=BATCH_SIZE)
81
 
82
+ # Initialize FAISS index
83
  dimension = embeddings.shape[1]
84
  faiss_index = faiss.IndexFlatL2(dimension)
85
  faiss_index.add(embeddings.astype(np.float32))
86
 
87
+ # API Request Model
88
  class InputText(BaseModel):
89
  query: str
90
  top_k: int = 5
91
 
92
+ # ✅ Hybrid Search Function
93
+ def get_relevant_papers(query, top_k=5):
 
 
 
94
  if not query.strip():
95
  return {"error": "Query is empty. Please enter a valid search query."}
96
 
97
+ # 1️⃣ Generate query embedding
98
  query_embedding = generate_embeddings_sci_bert([query], batch_size=1)
99
 
100
  # 2️⃣ Perform FAISS similarity search
 
109
  combined_indices = list(set(indices[0]) | set(bm25_top_indices))
110
  ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
111
 
112
+ # 5️⃣ Retrieve relevant papers
113
  relevant_papers = []
114
  for i, index in enumerate(ranked_results[:top_k]):
115
  paper = df.iloc[index]
 
117
  "rank": i + 1,
118
  "title": paper["title"],
119
  "authors": paper["authors"],
120
+ "abstract": paper["cleaned_abstract"],
121
  })
122
 
123
  return {"results": relevant_papers}
124
 
125
+ # ✅ FastAPI Endpoint
126
+ @app.post("/predict/")
127
+ async def predict(data: InputText):
128
+ return get_relevant_papers(data.query, data.top_k)
129
+
130
  # Run FastAPI
131
  if __name__ == "__main__":
132
  import uvicorn