Shankarm08 commited on
Commit
ceb87d2
·
verified ·
1 Parent(s): 0e6176c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -11
app.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  from transformers import BertTokenizer, BertModel
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
 
6
 
7
  app = FastAPI()
8
 
@@ -25,28 +26,44 @@ async def classify_text(request: TextClassificationRequest):
25
  return_tensors='pt'
26
  )
27
 
28
- # Create a dictionary to store the output
29
- output = {}
30
-
31
  # Use the pre-trained BERT model to extract features from the input text
32
  outputs = model(**inputs)
33
 
34
  # Extract the features
35
  features = outputs.last_hidden_state[:, 0, :]
36
 
37
- # Store the output
38
- output["features"] = features.tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
 
 
 
40
  return output
41
 
42
  # Create a Gradio interface
43
  interface = gr.Interface(
44
- fn=classify_text,
45
- inputs="pdf",
46
- outputs="text",
47
  title="PDF Text Classification",
48
- description="Upload a PDF file to classify its text"
49
  )
50
 
51
- # Launch the interface
52
- interface.launch()
 
3
  from transformers import BertTokenizer, BertModel
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
+ import pdfplumber
7
 
8
  app = FastAPI()
9
 
 
26
  return_tensors='pt'
27
  )
28
 
 
 
 
29
  # Use the pre-trained BERT model to extract features from the input text
30
  outputs = model(**inputs)
31
 
32
  # Extract the features
33
  features = outputs.last_hidden_state[:, 0, :]
34
 
35
+ # Return the features as a list
36
+ return {"features": features.tolist()}
37
+
38
+ # Define a function to extract text from a PDF
39
+ def extract_text_from_pdf(pdf_file):
40
+ with pdfplumber.open(pdf_file) as pdf:
41
+ text = ""
42
+ for page in pdf.pages:
43
+ text += page.extract_text()
44
+ return text
45
+
46
+ # Create a Gradio interface for handling PDF input
47
+ def classify_pdf(pdf_file):
48
+ # Extract text from the uploaded PDF
49
+ extracted_text = extract_text_from_pdf(pdf_file)
50
+
51
+ # Create the request for FastAPI
52
+ request = TextClassificationRequest(text=extracted_text)
53
 
54
+ # Simulate calling the FastAPI endpoint
55
+ output = classify_text(request)
56
+
57
  return output
58
 
59
  # Create a Gradio interface
60
  interface = gr.Interface(
61
+ fn=classify_pdf,
62
+ inputs="file", # Expecting PDF file input
63
+ outputs="json", # Outputs a JSON dictionary
64
  title="PDF Text Classification",
65
+ description="Upload a PDF file to classify its text using BERT"
66
  )
67
 
68
+ # Launch the Gradio interface
69
+ interface.launch()