Shankarm08 commited on
Commit
96f0bc8
·
verified ·
1 Parent(s): a472326

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -11
app.py CHANGED
@@ -10,6 +10,10 @@ model = BertModel.from_pretrained(model_name)
10
 
11
  # Function to get BERT embeddings
12
  def get_embeddings(text):
 
 
 
 
13
  # Ensure that text length does not exceed BERT's maximum input length
14
  inputs = tokenizer.encode_plus(
15
  text,
@@ -22,9 +26,13 @@ def get_embeddings(text):
22
 
23
  with torch.no_grad(): # Disable gradient calculation for inference
24
  outputs = model(**inputs)
25
-
26
- # Extract the embeddings from the last hidden state
27
- return outputs.last_hidden_state[:, 0, :].detach().cpu().numpy() # Move to CPU before converting to numpy
 
 
 
 
28
 
29
  # Extract text from PDF
30
  def extract_text_from_pdf(pdf_file):
@@ -46,8 +54,11 @@ pdf_file = st.file_uploader("Upload a PDF file", type=["pdf"])
46
 
47
  if pdf_file:
48
  pdf_text = extract_text_from_pdf(pdf_file)
49
- pdf_embeddings = get_embeddings(pdf_text)
50
- st.success("PDF loaded successfully!")
 
 
 
51
 
52
  # User input for chatbot
53
  user_input = st.text_input("Ask a question about the PDF:")
@@ -57,10 +68,12 @@ if st.button("Get Response"):
57
  st.warning("Please upload a PDF file first.")
58
  else:
59
  # Get embeddings for user input
60
- user_embeddings = get_embeddings(user_input)
61
-
62
- # For demonstration, simply return the PDF text.
63
- # Implement similarity matching logic here as needed.
64
- st.write("### Response:")
65
- st.write(pdf_text) # For simplicity, returning all text
 
 
66
 
 
10
 
11
  # Function to get BERT embeddings
12
  def get_embeddings(text):
13
+ # Check if input text is empty
14
+ if not text.strip():
15
+ raise ValueError("Input text is empty.")
16
+
17
  # Ensure that text length does not exceed BERT's maximum input length
18
  inputs = tokenizer.encode_plus(
19
  text,
 
26
 
27
  with torch.no_grad(): # Disable gradient calculation for inference
28
  outputs = model(**inputs)
29
+
30
+ # Check if the output contains the last hidden state
31
+ if hasattr(outputs, 'last_hidden_state'):
32
+ # Extract the embeddings from the last hidden state
33
+ return outputs.last_hidden_state[:, 0, :].detach().cpu().numpy() # Move to CPU before converting to numpy
34
+ else:
35
+ raise ValueError("Model output does not contain 'last_hidden_state'. Please check the model configuration.")
36
 
37
  # Extract text from PDF
38
  def extract_text_from_pdf(pdf_file):
 
54
 
55
  if pdf_file:
56
  pdf_text = extract_text_from_pdf(pdf_file)
57
+ try:
58
+ pdf_embeddings = get_embeddings(pdf_text)
59
+ st.success("PDF loaded successfully!")
60
+ except Exception as e:
61
+ st.error(f"Error while processing PDF: {e}")
62
 
63
  # User input for chatbot
64
  user_input = st.text_input("Ask a question about the PDF:")
 
68
  st.warning("Please upload a PDF file first.")
69
  else:
70
  # Get embeddings for user input
71
+ try:
72
+ user_embeddings = get_embeddings(user_input)
73
+ # For demonstration, simply return the PDF text.
74
+ # Implement similarity matching logic here as needed.
75
+ st.write("### Response:")
76
+ st.write(pdf_text) # For simplicity, returning all text
77
+ except Exception as e:
78
+ st.error(f"Error while processing user input: {e}")
79