Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
import speech_recognition as sr | |
import difflib | |
import json | |
import os | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
MODEL_NAME = "prithivida/grammar_error_correcter_v1" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
D3_TEMPLATE = """ | |
<div id="d3-dashboard"> | |
<svg id="main-chart"></svg> | |
<div id="text-vis" class="text-visualization"></div> | |
</div> | |
<script src="https://d3js.org/d3.v7.min.js"></script> | |
<script> | |
const container = d3.select("#d3-dashboard"); | |
const width = 800; | |
const height = 500; | |
const margin = {top: 20, right: 30, bottom: 30, left: 40}; | |
const svg = container.select("#main-chart") | |
.attr("width", width) | |
.attr("height", height) | |
.style("background", "#f8fafc") | |
.style("border-radius", "12px"); | |
const textVis = container.select("#text-vis") | |
.style("margin-top", "20px") | |
.style("min-height", "100px") | |
.style("padding", "20px") | |
.style("background", "white") | |
.style("border-radius", "12px") | |
.style("box-shadow", "0 2px 4px rgba(0,0,0,0.1)"); | |
function updateDashboard(data) { | |
// Clear previous elements | |
svg.selectAll("*").remove(); | |
textVis.html(""); | |
// Score Arc | |
const arc = d3.arc() | |
.innerRadius(80) | |
.outerRadius(120) | |
.startAngle(0) | |
.endAngle((Math.PI * 2 * data.score) / 100); | |
svg.append("path") | |
.attr("transform", `translate(${width/2},${height/3})`) | |
.attr("d", arc) | |
.attr("fill", "#3b82f6") | |
.transition() | |
.duration(1000) | |
.attrTween("d", function(d) { | |
const interpolate = d3.interpolate(0, data.score/100); | |
return function(t) { | |
arc.endAngle(Math.PI * 2 * interpolate(t)); | |
return arc(); | |
}; | |
}); | |
// Error Distribution | |
const errorTypes = data.errors; | |
const x = d3.scaleBand() | |
.domain(errorTypes.map(d => d.type)) | |
.range([margin.left, width - margin.right]) | |
.padding(0.2); | |
const y = d3.scaleLinear() | |
.domain([0, d3.max(errorTypes, d => d.count)]) | |
.range([height/2 - margin.bottom, margin.top]); | |
svg.selectAll(".error-bar") | |
.data(errorTypes) | |
.enter().append("rect") | |
.attr("class", "error-bar") | |
.attr("x", d => x(d.type)) | |
.attr("y", d => y(d.count)) | |
.attr("width", x.bandwidth()) | |
.attr("height", d => height/2 - margin.bottom - y(d.count)) | |
.attr("fill", "#ef4444") | |
.attr("rx", 4) | |
.on("mouseover", function(event, d) { | |
d3.select(this).attr("fill", "#dc2626"); | |
}) | |
.on("mouseout", function(event, d) { | |
d3.select(this).attr("fill", "#ef4444"); | |
}); | |
// Interactive Text Visualization | |
const textBox = textVis.selectAll(".word") | |
.data(data.corrections) | |
.enter().append("div") | |
.attr("class", "word") | |
.style("display", "inline-block") | |
.style("margin", "2px") | |
.style("padding", "4px 8px") | |
.style("border-radius", "4px") | |
.style("background", d => d.correct ? "#d1fae5" : "#fee2e2") | |
.style("color", d => d.correct ? "#065f46" : "#991b1b") | |
.style("cursor", "pointer") | |
.on("mouseover", function(event, d) { | |
d3.select(this).style("filter", "brightness(90%)"); | |
}) | |
.on("mouseout", function(event, d) { | |
d3.select(this).style("filter", "brightness(100%)"); | |
}) | |
.html(d => d.original); | |
textBox.append("div") | |
.attr("class", "tooltip") | |
.style("position", "absolute") | |
.style("background", "white") | |
.style("padding", "8px") | |
.style("border-radius", "6px") | |
.style("box-shadow", "0 2px 8px rgba(0,0,0,0.1)") | |
.html(d => ` | |
<strong>${d.type}</strong><br> | |
${d.message}<br> | |
<em>Suggested:</em> ${d.suggestion} | |
`); | |
} | |
</script> | |
""" | |
def analyze_errors(original, corrected): | |
diff = difflib.SequenceMatcher(None, original.split(), corrected.split()) | |
errors = [] | |
for tag, i1, i2, j1, j2 in diff.get_opcodes(): | |
if tag != 'equal': | |
error = { | |
'type': 'Grammar' if tag == 'replace' else 'Structure', | |
'original': ' '.join(original.split()[i1:i2]), | |
'suggestion': ' '.join(corrected.split()[j1:j2]), | |
'message': 'Improvement suggested' if tag == 'replace' else 'Structural change' | |
} | |
errors.append(error) | |
return errors | |
def process_input(audio_path, text): | |
try: | |
# Handle audio input | |
if audio_path and os.path.exists(audio_path): | |
recognizer = sr.Recognizer() | |
with sr.AudioFile(audio_path) as source: | |
audio = recognizer.record(source) | |
text = recognizer.recognize_google(audio) | |
if not text.strip(): | |
return {"error": "No input provided"}, "" | |
# Grammar correction | |
inputs = tokenizer.encode("gec: " + text, return_tensors="pt", | |
max_length=256, truncation=True) | |
with torch.no_grad(): | |
outputs = model.generate(inputs, max_length=256, num_beams=5) | |
corrected = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Analysis | |
errors = analyze_errors(text, corrected) | |
score = max(0, 100 - len(errors)*5) | |
# D3 data format | |
error_counts = { | |
"Grammar": sum(1 for e in errors if e['type'] == 'Grammar'), | |
"Structure": sum(1 for e in errors if e['type'] == 'Structure'), | |
"Spelling": 0 # Add spelling detection logic if available | |
} | |
d3_data = { | |
"score": score, | |
"errors": [{"type": k, "count": v} for k, v in error_counts.items()], | |
"corrections": errors[:10] # Show first 10 corrections | |
} | |
return d3_data, corrected | |
except Exception as e: | |
return {"error": str(e)}, "" | |
with gr.Blocks(css=""" | |
.gradio-container { max-width: 1400px!important; padding: 20px!important; } | |
#d3-dashboard { background: white; padding: 20px; border-radius: 12px; } | |
.text-visualization { font-family: monospace; font-size: 16px; } | |
""") as app: | |
gr.Markdown("# β¨ AI Writing Analytics Suite") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Input Section") | |
audio = gr.Audio(sources=["microphone"], type="filepath", | |
label="π€ Voice Input") | |
text = gr.Textbox(lines=5, placeholder="π Enter text here...", | |
label="Text Input") | |
btn = gr.Button("Analyze", variant="primary") | |
with gr.Column(scale=2): | |
gr.Markdown("### Writing Analytics") | |
visualization = gr.HTML(D3_TEMPLATE) | |
report = gr.JSON(label="Detailed Report") | |
gr.Markdown("### Corrected Text") | |
corrected = gr.Textbox(label="Result", interactive=False) | |
# Examples handling | |
gr.Markdown("### Example Inputs") | |
gr.Examples( | |
examples=[ | |
["I is going to the park yesterday."], | |
["Their happy about there test results."], | |
["She dont like apples, but she like bananas."] | |
], | |
inputs=[text], | |
outputs=[report, corrected], | |
fn=lambda t: process_input(None, t), | |
cache_examples=False | |
) | |
btn.click( | |
fn=process_input, | |
inputs=[audio, text], | |
outputs=[report, corrected] | |
) | |
# JavaScript update | |
visualization.change( | |
fn=lambda data: f"<script>updateDashboard({json.dumps(data)})</script>", | |
inputs=[report], | |
outputs=[visualization] | |
) | |
if __name__ == "__main__": | |
app.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) |