gaur3009 commited on
Commit
8886bcb
·
verified ·
1 Parent(s): 1d2a66c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -185
app.py CHANGED
@@ -1,229 +1,241 @@
1
  import torch
2
  import gradio as gr
3
  import speech_recognition as sr
4
- import time
5
  import difflib
 
6
  import os
7
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
8
- from happytransformer import HappyTextToText, TTSettings
9
 
10
- # Constants
11
  MODEL_NAME = "prithivida/grammar_error_correcter_v1"
12
- CSS = """
13
- .gradio-container { max-width: 1400px !important; }
14
- .header { text-align: center; padding: 2rem; background: linear-gradient(135deg, #3b82f6, #6366f1); color: white; border-radius: 15px; }
15
- #container { height: 500px; width: 100%; background: #1a1a1a; border-radius: 10px; }
16
- .diff-ins { color: #10b981; background: #d1fae5; padding: 2px 4px; border-radius: 4px; }
17
- .diff-del { color: #ef4444; background: #fee2e2; padding: 2px 4px; border-radius: 4px; }
18
- """
19
-
20
- THREEJS_TEMPLATE = """
21
- <div id="container"></div>
22
- <script async src="https://unpkg.com/[email protected]/dist/es-module-shims.js"></script>
23
- <script type="importmap">
24
- {
25
- "imports": {
26
- "three": "https://unpkg.com/[email protected]/build/three.module.js",
27
- "three/addons/": "https://unpkg.com/[email protected]/examples/jsm/"
28
- }
29
- }
30
- </script>
31
 
32
- <script type="module">
33
- import * as THREE from 'three';
34
- import { OrbitControls } from 'three/addons/controls/OrbitControls.js';
35
-
36
- class GrammarVisualizer {
37
- constructor() {
38
- this.initScene();
39
- this.addLights();
40
- this.createGrammarSphere();
41
- this.setupControls();
42
- this.animate();
43
- }
44
-
45
- initScene() {
46
- this.scene = new THREE.Scene();
47
- this.camera = new THREE.PerspectiveCamera(75, 500/400, 0.1, 1000);
48
- this.renderer = new THREE.WebGLRenderer({ antialias: true, alpha: true });
49
- document.getElementById('container').appendChild(this.renderer.domElement);
50
- this.renderer.setSize(500, 400);
51
- this.camera.position.z = 5;
52
- }
53
-
54
- addLights() {
55
- const ambient = new THREE.AmbientLight(0x404040);
56
- const directional = new THREE.DirectionalLight(0xffffff, 1);
57
- directional.position.set(5, 5, 5);
58
- this.scene.add(ambient, directional);
59
- }
60
-
61
- createGrammarSphere() {
62
- const geometry = new THREE.SphereGeometry(2, 32, 32);
63
- this.material = new THREE.MeshPhongMaterial({
64
- color: 0x3b82f6,
65
- transparent: true,
66
- opacity: 0.9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  });
68
- this.sphere = new THREE.Mesh(geometry, this.material);
69
- this.scene.add(this.sphere);
70
- }
71
-
72
- setupControls() {
73
- this.controls = new OrbitControls(this.camera, this.renderer.domElement);
74
- this.controls.enableDamping = true;
75
- this.controls.dampingFactor = 0.05;
76
- }
77
-
78
- animate() {
79
- requestAnimationFrame(() => this.animate());
80
- this.controls.update();
81
- this.renderer.render(this.scene, this.camera);
82
- }
83
-
84
- updateVisuals(score) {
85
- const hue = score / 100;
86
- this.material.color.setHSL(hue, 0.8, 0.5);
87
- this.sphere.rotation.x = (score / 50) * Math.PI;
88
- this.sphere.rotation.y = (score / 75) * Math.PI;
89
- }
90
- }
91
-
92
- let visualizer;
93
- window.addEventListener('DOMContentLoaded', () => {
94
- visualizer = new GrammarVisualizer();
95
- });
96
 
97
- window.updateGrammarVisuals = (score) => {
98
- if(visualizer) visualizer.updateVisuals(score);
99
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  </script>
101
  """
