Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,229 +1,241 @@
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
import speech_recognition as sr
|
4 |
-
import time
|
5 |
import difflib
|
|
|
6 |
import os
|
7 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
8 |
-
from happytransformer import HappyTextToText, TTSettings
|
9 |
|
10 |
-
# Constants
|
11 |
MODEL_NAME = "prithivida/grammar_error_correcter_v1"
|
12 |
-
|
13 |
-
|
14 |
-
.header { text-align: center; padding: 2rem; background: linear-gradient(135deg, #3b82f6, #6366f1); color: white; border-radius: 15px; }
|
15 |
-
#container { height: 500px; width: 100%; background: #1a1a1a; border-radius: 10px; }
|
16 |
-
.diff-ins { color: #10b981; background: #d1fae5; padding: 2px 4px; border-radius: 4px; }
|
17 |
-
.diff-del { color: #ef4444; background: #fee2e2; padding: 2px 4px; border-radius: 4px; }
|
18 |
-
"""
|
19 |
-
|
20 |
-
THREEJS_TEMPLATE = """
|
21 |
-
<div id="container"></div>
|
22 |
-
<script async src="https://unpkg.com/[email protected]/dist/es-module-shims.js"></script>
|
23 |
-
<script type="importmap">
|
24 |
-
{
|
25 |
-
"imports": {
|
26 |
-
"three": "https://unpkg.com/[email protected]/build/three.module.js",
|
27 |
-
"three/addons/": "https://unpkg.com/[email protected]/examples/jsm/"
|
28 |
-
}
|
29 |
-
}
|
30 |
-
</script>
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
});
|
68 |
-
this.sphere = new THREE.Mesh(geometry, this.material);
|
69 |
-
this.scene.add(this.sphere);
|
70 |
-
}
|
71 |
-
|
72 |
-
setupControls() {
|
73 |
-
this.controls = new OrbitControls(this.camera, this.renderer.domElement);
|
74 |
-
this.controls.enableDamping = true;
|
75 |
-
this.controls.dampingFactor = 0.05;
|
76 |
-
}
|
77 |
-
|
78 |
-
animate() {
|
79 |
-
requestAnimationFrame(() => this.animate());
|
80 |
-
this.controls.update();
|
81 |
-
this.renderer.render(this.scene, this.camera);
|
82 |
-
}
|
83 |
-
|
84 |
-
updateVisuals(score) {
|
85 |
-
const hue = score / 100;
|
86 |
-
this.material.color.setHSL(hue, 0.8, 0.5);
|
87 |
-
this.sphere.rotation.x = (score / 50) * Math.PI;
|
88 |
-
this.sphere.rotation.y = (score / 75) * Math.PI;
|
89 |
-
}
|
90 |
-
}
|
91 |
-
|
92 |
-
let visualizer;
|
93 |
-
window.addEventListener('DOMContentLoaded', () => {
|
94 |
-
visualizer = new GrammarVisualizer();
|
95 |
-
});
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
</script>
|
101 |
"""
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
def analyze_grammar(text):
|
117 |
-
inputs = tokenizer.encode("gec: " + text, return_tensors="pt", max_length=256, truncation=True)
|
118 |
-
with torch.no_grad():
|
119 |
-
outputs = model.generate(inputs, max_length=256, num_beams=5)
|
120 |
-
correction = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
121 |
-
|
122 |
-
args = TTSettings(num_beams=5, min_length=1)
|
123 |
-
happy_correction = happy_tt.generate_text("gec: " + text, args=args).text
|
124 |
-
|
125 |
-
final_correction = happy_correction if len(happy_correction) > len(correction) else correction
|
126 |
-
changes = sum(1 for a, b in zip(text.split(), final_correction.split()) if a != b)
|
127 |
-
score = max(0, 100 - (changes * 2))
|
128 |
-
|
129 |
-
return {
|
130 |
-
"original": text,
|
131 |
-
"corrected": final_correction,
|
132 |
-
"score": score,
|
133 |
-
"diff_html": create_diff_html(text, final_correction)
|
134 |
-
}
|
135 |
|
136 |
def process_input(audio_path, text):
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
recognizer = sr.Recognizer()
|
141 |
with sr.AudioFile(audio_path) as source:
|
142 |
audio = recognizer.record(source)
|
143 |
text = recognizer.recognize_google(audio)
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
except Exception as e:
|
165 |
-
return
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
gr.Markdown(""
|
174 |
-
<div class="header">
|
175 |
-
<h1>🌍 3D Grammar Analyzer Pro</h1>
|
176 |
-
<p>Interactive AI-Powered Writing Assistant with 3D Visualization</p>
|
177 |
-
</div>
|
178 |
-
""")
|
179 |
|
180 |
with gr.Row():
|
181 |
with gr.Column(scale=1):
|
182 |
gr.Markdown("### Input Section")
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
|
|
186 |
|
187 |
with gr.Column(scale=2):
|
188 |
-
gr.Markdown("###
|
189 |
-
|
|
|
190 |
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
hidden_trigger = gr.HTML(visible=False)
|
197 |
-
|
198 |
-
# Fixed examples configuration
|
199 |
-
gr.Markdown("### Example Sentences")
|
200 |
gr.Examples(
|
201 |
examples=[
|
202 |
-
["I is going to the park yesterday."],
|
203 |
["Their happy about there test results."],
|
204 |
["She dont like apples, but she like bananas."]
|
205 |
],
|
206 |
-
inputs=[
|
207 |
-
outputs=[
|
208 |
-
fn=lambda
|
209 |
-
cache_examples=False
|
210 |
)
|
211 |
|
212 |
-
|
213 |
fn=process_input,
|
214 |
-
inputs=[
|
215 |
-
outputs=[
|
216 |
)
|
217 |
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
222 |
)
|
223 |
|
224 |
if __name__ == "__main__":
|
225 |
app.launch(
|
226 |
server_name="0.0.0.0",
|
227 |
server_port=7860,
|
228 |
-
share=False
|
229 |
)
|
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
import speech_recognition as sr
|
|
|
4 |
import difflib
|
5 |
+
import json
|
6 |
import os
|
7 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
|
8 |
|
|
|
9 |
MODEL_NAME = "prithivida/grammar_error_correcter_v1"
|
10 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
11 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
D3_TEMPLATE = """
|
14 |
+
<div id="d3-dashboard">
|
15 |
+
<svg id="main-chart"></svg>
|
16 |
+
<div id="text-vis" class="text-visualization"></div>
|
17 |
+
</div>
|
18 |
+
<script src="https://d3js.org/d3.v7.min.js"></script>
|
19 |
+
<script>
|
20 |
+
const container = d3.select("#d3-dashboard");
|
21 |
+
const width = 800;
|
22 |
+
const height = 500;
|
23 |
+
const margin = {top: 20, right: 30, bottom: 30, left: 40};
|
24 |
+
|
25 |
+
const svg = container.select("#main-chart")
|
26 |
+
.attr("width", width)
|
27 |
+
.attr("height", height)
|
28 |
+
.style("background", "#f8fafc")
|
29 |
+
.style("border-radius", "12px");
|
30 |
+
|
31 |
+
const textVis = container.select("#text-vis")
|
32 |
+
.style("margin-top", "20px")
|
33 |
+
.style("min-height", "100px")
|
34 |
+
.style("padding", "20px")
|
35 |
+
.style("background", "white")
|
36 |
+
.style("border-radius", "12px")
|
37 |
+
.style("box-shadow", "0 2px 4px rgba(0,0,0,0.1)");
|
38 |
+
|
39 |
+
function updateDashboard(data) {
|
40 |
+
// Clear previous elements
|
41 |
+
svg.selectAll("*").remove();
|
42 |
+
textVis.html("");
|
43 |
+
|
44 |
+
// Score Arc
|
45 |
+
const arc = d3.arc()
|
46 |
+
.innerRadius(80)
|
47 |
+
.outerRadius(120)
|
48 |
+
.startAngle(0)
|
49 |
+
.endAngle((Math.PI * 2 * data.score) / 100);
|
50 |
+
|
51 |
+
svg.append("path")
|
52 |
+
.attr("transform", `translate(${width/2},${height/3})`)
|
53 |
+
.attr("d", arc)
|
54 |
+
.attr("fill", "#3b82f6")
|
55 |
+
.transition()
|
56 |
+
.duration(1000)
|
57 |
+
.attrTween("d", function(d) {
|
58 |
+
const interpolate = d3.interpolate(0, data.score/100);
|
59 |
+
return function(t) {
|
60 |
+
arc.endAngle(Math.PI * 2 * interpolate(t));
|
61 |
+
return arc();
|
62 |
+
};
|
63 |
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
// Error Distribution
|
66 |
+
const errorTypes = data.errors;
|
67 |
+
const x = d3.scaleBand()
|
68 |
+
.domain(errorTypes.map(d => d.type))
|
69 |
+
.range([margin.left, width - margin.right])
|
70 |
+
.padding(0.2);
|
71 |
+
|
72 |
+
const y = d3.scaleLinear()
|
73 |
+
.domain([0, d3.max(errorTypes, d => d.count)])
|
74 |
+
.range([height/2 - margin.bottom, margin.top]);
|
75 |
+
|
76 |
+
svg.selectAll(".error-bar")
|
77 |
+
.data(errorTypes)
|
78 |
+
.enter().append("rect")
|
79 |
+
.attr("class", "error-bar")
|
80 |
+
.attr("x", d => x(d.type))
|
81 |
+
.attr("y", d => y(d.count))
|
82 |
+
.attr("width", x.bandwidth())
|
83 |
+
.attr("height", d => height/2 - margin.bottom - y(d.count))
|
84 |
+
.attr("fill", "#ef4444")
|
85 |
+
.attr("rx", 4)
|
86 |
+
.on("mouseover", function(event, d) {
|
87 |
+
d3.select(this).attr("fill", "#dc2626");
|
88 |
+
})
|
89 |
+
.on("mouseout", function(event, d) {
|
90 |
+
d3.select(this).attr("fill", "#ef4444");
|
91 |
+
});
|
92 |
+
|
93 |
+
// Interactive Text Visualization
|
94 |
+
const textBox = textVis.selectAll(".word")
|
95 |
+
.data(data.corrections)
|
96 |
+
.enter().append("div")
|
97 |
+
.attr("class", "word")
|
98 |
+
.style("display", "inline-block")
|
99 |
+
.style("margin", "2px")
|
100 |
+
.style("padding", "4px 8px")
|
101 |
+
.style("border-radius", "4px")
|
102 |
+
.style("background", d => d.correct ? "#d1fae5" : "#fee2e2")
|
103 |
+
.style("color", d => d.correct ? "#065f46" : "#991b1b")
|
104 |
+
.style("cursor", "pointer")
|
105 |
+
.on("mouseover", function(event, d) {
|
106 |
+
d3.select(this).style("filter", "brightness(90%)");
|
107 |
+
})
|
108 |
+
.on("mouseout", function(event, d) {
|
109 |
+
d3.select(this).style("filter", "brightness(100%)");
|
110 |
+
})
|
111 |
+
.html(d => d.original);
|
112 |
+
|
113 |
+
textBox.append("div")
|
114 |
+
.attr("class", "tooltip")
|
115 |
+
.style("position", "absolute")
|
116 |
+
.style("background", "white")
|
117 |
+
.style("padding", "8px")
|
118 |
+
.style("border-radius", "6px")
|
119 |
+
.style("box-shadow", "0 2px 8px rgba(0,0,0,0.1)")
|
120 |
+
.html(d => `
|
121 |
+
<strong>${d.type}</strong><br>
|
122 |
+
${d.message}<br>
|
123 |
+
<em>Suggested:</em> ${d.suggestion}
|
124 |
+
`);
|
125 |
+
}
|
126 |
</script>
|
127 |
"""
|
128 |
+
|
129 |
+
def analyze_errors(original, corrected):
|
130 |
+
diff = difflib.SequenceMatcher(None, original.split(), corrected.split())
|
131 |
+
errors = []
|
132 |
+
for tag, i1, i2, j1, j2 in diff.get_opcodes():
|
133 |
+
if tag != 'equal':
|
134 |
+
error = {
|
135 |
+
'type': 'Grammar' if tag == 'replace' else 'Structure',
|
136 |
+
'original': ' '.join(original.split()[i1:i2]),
|
137 |
+
'suggestion': ' '.join(corrected.split()[j1:j2]),
|
138 |
+
'message': 'Improvement suggested' if tag == 'replace' else 'Structural change'
|
139 |
+
}
|
140 |
+
errors.append(error)
|
141 |
+
return errors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
def process_input(audio_path, text):
|
144 |
+
try:
|
145 |
+
# Handle audio input
|
146 |
+
if audio_path and os.path.exists(audio_path):
|
147 |
recognizer = sr.Recognizer()
|
148 |
with sr.AudioFile(audio_path) as source:
|
149 |
audio = recognizer.record(source)
|
150 |
text = recognizer.recognize_google(audio)
|
151 |
+
|
152 |
+
if not text.strip():
|
153 |
+
return {"error": "No input provided"}, ""
|
154 |
+
|
155 |
+
# Grammar correction
|
156 |
+
inputs = tokenizer.encode("gec: " + text, return_tensors="pt",
|
157 |
+
max_length=256, truncation=True)
|
158 |
+
with torch.no_grad():
|
159 |
+
outputs = model.generate(inputs, max_length=256, num_beams=5)
|
160 |
+
corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
161 |
+
|
162 |
+
# Analysis
|
163 |
+
errors = analyze_errors(text, corrected)
|
164 |
+
score = max(0, 100 - len(errors)*5)
|
165 |
+
|
166 |
+
# D3 data format
|
167 |
+
error_counts = {
|
168 |
+
"Grammar": sum(1 for e in errors if e['type'] == 'Grammar'),
|
169 |
+
"Structure": sum(1 for e in errors if e['type'] == 'Structure'),
|
170 |
+
"Spelling": 0 # Add spelling detection logic if available
|
171 |
+
}
|
172 |
+
|
173 |
+
d3_data = {
|
174 |
+
"score": score,
|
175 |
+
"errors": [{"type": k, "count": v} for k, v in error_counts.items()],
|
176 |
+
"corrections": errors[:10] # Show first 10 corrections
|
177 |
+
}
|
178 |
+
|
179 |
+
return d3_data, corrected
|
180 |
+
|
181 |
except Exception as e:
|
182 |
+
return {"error": str(e)}, ""
|
183 |
+
|
184 |
+
with gr.Blocks(css="""
|
185 |
+
.gradio-container { max-width: 1400px!important; padding: 20px!important; }
|
186 |
+
#d3-dashboard { background: white; padding: 20px; border-radius: 12px; }
|
187 |
+
.text-visualization { font-family: monospace; font-size: 16px; }
|
188 |
+
""") as app:
|
189 |
+
|
190 |
+
gr.Markdown("# ✨ AI Writing Analytics Suite")
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
with gr.Row():
|
193 |
with gr.Column(scale=1):
|
194 |
gr.Markdown("### Input Section")
|
195 |
+
audio = gr.Audio(sources=["microphone"], type="filepath",
|
196 |
+
label="🎤 Voice Input")
|
197 |
+
text = gr.Textbox(lines=5, placeholder="📝 Enter text here...",
|
198 |
+
label="Text Input")
|
199 |
+
btn = gr.Button("Analyze", variant="primary")
|
200 |
|
201 |
with gr.Column(scale=2):
|
202 |
+
gr.Markdown("### Writing Analytics")
|
203 |
+
visualization = gr.HTML(D3_TEMPLATE)
|
204 |
+
report = gr.JSON(label="Detailed Report")
|
205 |
|
206 |
+
gr.Markdown("### Corrected Text")
|
207 |
+
corrected = gr.Textbox(label="Result", interactive=False)
|
208 |
+
|
209 |
+
# Examples handling
|
210 |
+
gr.Markdown("### Example Inputs")
|
|
|
|
|
|
|
|
|
211 |
gr.Examples(
|
212 |
examples=[
|
213 |
+
["I is going to the park yesterday."],
|
214 |
["Their happy about there test results."],
|
215 |
["She dont like apples, but she like bananas."]
|
216 |
],
|
217 |
+
inputs=[text],
|
218 |
+
outputs=[report, corrected],
|
219 |
+
fn=lambda t: process_input(None, t),
|
220 |
+
cache_examples=False
|
221 |
)
|
222 |
|
223 |
+
btn.click(
|
224 |
fn=process_input,
|
225 |
+
inputs=[audio, text],
|
226 |
+
outputs=[report, corrected]
|
227 |
)
|
228 |
|
229 |
+
# JavaScript update
|
230 |
+
visualization.change(
|
231 |
+
fn=lambda data: f"<script>updateDashboard({json.dumps(data)})</script>",
|
232 |
+
inputs=[report],
|
233 |
+
outputs=[visualization]
|
234 |
)
|
235 |
|
236 |
if __name__ == "__main__":
|
237 |
app.launch(
|
238 |
server_name="0.0.0.0",
|
239 |
server_port=7860,
|
240 |
+
share=False
|
241 |
)
|