Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import pandas as pd | |
| import torch | |
| import transformers | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # Download and load the model and tokenizer | |
| model_name = 'bert-base-uncased' | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| # Define a function to classify a single text | |
| def classify_text(text): | |
| # Tokenize the text and add special tokens | |
| inputs = tokenizer.encode_plus( | |
| text, | |
| add_special_tokens=True, | |
| return_tensors='pt', | |
| max_length=512 | |
| ) | |
| # Get the input IDs and attention mask | |
| input_ids = inputs['input_ids'] | |
| attention_mask = inputs['attention_mask'] | |
| # Get the predicted label | |
| with torch.no_grad(): | |
| outputs = model(input_ids, attention_mask) | |
| logits = outputs[0] | |
| predicted_label = torch.argmax(logits, dim=1).item() | |
| return predicted_label | |
| # Define the Streamlit app | |
| def main(): | |
| st.title('Text Classification with BERT') | |
| # Allow the user to upload a CSV file | |
| uploaded_file = st.file_uploader('Upload a CSV file', type='csv') | |
| if uploaded_file is not None: | |
| data = pd.read_csv(uploaded_file) | |
| # Create a new column for the predicted labels | |
| data['predicted_label'] = data['text'].apply(classify_text) | |
| st.write(data) | |
| if __name__ == '__main__': | |
| main() | |