102
- # Load models
103
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
104
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
105
- happy_tt = HappyTextToText("T5", MODEL_NAME)
106
-
107
- def create_diff_html(original, corrected):
108
- d = difflib.Differ()
109
- diff = d.compare(original.split(), corrected.split())
110
- return " ".join([
111
- f'<span class="diff-ins">{p[2:]}</span> ' if p.startswith('+ ') else
112
- f'<span class="diff-del">{p[2:]}</span> ' if p.startswith('- ') else
113
- f'{p[2:]} ' for p in diff
114
- ])
115
-
116
- def analyze_grammar(text):
117
- inputs = tokenizer.encode("gec: " + text, return_tensors="pt", max_length=256, truncation=True)
118
- with torch.no_grad():
119
- outputs = model.generate(inputs, max_length=256, num_beams=5)
120
- correction = tokenizer.decode(outputs[0], skip_special_tokens=True)
121
-
122
- args = TTSettings(num_beams=5, min_length=1)
123
- happy_correction = happy_tt.generate_text("gec: " + text, args=args).text
124
-
125
- final_correction = happy_correction if len(happy_correction) > len(correction) else correction
126
- changes = sum(1 for a, b in zip(text.split(), final_correction.split()) if a != b)
127
- score = max(0, 100 - (changes * 2))
128
-
129
- return {
130
- "original": text,
131
- "corrected": final_correction,
132
- "score": score,
133
- "diff_html": create_diff_html(text, final_correction)
134
- }
135
 
136
  def process_input(audio_path, text):
137
- # Handle audio input
138
- if audio_path and os.path.exists(audio_path):
139
- try:
140
  recognizer = sr.Recognizer()
141
  with sr.AudioFile(audio_path) as source:
142
  audio = recognizer.record(source)
143
  text = recognizer.recognize_google(audio)
144
- except Exception as e:
145
- return [
146
- "Audio processing error",
147
- 0,
148
- f"<span style='color:red'>Error: {str(e)}</span>",
149
- "<script>window.updateGrammarVisuals(0)</script>"
150
- ]
151
-
152
- # Handle text input
153
- if not text.strip():
154
- return ["No input provided", 0, "", "<script>window.updateGrammarVisuals(0)</script>"]
155
-
156
- try:
157
- results = analyze_grammar(text)
158
- return [
159
- results['original'],
160
- results['score'],
161
- results['diff_html'],
162
- f"<script>window.updateGrammarVisuals({results['score']})</script>"
163
- ]
 
 
 
 
 
 
 
 
 
 
164
  except Exception as e:
165
- return [
166
- "Analysis error",
167
- 0,
168
- f"<span style='color:red'>Error: {str(e)}</span>",
169
- "<script>window.updateGrammarVisuals(0)</script>"
170
- ]
171
-
172
- with gr.Blocks(css=CSS) as app:
173
- gr.Markdown("""
174
- <div class="header">
175
- <h1>🌍 3D Grammar Analyzer Pro</h1>
176
- <p>Interactive AI-Powered Writing Assistant with 3D Visualization</p>
177
- </div>
178
- """)
179
 
180
  with gr.Row():
181
  with gr.Column(scale=1):
182
  gr.Markdown("### Input Section")
183
- audio_input = gr.Audio(sources=["microphone"], type="filepath", label="🎤 Voice Input")
184
- text_input = gr.Textbox(lines=5, placeholder="📝 Type your text here...", label="Text Input")
185
- submit_btn = gr.Button("🚀 Analyze Text", variant="primary")
 
 
186
 
187
  with gr.Column(scale=2):
188
- gr.Markdown("### 3D Visualization")
189
- threejs = gr.HTML(THREEJS_TEMPLATE)
 
190
 
191
- with gr.Row():
192
- grammar_score = gr.Number(label="📊 Grammar Score", precision=0)
193
- score_gauge = gr.BarPlot(x=["Score"], y=[0], color="#3b82f6", height=150)
194
-
195
- diff_output = gr.HTML(label="📝 Text Corrections")
196
- hidden_trigger = gr.HTML(visible=False)
197
-
198
- # Fixed examples configuration
199
- gr.Markdown("### Example Sentences")
200
  gr.Examples(
201
  examples=[
202
- ["I is going to the park yesterday."], # Text-only examples
203
  ["Their happy about there test results."],
204
  ["She dont like apples, but she like bananas."]
205
  ],
206
- inputs=[text_input], # Only text input
207
- outputs=[text_input, grammar_score, diff_output, hidden_trigger],
208
- fn=lambda text: process_input(None, text), # Explicitly handle text-only examples
209
- cache_examples=False # Disable caching to prevent startup issues
210
  )
211
 
212
- submit_btn.click(
213
  fn=process_input,
214
- inputs=[audio_input, text_input],
215
- outputs=[text_input, grammar_score, diff_output, hidden_trigger]
216
  )
217
 
218
- text_input.change(
219
- lambda x: analyze_grammar(x)["score"] if x else 0,
220
- inputs=text_input,
221
- outputs=grammar_score
 
222
  )
223
 
224
  if __name__ == "__main__":
225
  app.launch(
226
  server_name="0.0.0.0",
227
  server_port=7860,
228
- share=False # Disable sharing until basic functionality works
229
  )
 
