Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
import torch | |
import transformers | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
# Download and load the model and tokenizer | |
model_name = 'bert-base-uncased' | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
# Define a function to classify a single text | |
def classify_text(text): | |
# Tokenize the text and add special tokens | |
inputs = tokenizer.encode_plus( | |
text, | |
add_special_tokens=True, | |
return_tensors='pt', | |
max_length=512 | |
) | |
# Get the input IDs and attention mask | |
input_ids = inputs['input_ids'] | |
attention_mask = inputs['attention_mask'] | |
# Get the predicted label | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask) | |
logits = outputs[0] | |
predicted_label = torch.argmax(logits, dim=1).item() | |
return predicted_label | |
# Define the Streamlit app | |
def main(): | |
st.title('Text Classification with BERT') | |
# Allow the user to upload a CSV file | |
uploaded_file = st.file_uploader('Upload a CSV file', type='csv') | |
if uploaded_file is not None: | |
data = pd.read_csv(uploaded_file) | |
# Create a new column for the predicted labels | |
data['predicted_label'] = data['text'].apply(classify_text) | |
st.write(data) | |
if __name__ == '__main__': | |
main() | |