Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
import speech_recognition as sr | |
import time | |
import difflib | |
import random | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from happytransformer import HappyTextToText, TTSettings | |
# Constants | |
MODEL_NAME = "prithivida/grammar_error_correcter_v1" | |
CSS = """ | |
.gradio-container { max-width: 1400px !important; } | |
.header { text-align: center; padding: 2rem; background: linear-gradient(135deg, #3b82f6, #6366f1); color: white; border-radius: 15px; } | |
#container { height: 500px; width: 100%; background: #1a1a1a; border-radius: 10px; } | |
.diff-ins { color: #10b981; background: #d1fae5; padding: 2px 4px; border-radius: 4px; } | |
.diff-del { color: #ef4444; background: #fee2e2; padding: 2px 4px; border-radius: 4px; } | |
""" | |
# Three.js Template | |
THREEJS_TEMPLATE = """ | |
<div id="container"></div> | |
<script async src="https://unpkg.com/[email protected]/dist/es-module-shims.js"></script> | |
<script type="importmap"> | |
{ | |
"imports": { | |
"three": "https://unpkg.com/[email protected]/build/three.module.js", | |
"three/addons/": "https://unpkg.com/[email protected]/examples/jsm/" | |
} | |
} | |
</script> | |
<script type="module"> | |
import * as THREE from 'three'; | |
import { OrbitControls } from 'three/addons/controls/OrbitControls.js'; | |
class GrammarVisualizer { | |
constructor() { | |
this.initScene(); | |
this.addLights(); | |
this.createGrammarSphere(); | |
this.setupControls(); | |
this.animate(); | |
} | |
initScene() { | |
this.scene = new THREE.Scene(); | |
this.camera = new THREE.PerspectiveCamera(75, 500/400, 0.1, 1000); | |
this.renderer = new THREE.WebGLRenderer({ antialias: true, alpha: true }); | |
document.getElementById('container').appendChild(this.renderer.domElement); | |
this.renderer.setSize(500, 400); | |
this.camera.position.z = 5; | |
} | |
addLights() { | |
const ambient = new THREE.AmbientLight(0x404040); | |
const directional = new THREE.DirectionalLight(0xffffff, 1); | |
directional.position.set(5, 5, 5); | |
this.scene.add(ambient, directional); | |
} | |
createGrammarSphere() { | |
const geometry = new THREE.SphereGeometry(2, 32, 32); | |
this.material = new THREE.MeshPhongMaterial({ | |
color: 0x3b82f6, | |
transparent: true, | |
opacity: 0.9 | |
}); | |
this.sphere = new THREE.Mesh(geometry, this.material); | |
this.scene.add(this.sphere); | |
} | |
setupControls() { | |
this.controls = new OrbitControls(this.camera, this.renderer.domElement); | |
this.controls.enableDamping = true; | |
this.controls.dampingFactor = 0.05; | |
} | |
animate() { | |
requestAnimationFrame(() => this.animate()); | |
this.controls.update(); | |
this.renderer.render(this.scene, this.camera); | |
} | |
updateVisuals(score) { | |
const hue = score / 100; | |
this.material.color.setHSL(hue, 0.8, 0.5); | |
this.sphere.rotation.x = (score / 50) * Math.PI; | |
this.sphere.rotation.y = (score / 75) * Math.PI; | |
} | |
} | |
let visualizer; | |
window.addEventListener('DOMContentLoaded', () => { | |
visualizer = new GrammarVisualizer(); | |
}); | |
window.updateGrammarVisuals = (score) => { | |
if(visualizer) visualizer.updateVisuals(score); | |
}; | |
</script> | |
""" | |
# Load models | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
happy_tt = HappyTextToText("T5", MODEL_NAME) | |
def create_diff_html(original, corrected): | |
d = difflib.Differ() | |
diff = d.compare(original.split(), corrected.split()) | |
return " ".join([ | |
f'<span class="diff-ins">{p[2:]}</span> ' if p.startswith('+ ') else | |
f'<span class="diff-del">{p[2:]}</span> ' if p.startswith('- ') else | |
f'{p[2:]} ' for p in diff | |
]) | |
def analyze_grammar(text): | |
inputs = tokenizer.encode("gec: " + text, return_tensors="pt", max_length=256, truncation=True) | |
with torch.no_grad(): | |
outputs = model.generate(inputs, max_length=256, num_beams=5) | |
correction = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
args = TTSettings(num_beams=5, min_length=1) | |
happy_correction = happy_tt.generate_text("gec: " + text, args=args).text | |
final_correction = happy_correction if len(happy_correction) > len(correction) else correction | |
changes = sum(1 for a, b in zip(text.split(), final_correction.split()) if a != b) | |
score = max(0, 100 - (changes * 2)) | |
return { | |
"original": text, | |
"corrected": final_correction, | |
"score": score, | |
"diff_html": create_diff_html(text, final_correction) | |
} | |
def process_input(audio_path, text): | |
if audio_path: | |
recognizer = sr.Recognizer() | |
with sr.AudioFile(audio_path) as source: | |
audio = recognizer.record(source) | |
try: | |
text = recognizer.recognize_google(audio) | |
except sr.UnknownValueError: | |
return "Could not understand audio", 0, "", "" | |
if not text.strip(): | |
return "No input", 0, "", "" | |
results = analyze_grammar(text) | |
return [ | |
results['original'], | |
results['score'], | |
results['diff_html'], | |
f"<script>window.updateGrammarVisuals({results['score']})</script>" | |
] | |
with gr.Blocks(css=CSS) as app: | |
gr.Markdown(""" | |
<div class="header"> | |
<h1>π 3D Grammar Analyzer Pro</h1> | |
<p>Interactive AI-Powered Writing Assistant with 3D Visualization</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Input Section") | |
audio_input = gr.Audio(sources=["microphone"], type="filepath", label="π€ Voice Input") | |
text_input = gr.Textbox(lines=5, placeholder="π Type your text here...", label="Text Input") | |
submit_btn = gr.Button("π Analyze Text", variant="primary") | |
with gr.Column(scale=2): | |
gr.Markdown("### 3D Visualization") | |
threejs = gr.HTML(THREEJS_TEMPLATE) | |
with gr.Row(): | |
grammar_score = gr.Number(label="π Grammar Score", precision=0) | |
score_gauge = gr.BarPlot(x=["Score"], y=[0], color="#3b82f6", height=150) | |
diff_output = gr.HTML(label="π Text Corrections") | |
hidden_trigger = gr.HTML(visible=False) | |
gr.Markdown("### Example Sentences") | |
gr.Examples( | |
examples=[ | |
["I is going to the park yesterday."], | |
["Their happy about there test results."], | |
["She dont like apples, but she like bananas."] | |
], | |
inputs=[text_input], | |
outputs=[text_input, grammar_score, diff_output, hidden_trigger], | |
fn=process_input, | |
cache_examples=True | |
) | |
submit_btn.click( | |
fn=process_input, | |
inputs=[audio_input, text_input], | |
outputs=[text_input, grammar_score, diff_output, hidden_trigger] | |
) | |
text_input.change( | |
lambda x: analyze_grammar(x)["score"] if x else 0, | |
inputs=text_input, | |
outputs=grammar_score | |
) | |
if __name__ == "__main__": | |
app.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, | |
favicon_path="https://raw.githubusercontent.com/gradio-app/gradio/main/guides/assets/logo.png" | |
) |