1
  import torch
2
  import gradio as gr
3
  import speech_recognition as sr
 
4
  import difflib
5
+ import json
6
  import os
7
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
8
 
 
9
  MODEL_NAME = "prithivida/grammar_error_correcter_v1"
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ D3_TEMPLATE = """
14
+ <div id="d3-dashboard">
15
+ <svg id="main-chart"></svg>
16
+ <div id="text-vis" class="text-visualization"></div>
17
+ </div>
18
+ <script src="https://d3js.org/d3.v7.min.js"></script>
19
+ <script>
20
+ const container = d3.select("#d3-dashboard");
21
+ const width = 800;
22
+ const height = 500;
23
+ const margin = {top: 20, right: 30, bottom: 30, left: 40};
24
+
25
+ const svg = container.select("#main-chart")
26
+ .attr("width", width)
27
+ .attr("height", height)
28
+ .style("background", "#f8fafc")
29
+ .style("border-radius", "12px");
30
+
31
+ const textVis = container.select("#text-vis")
32
+ .style("margin-top", "20px")
33
+ .style("min-height", "100px")
34
+ .style("padding", "20px")
35
+ .style("background", "white")
36
+ .style("border-radius", "12px")
37
+ .style("box-shadow", "0 2px 4px rgba(0,0,0,0.1)");
38
+
39
+ function updateDashboard(data) {
40
+ // Clear previous elements
41
+ svg.selectAll("*").remove();
42
+ textVis.html("");
43
+
44
+ // Score Arc
45
+ const arc = d3.arc()
46
+ .innerRadius(80)
47
+ .outerRadius(120)
48
+ .startAngle(0)
49
+ .endAngle((Math.PI * 2 * data.score) / 100);
50
+
51
+ svg.append("path")
52
+ .attr("transform", `translate(${width/2},${height/3})`)
53
+ .attr("d", arc)
54
+ .attr("fill", "#3b82f6")
55
+ .transition()
56
+ .duration(1000)
57
+ .attrTween("d", function(d) {
58
+ const interpolate = d3.interpolate(0, data.score/100);
59
+ return function(t) {
60
+ arc.endAngle(Math.PI * 2 * interpolate(t));
61
+ return arc();
62
+ };
63
  });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ // Error Distribution
66
+ const errorTypes = data.errors;
67
+ const x = d3.scaleBand()
68
+ .domain(errorTypes.map(d => d.type))
69
+ .range([margin.left, width - margin.right])
70
+ .padding(0.2);
71
+
72
+ const y = d3.scaleLinear()
73
+ .domain([0, d3.max(errorTypes, d => d.count)])
74
+ .range([height/2 - margin.bottom, margin.top]);
75
+
76
+ svg.selectAll(".error-bar")
77
+ .data(errorTypes)
78
+ .enter().append("rect")
79
+ .attr("class", "error-bar")
80
+ .attr("x", d => x(d.type))
81
+ .attr("y", d => y(d.count))
82
+ .attr("width", x.bandwidth())
83
+ .attr("height", d => height/2 - margin.bottom - y(d.count))
84
+ .attr("fill", "#ef4444")
85
+ .attr("rx", 4)
86
+ .on("mouseover", function(event, d) {
87
+ d3.select(this).attr("fill", "#dc2626");
88
+ })
89
+ .on("mouseout", function(event, d) {
90
+ d3.select(this).attr("fill", "#ef4444");
91
+ });
92
+
93
+ // Interactive Text Visualization
94
+ const textBox = textVis.selectAll(".word")
95
+ .data(data.corrections)
96
+ .enter().append("div")
97
+ .attr("class", "word")
98
+ .style("display", "inline-block")
99
+ .style("margin", "2px")
100
+ .style("padding", "4px 8px")
101
+ .style("border-radius", "4px")
102
+ .style("background", d => d.correct ? "#d1fae5" : "#fee2e2")
103
+ .style("color", d => d.correct ? "#065f46" : "#991b1b")
104
+ .style("cursor", "pointer")
105
+ .on("mouseover", function(event, d) {
106
+ d3.select(this).style("filter", "brightness(90%)");
107
+ })
108
+ .on("mouseout", function(event, d) {
109
+ d3.select(this).style("filter", "brightness(100%)");
110
+ })
111
+ .html(d => d.original);
112
+
113
+ textBox.append("div")
114
+ .attr("class", "tooltip")
115
+ .style("position", "absolute")
116
+ .style("background", "white")
117
+ .style("padding", "8px")
118
+ .style("border-radius", "6px")
119
+ .style("box-shadow", "0 2px 8px rgba(0,0,0,0.1)")
120
+ .html(d => `
121
+ <strong>${d.type}</strong><br>
122
+ ${d.message}<br>
123
+ <em>Suggested:</em> ${d.suggestion}
124
+ `);
125
+ }
126
  </script>
127
  """
