import torch
import gradio as gr
import speech_recognition as sr
import difflib
import json
import os
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
MODEL_NAME = "prithivida/grammar_error_correcter_v1"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
D3_TEMPLATE = """
"""
def analyze_errors(original, corrected):
diff = difflib.SequenceMatcher(None, original.split(), corrected.split())
errors = []
for tag, i1, i2, j1, j2 in diff.get_opcodes():
if tag != 'equal':
error = {
'type': 'Grammar' if tag == 'replace' else 'Structure',
'original': ' '.join(original.split()[i1:i2]),
'suggestion': ' '.join(corrected.split()[j1:j2]),
'message': 'Improvement suggested' if tag == 'replace' else 'Structural change'
}
errors.append(error)
return errors
def process_input(audio_path, text):
try:
# Handle audio input
if audio_path and os.path.exists(audio_path):
recognizer = sr.Recognizer()
with sr.AudioFile(audio_path) as source:
audio = recognizer.record(source)
text = recognizer.recognize_google(audio)
if not text.strip():
return {"error": "No input provided"}, ""
# Grammar correction
inputs = tokenizer.encode("gec: " + text, return_tensors="pt",
max_length=256, truncation=True)
with torch.no_grad():
outputs = model.generate(inputs, max_length=256, num_beams=5)
corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Analysis
errors = analyze_errors(text, corrected)
score = max(0, 100 - len(errors)*5)
# D3 data format
error_counts = {
"Grammar": sum(1 for e in errors if e['type'] == 'Grammar'),
"Structure": sum(1 for e in errors if e['type'] == 'Structure'),
"Spelling": 0 # Add spelling detection logic if available
}
d3_data = {
"score": score,
"errors": [{"type": k, "count": v} for k, v in error_counts.items()],
"corrections": errors[:10] # Show first 10 corrections
}
return d3_data, corrected
except Exception as e:
return {"error": str(e)}, ""
with gr.Blocks(css="""
.gradio-container { max-width: 1400px!important; padding: 20px!important; }
#d3-dashboard { background: white; padding: 20px; border-radius: 12px; }
.text-visualization { font-family: monospace; font-size: 16px; }
""") as app:
gr.Markdown("# ✨ AI Writing Analytics Suite")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Input Section")
audio = gr.Audio(sources=["microphone"], type="filepath",
label="🎤 Voice Input")
text = gr.Textbox(lines=5, placeholder="📝 Enter text here...",
label="Text Input")
btn = gr.Button("Analyze", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### Writing Analytics")
visualization = gr.HTML(D3_TEMPLATE)
report = gr.JSON(label="Detailed Report")
gr.Markdown("### Corrected Text")
corrected = gr.Textbox(label="Result", interactive=False)
# Examples handling
gr.Markdown("### Example Inputs")
gr.Examples(
examples=[
["I is going to the park yesterday."],
["Their happy about there test results."],
["She dont like apples, but she like bananas."]
],
inputs=[text],
outputs=[report, corrected],
fn=lambda t: process_input(None, t),
cache_examples=False
)
btn.click(
fn=process_input,
inputs=[audio, text],
outputs=[report, corrected]
)
# JavaScript update
visualization.change(
fn=lambda data: f"",
inputs=[report],
outputs=[visualization]
)
if __name__ == "__main__":
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)