gaur3009's picture
Update app.py
8886bcb verified
raw
history blame contribute delete
7.74 kB
import torch
import gradio as gr
import speech_recognition as sr
import difflib
import json
import os
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
MODEL_NAME = "prithivida/grammar_error_correcter_v1"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
D3_TEMPLATE = """
<div id="d3-dashboard">
<svg id="main-chart"></svg>
<div id="text-vis" class="text-visualization"></div>
</div>
<script src="https://d3js.org/d3.v7.min.js"></script>
<script>
const container = d3.select("#d3-dashboard");
const width = 800;
const height = 500;
const margin = {top: 20, right: 30, bottom: 30, left: 40};
const svg = container.select("#main-chart")
.attr("width", width)
.attr("height", height)
.style("background", "#f8fafc")
.style("border-radius", "12px");
const textVis = container.select("#text-vis")
.style("margin-top", "20px")
.style("min-height", "100px")
.style("padding", "20px")
.style("background", "white")
.style("border-radius", "12px")
.style("box-shadow", "0 2px 4px rgba(0,0,0,0.1)");
function updateDashboard(data) {
// Clear previous elements
svg.selectAll("*").remove();
textVis.html("");
// Score Arc
const arc = d3.arc()
.innerRadius(80)
.outerRadius(120)
.startAngle(0)
.endAngle((Math.PI * 2 * data.score) / 100);
svg.append("path")
.attr("transform", `translate(${width/2},${height/3})`)
.attr("d", arc)
.attr("fill", "#3b82f6")
.transition()
.duration(1000)
.attrTween("d", function(d) {
const interpolate = d3.interpolate(0, data.score/100);
return function(t) {
arc.endAngle(Math.PI * 2 * interpolate(t));
return arc();
};
});
// Error Distribution
const errorTypes = data.errors;
const x = d3.scaleBand()
.domain(errorTypes.map(d => d.type))
.range([margin.left, width - margin.right])
.padding(0.2);
const y = d3.scaleLinear()
.domain([0, d3.max(errorTypes, d => d.count)])
.range([height/2 - margin.bottom, margin.top]);
svg.selectAll(".error-bar")
.data(errorTypes)
.enter().append("rect")
.attr("class", "error-bar")
.attr("x", d => x(d.type))
.attr("y", d => y(d.count))
.attr("width", x.bandwidth())
.attr("height", d => height/2 - margin.bottom - y(d.count))
.attr("fill", "#ef4444")
.attr("rx", 4)
.on("mouseover", function(event, d) {
d3.select(this).attr("fill", "#dc2626");
})
.on("mouseout", function(event, d) {
d3.select(this).attr("fill", "#ef4444");
});
// Interactive Text Visualization
const textBox = textVis.selectAll(".word")
.data(data.corrections)
.enter().append("div")
.attr("class", "word")
.style("display", "inline-block")
.style("margin", "2px")
.style("padding", "4px 8px")
.style("border-radius", "4px")
.style("background", d => d.correct ? "#d1fae5" : "#fee2e2")
.style("color", d => d.correct ? "#065f46" : "#991b1b")
.style("cursor", "pointer")
.on("mouseover", function(event, d) {
d3.select(this).style("filter", "brightness(90%)");
})
.on("mouseout", function(event, d) {
d3.select(this).style("filter", "brightness(100%)");
})
.html(d => d.original);
textBox.append("div")
.attr("class", "tooltip")
.style("position", "absolute")
.style("background", "white")
.style("padding", "8px")
.style("border-radius", "6px")
.style("box-shadow", "0 2px 8px rgba(0,0,0,0.1)")
.html(d => `
<strong>${d.type}</strong><br>
${d.message}<br>
<em>Suggested:</em> ${d.suggestion}
`);
}
</script>
"""
def analyze_errors(original, corrected):
diff = difflib.SequenceMatcher(None, original.split(), corrected.split())
errors = []
for tag, i1, i2, j1, j2 in diff.get_opcodes():
if tag != 'equal':
error = {
'type': 'Grammar' if tag == 'replace' else 'Structure',
'original': ' '.join(original.split()[i1:i2]),
'suggestion': ' '.join(corrected.split()[j1:j2]),
'message': 'Improvement suggested' if tag == 'replace' else 'Structural change'
}
errors.append(error)
return errors
def process_input(audio_path, text):
try:
# Handle audio input
if audio_path and os.path.exists(audio_path):
recognizer = sr.Recognizer()
with sr.AudioFile(audio_path) as source:
audio = recognizer.record(source)
text = recognizer.recognize_google(audio)
if not text.strip():
return {"error": "No input provided"}, ""
# Grammar correction
inputs = tokenizer.encode("gec: " + text, return_tensors="pt",
max_length=256, truncation=True)
with torch.no_grad():
outputs = model.generate(inputs, max_length=256, num_beams=5)
corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Analysis
errors = analyze_errors(text, corrected)
score = max(0, 100 - len(errors)*5)
# D3 data format
error_counts = {
"Grammar": sum(1 for e in errors if e['type'] == 'Grammar'),
"Structure": sum(1 for e in errors if e['type'] == 'Structure'),
"Spelling": 0 # Add spelling detection logic if available
}
d3_data = {
"score": score,
"errors": [{"type": k, "count": v} for k, v in error_counts.items()],
"corrections": errors[:10] # Show first 10 corrections
}
return d3_data, corrected
except Exception as e:
return {"error": str(e)}, ""
with gr.Blocks(css="""
.gradio-container { max-width: 1400px!important; padding: 20px!important; }
#d3-dashboard { background: white; padding: 20px; border-radius: 12px; }
.text-visualization { font-family: monospace; font-size: 16px; }
""") as app:
gr.Markdown("# ✨ AI Writing Analytics Suite")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Input Section")
audio = gr.Audio(sources=["microphone"], type="filepath",
label="🎀 Voice Input")
text = gr.Textbox(lines=5, placeholder="πŸ“ Enter text here...",
label="Text Input")
btn = gr.Button("Analyze", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### Writing Analytics")
visualization = gr.HTML(D3_TEMPLATE)
report = gr.JSON(label="Detailed Report")
gr.Markdown("### Corrected Text")
corrected = gr.Textbox(label="Result", interactive=False)
# Examples handling
gr.Markdown("### Example Inputs")
gr.Examples(
examples=[
["I is going to the park yesterday."],
["Their happy about there test results."],
["She dont like apples, but she like bananas."]
],
inputs=[text],
outputs=[report, corrected],
fn=lambda t: process_input(None, t),
cache_examples=False
)
btn.click(
fn=process_input,
inputs=[audio, text],
outputs=[report, corrected]
)
# JavaScript update
visualization.change(
fn=lambda data: f"<script>updateDashboard({json.dumps(data)})</script>",
inputs=[report],
outputs=[visualization]
)
if __name__ == "__main__":
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)