Shankarm08 commited on
Commit
9e487ab
·
verified ·
1 Parent(s): 5fa3c44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -12
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import streamlit as st
2
  import torch
 
3
  from transformers import BertTokenizer, BertModel
4
  import pdfplumber
 
5
 
6
  # Load the pre-trained BERT model and tokenizer once
7
  model_name = "bert-base-uncased"
@@ -27,12 +29,11 @@ def get_embeddings(text):
27
  with torch.no_grad(): # Disable gradient calculation for inference
28
  outputs = model(**inputs)
29
 
30
- # Check if the output contains the last hidden state
31
  if hasattr(outputs, 'last_hidden_state'):
32
- # Extract the embeddings from the last hidden state
33
  return outputs.last_hidden_state[:, 0, :].detach().cpu().numpy() # Move to CPU before converting to numpy
34
  else:
35
- raise ValueError("Model output does not contain 'last_hidden_state'. Please check the model configuration.")
36
 
37
  # Extract text from PDF
38
  def extract_text_from_pdf(pdf_file):
@@ -42,9 +43,9 @@ def extract_text_from_pdf(pdf_file):
42
  text += page.extract_text() + "\n" # Add newline for better separation
43
  return text
44
 
45
- # Store the PDF text and embeddings
46
- pdf_text = ""
47
- pdf_embeddings = None
48
 
49
  # Streamlit app
50
  st.title("PDF Chatbot using BERT")
@@ -52,10 +53,15 @@ st.title("PDF Chatbot using BERT")
52
  # PDF file upload
53
  pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"])
54
 
 
 
 
 
55
  if pdf_file:
56
  pdf_text = extract_text_from_pdf(pdf_file)
57
  try:
58
- pdf_embeddings = get_embeddings(pdf_text)
 
59
  st.success("PDF loaded successfully!")
60
  except Exception as e:
61
  st.error(f"Error while processing PDF: {e}")
@@ -64,16 +70,23 @@ if pdf_file:
64
  user_input = st.text_input("Ask a question about the PDF:")
65
 
66
  if st.button("Get Response"):
67
- if pdf_text == "":
68
  st.warning("Please upload a PDF file first.")
 
 
69
  else:
70
- # Get embeddings for user input
71
  try:
72
  user_embeddings = get_embeddings(user_input)
73
- # For demonstration, simply return the PDF text.
74
- # Implement similarity matching logic here as needed.
 
 
 
 
 
75
  st.write("### Response:")
76
- st.write(pdf_text) # For simplicity, returning all text
 
77
  except Exception as e:
78
  st.error(f"Error while processing user input: {e}")
79
 
 
1
  import streamlit as st
2
  import torch
3
+ import numpy as np
4
  from transformers import BertTokenizer, BertModel
5
  import pdfplumber
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
 
8
  # Load the pre-trained BERT model and tokenizer once
9
  model_name = "bert-base-uncased"
 
29
  with torch.no_grad(): # Disable gradient calculation for inference
30
  outputs = model(**inputs)
31
 
32
+ # Extract the embeddings from the last hidden state
33
  if hasattr(outputs, 'last_hidden_state'):
 
34
  return outputs.last_hidden_state[:, 0, :].detach().cpu().numpy() # Move to CPU before converting to numpy
35
  else:
36
+ raise ValueError("Model output does not contain 'last_hidden_state'.")
37
 
38
  # Extract text from PDF
39
  def extract_text_from_pdf(pdf_file):
 
43
  text += page.extract_text() + "\n" # Add newline for better separation
44
  return text
45
 
46
+ # Split text into sentences for better matching
47
+ def split_text_into_sentences(text):
48
+ return text.split('\n') # Split by newlines; adjust as needed
49
 
50
  # Streamlit app
51
  st.title("PDF Chatbot using BERT")
 
53
  # PDF file upload
54
  pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"])
55
 
56
+ # Store the PDF text and embeddings
57
+ pdf_text = ""
58
+ pdf_embeddings = None
59
+
60
  if pdf_file:
61
  pdf_text = extract_text_from_pdf(pdf_file)
62
  try:
63
+ pdf_sentences = split_text_into_sentences(pdf_text) # Split PDF text into sentences
64
+ pdf_embeddings = np.array([get_embeddings(sentence) for sentence in pdf_sentences]) # Get embeddings for each sentence
65
  st.success("PDF loaded successfully!")
66
  except Exception as e:
67
  st.error(f"Error while processing PDF: {e}")
 
70
  user_input = st.text_input("Ask a question about the PDF:")
71
 
72
  if st.button("Get Response"):
73
+ if not pdf_sentences:
74
  st.warning("Please upload a PDF file first.")
75
+ elif not user_input.strip():
76
+ st.warning("Please enter a question.")
77
  else:
 
78
  try:
79
  user_embeddings = get_embeddings(user_input)
80
+ user_embeddings = user_embeddings.reshape(1, -1) # Reshape for cosine similarity calculation
81
+
82
+ # Calculate cosine similarity between user input and PDF sentence embeddings
83
+ similarities = cosine_similarity(user_embeddings, pdf_embeddings)
84
+ best_match_index = np.argmax(similarities) # Get the index of the best match
85
+
86
+ # Display the most relevant sentence
87
  st.write("### Response:")
88
+ st.write(pdf_sentences[best_match_index]) # Return the most relevant sentence
89
+
90
  except Exception as e:
91
  st.error(f"Error while processing user input: {e}")
92