SocialAnalyzer / model.py
F-allahmoradi's picture
Upload 3 files
f60ef34 verified
raw
history blame
4.66 kB
# -*- coding: utf-8 -*-
"""model.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1lKXL4Cdum5DiSbczUsadXc0F8j46NM_m
# in the name of **allah**
"""
import torch
from transformers import AutoTokenizer, BertForSequenceClassification
from datasets import Dataset
import pandas as pd
import re
from hazm import Normalizer, Lemmatizer, word_tokenize, stopwords_list
# Initialize Hazm components
normalizer = Normalizer()
lemmatizer = Lemmatizer()
stopwords = stopwords_list()
# Load the BERT model for sentiment analysis
dataset = Dataset.from_pandas(pd.DataFrame({"Comment": []}))
tokenizer = AutoTokenizer.from_pretrained("HooshvareLab/bert-fa-base-uncased")
model = BertForSequenceClassification.from_pretrained("HooshvareLab/bert-fa-base-uncased", num_labels=3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Tokenization function for sentiment analysis
def tokenize_function(examples):
return tokenizer(examples["Comment"], padding="max_length", truncation=True, max_length=256, return_tensors='pt')
# Sentiment prediction function
def predict_sentiment(batch):
input_ids = torch.tensor(batch['input_ids']).to(device)
attention_mask = torch.tensor(batch['attention_mask']).to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
predictions = torch.argmax(outputs.logits, dim=-1)
return {'sentiment': predictions.cpu()}
# Mapping sentiment labels
sentiment_labels_en = {0: 'منفی', 1: 'خنثی', 2: 'مثبت'}
# Adding sentiment prediction to tokenized dataset
def predict_sentiment_labels(text):
dataset = Dataset.from_dict({"Comment": [text]})
tokenized_dataset = dataset.map(tokenize_function, batched=True)
predicted_sentiments = tokenized_dataset.map(predict_sentiment, batched=True)
sentiment = predicted_sentiments[0]['sentiment']
return sentiment_labels_en.get(sentiment, 'نامشخص')
# Functions from your original code for classifying sentence type and cleaning
imperative_verbs = [
'بیا', 'برو', 'بخواب', 'کن', 'باش', 'بذار', 'فراموش کن', 'بخور',
'بپوش', 'ببخش', 'بنویس', 'دقت کن', 'دست بردار', 'سکوت کن',
'اجازه بده', 'نکن', 'پیش برو', 'خواب بمان', 'توجه کن', 'خوش آمدید',
'حواس‌جمع باش', 'در نظر بگیر', 'بخشید', 'بکش', 'نگذار', 'سعی کن',
'تلاش کن', 'ببین', 'نرو', 'بگیر', 'بگو', 'شک نکن', 'فکر کن',
'عادت کن', 'بیانداز', 'حرکت کن', 'شکایت نکن', 'عاشق شو', 'بخند',
'برگرد', 'بزن', 'آشپزی کن', 'بپذیر', 'شیرینی بپز', 'درس بخوان',
'کلاس بگذار', 'کمک کن', 'بمان', 'راهنمایی کن', 'لطفا'
]
def classify_sentence(sentence):
sentence = sentence.strip()
sentence_type = 'خبری'
if re.search(r'چرا|چطور|کجا|آیا|چه|چی|چند|کدام|کی|چندم|چیست|چیه|چندمین|چجوری|کی|چیست|چگونه|؟', sentence) or sentence.endswith('?'):
sentence_type = 'پرسشی'
elif re.search(r'\b(?:' + '|'.join(imperative_verbs) + r')\b', sentence):
sentence_type = 'امری'
return sentence_type
def clean_text(text):
text = re.sub(r'https://\S+|www\.\S+', '', text)
text = re.sub(r'[^ا-ی0-9\s#@_؟]', ' ', text)
text = re.sub(r'\s+', ' ', text).strip()
words = word_tokenize(text)
words = [word for word in words if word not in stopwords]
words = [lemmatizer.lemmatize(word) for word in words]
return ' '.join(words)
def process_sentence(sentence):
cleaned = clean_text(sentence)
sentence_type = classify_sentence(cleaned)
sentiment = predict_sentiment_labels(sentence)
return f"Type: {sentence_type}\nSentiment: {sentiment}\nCleaned Text: {cleaned}"
# Function to process file
def process_file(file):
try:
df = pd.read_csv(file.name)
if 'Comment' not in df.columns:
return "Error: No 'Comment' column found in the file."
# Process comments
df['Cleaned_Comment'] = df['Comment'].apply(clean_text)
df['Type'] = df['Comment'].apply(classify_sentence)
df['Sentiment'] = df['Comment'].apply(predict_sentiment_labels)
output_path = "processed_file.csv"
df.to_csv(output_path, index=False)
return f"File processed successfully! Download it [here](./{output_path})"
except Exception as e:
return str(e)