ADKU commited on
Commit
d4c27ab
·
verified ·
1 Parent(s): 03623ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -9
app.py CHANGED
@@ -1,26 +1,106 @@
1
  import os
 
 
 
 
 
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
- from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
5
- import torch
6
 
 
7
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface"
8
 
9
  app = FastAPI()
10
 
11
- model = DistilBertForSequenceClassification.from_pretrained("ADKU/ResearchGPT_model", cache_dir="/tmp/huggingface")
12
- tokenizer = DistilBertTokenizerFast.from_pretrained("ADKU/ResearchGPT_model", cache_dir="/tmp/huggingface")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
 
 
 
 
 
14
  class InputText(BaseModel):
15
- inputs: str
 
16
 
17
  @app.post("/predict/")
18
  async def predict(data: InputText):
19
- inputs = tokenizer(data.inputs, return_tensors="pt", padding=True, truncation=True)
20
- outputs = model(**inputs)
21
- prediction = torch.argmax(outputs.logits, dim=-1).item()
22
- return {"prediction": prediction}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
 
24
  if __name__ == "__main__":
25
  import uvicorn
26
  uvicorn.run(app, host="0.0.0.0")
 
1
  import os
2
+ import faiss
3
+ import numpy as np
4
+ from rank_bm25 import BM25Okapi
5
+ import torch
6
+ import pandas as pd
7
  from fastapi import FastAPI
8
  from pydantic import BaseModel
9
+ from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast, AutoTokenizer, AutoModel
 
10
 
11
+ # Set Hugging Face cache directory
12
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface"
13
 
14
  app = FastAPI()
15
 
16
+ # Load dataset
17
+ df = pd.read_json("springer_papers_DL.json")
18
+
19
+ # Clean text function
20
+ def clean_text(text):
21
+ return text.strip().lower()
22
+
23
+ df['cleaned_abstract'] = df['abstract'].apply(clean_text)
24
+
25
+ # Precompute BM25 Index
26
+ tokenized_corpus = [paper.split() for paper in df["cleaned_abstract"]]
27
+ bm25 = BM25Okapi(tokenized_corpus)
28
+
29
+ # Load FAISS model
30
+ embedding_model = "allenai/scibert_scivocab_uncased"
31
+ tokenizer = AutoTokenizer.from_pretrained(embedding_model)
32
+ model = AutoModel.from_pretrained(embedding_model)
33
+
34
+ # Generate embeddings using SciBERT
35
+ def generate_embeddings_sci_bert(texts, batch_size=32):
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ model.to(device)
38
+
39
+ all_embeddings = []
40
+ for i in range(0, len(texts), batch_size):
41
+ batch = texts[i:i + batch_size]
42
+ inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
43
+ inputs = {key: val.to(device) for key, val in inputs.items()}
44
+
45
+ with torch.no_grad():
46
+ outputs = model(**inputs)
47
+
48
+ embeddings = outputs.last_hidden_state.mean(dim=1)
49
+ all_embeddings.append(embeddings.cpu().numpy())
50
+
51
+ return np.concatenate(all_embeddings, axis=0)
52
+
53
+ # Compute document embeddings
54
+ abstracts = df["cleaned_abstract"].tolist()
55
+ embeddings = generate_embeddings_sci_bert(abstracts, batch_size=32)
56
 
57
+ # Initialize FAISS index
58
+ dimension = embeddings.shape[1]
59
+ faiss_index = faiss.IndexFlatL2(dimension)
60
+ faiss_index.add(embeddings.astype(np.float32))
61
+
62
+ # API Request Model
63
  class InputText(BaseModel):
64
+ query: str
65
+ top_k: int = 5
66
 
67
  @app.post("/predict/")
68
  async def predict(data: InputText):
69
+ query = data.query
70
+ top_k = data.top_k
71
+
72
+ if not query.strip():
73
+ return {"error": "Query is empty. Please enter a valid search query."}
74
+
75
+ # 1️⃣ Generate embedding for query
76
+ query_embedding = generate_embeddings_sci_bert([query], batch_size=1)
77
+
78
+ # 2️⃣ Perform FAISS similarity search
79
+ distances, indices = faiss_index.search(query_embedding.astype(np.float32), top_k)
80
+
81
+ # 3️⃣ Perform BM25 keyword search
82
+ tokenized_query = query.split()
83
+ bm25_scores = bm25.get_scores(tokenized_query)
84
+ bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
85
+
86
+ # 4️⃣ Combine FAISS and BM25 results
87
+ combined_indices = list(set(indices[0]) | set(bm25_top_indices))
88
+ ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
89
+
90
+ # 5️⃣ Retrieve research papers
91
+ relevant_papers = []
92
+ for i, index in enumerate(ranked_results[:top_k]):
93
+ paper = df.iloc[index]
94
+ relevant_papers.append({
95
+ "rank": i + 1,
96
+ "title": paper["title"],
97
+ "authors": paper["authors"],
98
+ "abstract": paper["cleaned_abstract"]
99
+ })
100
+
101
+ return {"results": relevant_papers}
102
 
103
+ # Run FastAPI
104
  if __name__ == "__main__":
105
  import uvicorn
106
  uvicorn.run(app, host="0.0.0.0")