128
+
129
+ def analyze_errors(original, corrected):
130
+ diff = difflib.SequenceMatcher(None, original.split(), corrected.split())
131
+ errors = []
132
+ for tag, i1, i2, j1, j2 in diff.get_opcodes():
133
+ if tag != 'equal':
134
+ error = {
135
+ 'type': 'Grammar' if tag == 'replace' else 'Structure',
136
+ 'original': ' '.join(original.split()[i1:i2]),
137
+ 'suggestion': ' '.join(corrected.split()[j1:j2]),
138
+ 'message': 'Improvement suggested' if tag == 'replace' else 'Structural change'
139
+ }
140
+ errors.append(error)
141
+ return errors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  def process_input(audio_path, text):
144
+ try:
145
+ # Handle audio input
146
+ if audio_path and os.path.exists(audio_path):
147
  recognizer = sr.Recognizer()
148
  with sr.AudioFile(audio_path) as source:
149
  audio = recognizer.record(source)
150
  text = recognizer.recognize_google(audio)
151
+
152
+ if not text.strip():
153
+ return {"error": "No input provided"}, ""
154
+
155
+ # Grammar correction
156
+ inputs = tokenizer.encode("gec: " + text, return_tensors="pt",
157
+ max_length=256, truncation=True)
158
+ with torch.no_grad():
159
+ outputs = model.generate(inputs, max_length=256, num_beams=5)
160
+ corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
161
+
162
+ # Analysis
163
+ errors = analyze_errors(text, corrected)
164
+ score = max(0, 100 - len(errors)*5)
165
+
166
+ # D3 data format
167
+ error_counts = {
168
+ "Grammar": sum(1 for e in errors if e['type'] == 'Grammar'),
169
+ "Structure": sum(1 for e in errors if e['type'] == 'Structure'),
170
+ "Spelling": 0 # Add spelling detection logic if available
171
+ }
172
+
173
+ d3_data = {
174
+ "score": score,
175
+ "errors": [{"type": k, "count": v} for k, v in error_counts.items()],
176
+ "corrections": errors[:10] # Show first 10 corrections
177
+ }
178
+
179
+ return d3_data, corrected
180
+
181
  except Exception as e:
182
+ return {"error": str(e)}, ""
183
+
184
+ with gr.Blocks(css="""
185
+ .gradio-container { max-width: 1400px!important; padding: 20px!important; }
186
+ #d3-dashboard { background: white; padding: 20px; border-radius: 12px; }
187
+ .text-visualization { font-family: monospace; font-size: 16px; }
188
+ """) as app:
189
+
190
+ gr.Markdown("# ✨ AI Writing Analytics Suite")
 
 
 
 
 
191
 
192
  with gr.Row():
193
  with gr.Column(scale=1):
194
  gr.Markdown("### Input Section")
195
+ audio = gr.Audio(sources=["microphone"], type="filepath",
196
+ label="🎤 Voice Input")
197
+ text = gr.Textbox(lines=5, placeholder="📝 Enter text here...",
198
+ label="Text Input")
199
+ btn = gr.Button("Analyze", variant="primary")
200
 
201
  with gr.Column(scale=2):
202
+ gr.Markdown("### Writing Analytics")
203
+ visualization = gr.HTML(D3_TEMPLATE)
204
+ report = gr.JSON(label="Detailed Report")
205
 
206
+ gr.Markdown("### Corrected Text")
207
+ corrected = gr.Textbox(label="Result", interactive=False)
208
+
209
+ # Examples handling
210
+ gr.Markdown("### Example Inputs")
 
 
 
 
211
  gr.Examples(
212
  examples=[
213
+ ["I is going to the park yesterday."],
214
  ["Their happy about there test results."],
215
  ["She dont like apples, but she like bananas."]
216
  ],
217
+ inputs=[text],
218
+ outputs=[report, corrected],
219
+ fn=lambda t: process_input(None, t),
220
+ cache_examples=False
221
  )
222
 
223
+ btn.click(
224
  fn=process_input,
225
+ inputs=[audio, text],
226
+ outputs=[report, corrected]
227
  )
228
 
229
+ # JavaScript update
230
+ visualization.change(
231
+ fn=lambda data: f"<script>updateDashboard({json.dumps(data)})</script>",
232
+ inputs=[report],
233
+ outputs=[visualization]
234
  )
235
 
236
  if __name__ == "__main__":
237
  app.launch(
238
  server_name="0.0.0.0",
239
  server_port=7860,
240
+ share=False
241
  )