``` from transformers import AutoModelForSequenceClassification, AutoTokenizer import torch import pandas as pd # Load the merged model and tokenizer model_path='POLLCHECK/Llama3.1-bias-sequence-classifier' tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForSequenceClassification.from_pretrained(model_path) model.eval() # Function to classify text and return probabilities def classify_text(text): # Tokenize the input text and convert it to lower case inputs = tokenizer(text.lower(), return_tensors="pt", truncation=True, max_length=512) inputs = {k: v.to(model.device) for k, v in inputs.items()} # Ensure inputs are on the correct device # Perform inference without gradient calculation with torch.no_grad(): outputs = model(**inputs) # Extract logits from the model output logits = outputs.logits # Compute probabilities using the softmax function probabilities = torch.nn.functional.softmax(logits, dim=1).squeeze().cpu().numpy() # Get the index of the class with the highest probability predicted_class = torch.argmax(logits, dim=1).item() # Extract the confidence score for the predicted class confidence = probabilities[predicted_class] # Map class indices to class labels class_mapping = {0: "Biased", 1: "Unbiased"} predicted_label = class_mapping[predicted_class] return predicted_label, confidence, probabilities # Load the CSV file df = pd.read_csv('/h/sraza/news-media-bias-plus/classifiers/LLM/data/clean_data.csv') texts = df['text_content'].tolist() labels = df['text_label'].tolist() # Convert labels to lower case for case-insensitive comparison labels = [label.lower() for label in labels] # Classify a few sample texts and display ground truth along with probabilities for text, ground_truth in zip(texts[:5], labels[:5]): # Classify first 5 texts as an example predicted_label, confidence, probabilities = classify_text(text) print(f"Text: {text[:100]}...") # Print first 100 characters print(f"Ground Truth: {ground_truth}") print(f"Predicted class: {predicted_label}") print(f"Confidence: {confidence:.2f}") print(f"Probabilities: {probabilities}") print("---")