mindspark121 commited on
Commit
23943df
·
verified ·
1 Parent(s): a49f6a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -58
app.py CHANGED
@@ -1,85 +1,71 @@
1
- import torch
2
- import pandas as pd
3
- import faiss
4
- import numpy as np
5
- from sentence_transformers import SentenceTransformer
6
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
  from fastapi import FastAPI
8
  from pydantic import BaseModel
 
 
 
 
9
 
10
- # 🔹 Initialize FastAPI
11
  app = FastAPI()
12
- @app.get("/")
13
- def home():
14
- return {"message": "Welcome to the AI Psychiatry API!"}
15
 
16
- # 🔹 Load AI Models
17
  similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
18
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
19
  summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base")
20
  summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")
21
 
22
- # 🔹 Load Datasets (Ensure files are uploaded to Hugging Face Space)
23
- try:
24
- recommendations_df = pd.read_csv("treatment_recommendations.csv")
25
- questions_df = pd.read_csv("symptom_questions.csv")
26
- except FileNotFoundError:
27
- recommendations_df = pd.DataFrame(columns=["Disorder", "Treatment Recommendation"])
28
- questions_df = pd.DataFrame(columns=["Questions"])
29
 
30
- # 🔹 Create FAISS Index for Treatment Retrieval
31
- if not recommendations_df.empty:
32
- treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
33
- index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
34
- index.add(treatment_embeddings)
35
- else:
36
- index = None
37
 
38
- # 🔹 Create FAISS Index for Question Retrieval
39
- if not questions_df.empty:
40
- question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
41
- question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
42
- question_index.add(question_embeddings)
43
- else:
44
- question_index = None
45
 
46
- # 🔹 API Request Model
47
  class ChatRequest(BaseModel):
48
  message: str
49
 
50
- @app.post("/detect_disorders")
51
- def detect_disorders(request: ChatRequest):
52
- """ Detect psychiatric disorders from user input """
53
- if index is None:
54
- return {"error": "Dataset is missing or empty"}
55
-
56
- text_embedding = similarity_model.encode([request.message], convert_to_numpy=True)
57
- distances, indices = index.search(text_embedding, 3)
58
- disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
59
- return {"disorders": disorders}
60
-
61
- @app.post("/get_treatment")
62
- def get_treatment(request: ChatRequest):
63
- """ Retrieve treatment recommendations """
64
- detected_disorders = detect_disorders(request)["disorders"]
65
- treatments = {disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0] for disorder in detected_disorders}
66
- return {"treatments": treatments}
67
 
68
  @app.post("/get_questions")
69
  def get_recommended_questions(request: ChatRequest):
70
- """Retrieve the most relevant diagnostic questions based on patient symptoms."""
71
- if question_index is None:
72
- return {"error": "Questions dataset is missing or empty"}
73
-
74
  input_embedding = embedding_model.encode([request.message], convert_to_numpy=True)
75
  distances, indices = question_index.search(input_embedding, 3)
76
  retrieved_questions = [questions_df["Questions"].iloc[i] for i in indices[0]]
77
  return {"questions": retrieved_questions}
78
 
79
  @app.post("/summarize_chat")
80
- def summarize_chat(request: ChatRequest):
81
- """ Summarize chat logs using LongT5 """
82
- inputs = summarization_tokenizer("summarize: " + request.message, return_tensors="pt", max_length=4096, truncation=True)
 
83
  summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
84
  summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
85
- return {"summary": summary}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from sentence_transformers import SentenceTransformer
4
+ import faiss
5
+ import pandas as pd
6
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
 
 
8
  app = FastAPI()
 
 
 
9
 
10
+ # Load AI Models
11
  similarity_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
12
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
13
  summarization_model = AutoModelForSeq2SeqLM.from_pretrained("google/long-t5-tglobal-base")
14
  summarization_tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")
15
 
16
+ # Load datasets
17
+ recommendations_df = pd.read_csv("treatment_recommendations.csv")
18
+ questions_df = pd.read_csv("symptom_questions.csv")
 
 
 
 
19
 
20
+ # FAISS Index for disorder detection
21
+ treatment_embeddings = similarity_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True)
22
+ index = faiss.IndexFlatIP(treatment_embeddings.shape[1])
23
+ index.add(treatment_embeddings)
 
 
 
24
 
25
+ # FAISS Index for Question Retrieval
26
+ question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True)
27
+ question_index = faiss.IndexFlatL2(question_embeddings.shape[1])
28
+ question_index.add(question_embeddings)
 
 
 
29
 
30
+ # Request Model
31
  class ChatRequest(BaseModel):
32
  message: str
33
 
34
+ class SummaryRequest(BaseModel):
35
+ chat_history: list # List of messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  @app.post("/get_questions")
38
  def get_recommended_questions(request: ChatRequest):
39
+ """Retrieve the most relevant diagnostic questions."""
 
 
 
40
  input_embedding = embedding_model.encode([request.message], convert_to_numpy=True)
41
  distances, indices = question_index.search(input_embedding, 3)
42
  retrieved_questions = [questions_df["Questions"].iloc[i] for i in indices[0]]
43
  return {"questions": retrieved_questions}
44
 
45
  @app.post("/summarize_chat")
46
+ def summarize_chat(request: SummaryRequest):
47
+ """Summarize full chat session at the end."""
48
+ chat_text = " ".join(request.chat_history)
49
+ inputs = summarization_tokenizer("summarize: " + chat_text, return_tensors="pt", max_length=4096, truncation=True)
50
  summary_ids = summarization_model.generate(inputs.input_ids, max_length=500, num_beams=4, early_stopping=True)
51
  summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
52
+ return {"summary": summary}
53
+
54
+ @app.post("/detect_disorders")
55
+ def detect_disorders(request: SummaryRequest):
56
+ """Detect psychiatric disorders from full chat history at the end."""
57
+ full_chat_text = " ".join(request.chat_history)
58
+ text_embedding = similarity_model.encode([full_chat_text], convert_to_numpy=True)
59
+ distances, indices = index.search(text_embedding, 3)
60
+ disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]]
61
+ return {"disorders": disorders}
62
+
63
+ @app.post("/get_treatment")
64
+ def get_treatment(request: SummaryRequest):
65
+ """Retrieve treatment recommendations based on detected disorders."""
66
+ detected_disorders = detect_disorders(request)["disorders"]
67
+ treatments = {
68
+ disorder: recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0]
69
+ for disorder in detected_disorders
70
+ }
71
+ return {"treatments": treatments}