Shankarm08 commited on
Commit
c5608b5
·
verified ·
1 Parent(s): e3d3e2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -44
app.py CHANGED
@@ -1,25 +1,26 @@
1
- import gradio as gr
2
  import torch
3
  from transformers import BertTokenizer, BertModel
4
- from fastapi import FastAPI, HTTPException
5
- from pydantic import BaseModel
6
  import pdfplumber
7
 
8
- app = FastAPI()
 
 
 
9
 
10
- class TextClassificationRequest(BaseModel):
11
- text: str
12
-
13
- @app.post("/classify")
14
- async def classify_text(request: TextClassificationRequest):
15
- # Load the pre-trained BERT model and tokenizer
16
- model_name = "bert-base-uncased"
17
- tokenizer = BertTokenizer.from_pretrained(model_name)
18
- model = BertModel.from_pretrained(model_name)
19
 
 
 
20
  # Preprocess the input text
21
  inputs = tokenizer.encode_plus(
22
- request.text,
23
  add_special_tokens=True,
24
  max_length=512,
25
  return_attention_mask=True,
@@ -32,38 +33,22 @@ async def classify_text(request: TextClassificationRequest):
32
  # Extract the features
33
  features = outputs.last_hidden_state[:, 0, :]
34
 
35
- # Return the features as a list
36
- return {"features": features.tolist()}
37
 
38
- # Define a function to extract text from a PDF
39
- def extract_text_from_pdf(pdf_file):
40
- with pdfplumber.open(pdf_file) as pdf:
41
- text = ""
42
- for page in pdf.pages:
43
- text += page.extract_text()
44
- return text
45
 
46
- # Create a Gradio interface for handling PDF input
47
- def classify_pdf(pdf_file):
48
  # Extract text from the uploaded PDF
49
  extracted_text = extract_text_from_pdf(pdf_file)
50
-
51
- # Create the request for FastAPI
52
- request = TextClassificationRequest(text=extracted_text)
53
-
54
- # Simulate calling the FastAPI endpoint
55
- output = classify_text(request)
56
-
57
- return output
58
-
59
- # Create a Gradio interface
60
- interface = gr.Interface(
61
- fn=classify_pdf,
62
- inputs="file", # Expecting PDF file input
63
- outputs="json", # Outputs a JSON dictionary
64
- title="PDF Text Classification",
65
- description="Upload a PDF file to classify its text using BERT"
66
- )
67
 
68
- # Launch the Gradio interface
69
- interface.launch(server_port=7861)
 
 
 
1
+ import streamlit as st
2
  import torch
3
  from transformers import BertTokenizer, BertModel
 
 
4
  import pdfplumber
5
 
6
+ # Load the pre-trained BERT model and tokenizer outside the function for efficiency
7
+ model_name = "bert-base-uncased"
8
+ tokenizer = BertTokenizer.from_pretrained(model_name)
9
+ model = BertModel.from_pretrained(model_name)
10
 
11
+ # Define a function to extract text from a PDF
12
+ def extract_text_from_pdf(pdf_file):
13
+ with pdfplumber.open(pdf_file) as pdf:
14
+ text = ""
15
+ for page in pdf.pages:
16
+ text += page.extract_text()
17
+ return text
 
 
18
 
19
+ # Define a function to classify the extracted text
20
+ def classify_text(text):
21
  # Preprocess the input text
22
  inputs = tokenizer.encode_plus(
23
+ text,
24
  add_special_tokens=True,
25
  max_length=512,
26
  return_attention_mask=True,
 
33
  # Extract the features
34
  features = outputs.last_hidden_state[:, 0, :]
35
 
36
+ return features.tolist()
 
37
 
38
+ # Streamlit app setup
39
+ st.title("PDF Text Classification")
40
+ st.write("Upload a PDF file to classify its text using BERT")
41
+
42
+ # File uploader for PDFs
43
+ pdf_file = st.file_uploader("Choose a PDF file", type="pdf")
 
44
 
45
+ if pdf_file is not None:
 
46
  # Extract text from the uploaded PDF
47
  extracted_text = extract_text_from_pdf(pdf_file)
48
+ st.write("Extracted Text:")
49
+ st.write(extracted_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # Classify the extracted text
52
+ if st.button("Classify"):
53
+ features = classify_text(extracted_text)
54
+ st.json({"features": features}) # Display the features in JSON format