|
import gradio as gr |
|
import torch |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForSequenceClassification |
|
) |
|
import os |
|
from pdf_generator import ReportGenerator |
|
from news_checker import NewsChecker |
|
from dotenv import load_dotenv |
|
from spellchecker import SpellChecker |
|
import re |
|
|
|
load_dotenv() |
|
|
|
CONTRACTIONS = { |
|
|
|
"ain't", "aren't", "can't", "couldn't", "didn't", "doesn't", "don't", "hadn't", |
|
"hasn't", "haven't", "he'd", "he'll", "he's", "i'd", "i'll", "i'm", "i've", |
|
"isn't", "let's", "mightn't", "mustn't", "shan't", "she'd", "she'll", "she's", |
|
"shouldn't", "that's", "there's", "they'd", "they'll", "they're", "they've", |
|
"we'd", "we're", "we've", "weren't", "what'll", "what're", "what's", "what've", |
|
"where's", "who'd", "who'll", "who're", "who's", "who've", "won't", "wouldn't", |
|
"you'd", "you'll", "you're", "you've", |
|
|
|
"ain't", "aren't", "can't", "couldn't", "didn't", "doesn't", "don't", "hadn't", |
|
"hasn't", "haven't", "he'd", "he'll", "he's", "i'd", "i'll", "i'm", "i've", |
|
"isn't", "let's", "mightn't", "mustn't", "shan't", "she'd", "she'll", "she's", |
|
"shouldn't", "that's", "there's", "they'd", "they'll", "they're", "they've", |
|
"we'd", "we're", "we've", "weren't", "what'll", "what're", "what's", "what've", |
|
"where's", "who'd", "who'll", "who're", "who's", "who've", "won't", "wouldn't", |
|
"you'd", "you'll", "you're", "you've" |
|
} |
|
|
|
|
|
def load_models(): |
|
|
|
model_name = "facebook/roberta-hate-speech-dynabench-r4-target" |
|
hate_tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
hate_model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
|
|
|
spell = SpellChecker() |
|
|
|
return { |
|
'hate_speech': (hate_model, hate_tokenizer), |
|
'spell_check': spell |
|
} |
|
|
|
|
|
news_checker = NewsChecker() |
|
|
|
def check_text_length(text): |
|
if len(text) > 1000: |
|
return { |
|
'status': 'fail', |
|
'message': 'Text exceeds 1000 character limit' |
|
} |
|
return { |
|
'status': 'pass', |
|
'message': 'Text length is within limits' |
|
} |
|
|
|
def check_hate_speech_and_bias(text, model, tokenizer): |
|
try: |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
|
outputs = model(**inputs) |
|
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
|
|
|
if predictions[0][1].item() > 0.3: |
|
return { |
|
'status': 'fail', |
|
'message': 'Content contains potential hate speech or strong bias' |
|
} |
|
elif predictions[0][1].item() > 0.1: |
|
return { |
|
'status': 'warning', |
|
'message': 'Content may contain subtle bias or potentially offensive language' |
|
} |
|
return { |
|
'status': 'pass', |
|
'message': 'No significant bias or hate speech detected' |
|
} |
|
except Exception as e: |
|
return { |
|
'status': 'error', |
|
'message': f'Error in hate speech/bias detection: {str(e)}' |
|
} |
|
|
|
def normalize_apostrophes(text): |
|
"""Normalize different types of apostrophes and quotes to standard straight apostrophe""" |
|
|
|
return text.replace(''', "'").replace(''', "'").replace('`', "'").replace('´', "'") |
|
|
|
|
|
def check_spelling(text, spell_checker): |
|
try: |
|
|
|
text = normalize_apostrophes(text) |
|
|
|
|
|
words = text.split() |
|
|
|
|
|
misspelled = set() |
|
for word in words: |
|
|
|
word = normalize_apostrophes(word) |
|
|
|
|
|
cleaned = re.sub(r'^[^\w\']+|[^\w\']+$', '', word) |
|
|
|
|
|
if not cleaned: |
|
continue |
|
|
|
|
|
if cleaned.lower() in CONTRACTIONS: |
|
continue |
|
|
|
|
|
if (cleaned.isdigit() or |
|
any(char.isdigit() for char in cleaned) or |
|
cleaned.startswith('@') or |
|
cleaned.startswith('#') or |
|
cleaned.startswith('http') or |
|
cleaned.isupper() or |
|
len(cleaned) <= 1): |
|
continue |
|
|
|
|
|
if cleaned.lower() not in spell_checker.word_frequency: |
|
misspelled.add(cleaned) |
|
|
|
if misspelled: |
|
corrections = [] |
|
for word in misspelled: |
|
|
|
candidates = spell_checker.candidates(word) |
|
if candidates: |
|
|
|
suggestions = list(candidates)[:3] |
|
|
|
if any(sugg.lower() != word.lower() for sugg in suggestions): |
|
corrections.append(f"'{word}' -> suggestions: {', '.join(suggestions)}") |
|
|
|
if corrections: |
|
return { |
|
'status': 'warning', |
|
'message': 'Misspelled words found:\n' + '\n'.join(corrections) |
|
} |
|
|
|
return { |
|
'status': 'pass', |
|
'message': 'No spelling errors detected' |
|
} |
|
except Exception as e: |
|
return { |
|
'status': 'error', |
|
'message': f'Error in spell check: {str(e)}' |
|
} |
|
|
|
def analyze_content(text): |
|
try: |
|
|
|
report_gen = ReportGenerator() |
|
report_gen.add_header() |
|
report_gen.add_input_text(text) |
|
|
|
|
|
models = load_models() |
|
|
|
|
|
results = {} |
|
|
|
|
|
length_result = check_text_length(text) |
|
results['Length Check'] = length_result |
|
report_gen.add_check_result("Length Check", length_result['status'], length_result['message']) |
|
|
|
if length_result['status'] == 'fail': |
|
report_path = report_gen.save_report() |
|
return results, report_path |
|
|
|
|
|
hate_result = check_hate_speech_and_bias(text, models['hate_speech'][0], models['hate_speech'][1]) |
|
results['Hate Speech / Involuntary Bias Check'] = hate_result |
|
report_gen.add_check_result("Hate Speech / Involuntary Bias Check", hate_result['status'], hate_result['message']) |
|
|
|
|
|
spell_result = check_spelling(text, models['spell_check']) |
|
results['Spelling Check'] = spell_result |
|
report_gen.add_check_result("Spelling Check", spell_result['status'], spell_result['message']) |
|
|
|
|
|
if os.getenv('NEWS_API_KEY'): |
|
news_result = news_checker.check_content_against_news(text) |
|
else: |
|
news_result = { |
|
'status': 'warning', |
|
'message': 'News API key not configured. Skipping current events check.' |
|
} |
|
results['Current Events Context'] = news_result |
|
report_gen.add_check_result("Current Events Context", news_result['status'], news_result['message']) |
|
|
|
|
|
report_path = report_gen.save_report() |
|
|
|
return results, report_path |
|
except Exception as e: |
|
print(f"Error in analyze_content: {str(e)}") |
|
return { |
|
'Length Check': {'status': 'error', 'message': 'Analysis failed'}, |
|
'Hate Speech / Involuntary Bias Check': {'status': 'error', 'message': 'Analysis failed'}, |
|
'Spelling Check': {'status': 'error', 'message': 'Analysis failed'}, |
|
'Current Events Context': {'status': 'error', 'message': 'Analysis failed'} |
|
}, None |
|
|
|
def format_results(results): |
|
status_symbols = { |
|
'pass': '✅', |
|
'fail': '❌', |
|
'warning': '⚠️', |
|
'error': '⚠️' |
|
} |
|
|
|
formatted_output = "" |
|
for check, result in results.items(): |
|
symbol = status_symbols.get(result['status'], '❓') |
|
formatted_output += f"{check}: {symbol}\n" |
|
if result['message']: |
|
formatted_output += f"Details: {result['message']}\n\n" |
|
|
|
return formatted_output |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks(title="Marketing Content Validator") as interface: |
|
gr.Markdown("# Marketing Content Validator") |
|
gr.Markdown("Paste your marketing content below to check for potential issues.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_text = gr.TextArea( |
|
label="Marketing Content", |
|
placeholder="Enter your marketing content here (max 1000 characters)...", |
|
lines=10 |
|
) |
|
analyze_btn = gr.Button("Analyze Content") |
|
|
|
with gr.Column(): |
|
output_text = gr.TextArea( |
|
label="Analysis Results", |
|
lines=10, |
|
interactive=False |
|
) |
|
report_output = gr.File(label="Download Report") |
|
|
|
analyze_btn.click( |
|
fn=lambda text: ( |
|
format_results(analyze_content(text)[0]), |
|
analyze_content(text)[1] |
|
), |
|
inputs=input_text, |
|
outputs=[output_text, report_output] |
|
) |
|
|
|
gr.Markdown(""" |
|
### Notes: |
|
- Maximum text length: 1000 characters |
|
- Analysis may take up to 2 minutes |
|
- Results include checks for: |
|
- Text length |
|
- Hate speech and involuntary bias |
|
- Spelling |
|
- Negative news context |
|
""") |
|
|
|
return interface |
|
|
|
|
|
if __name__ == "__main__": |
|
interface = create_interface() |
|
interface.launch() |