gaur3009's picture
Create app.py
eb83e04 verified
raw
history blame
7.13 kB
import torch
import gradio as gr
import speech_recognition as sr
import time
import difflib
import random
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from happytransformer import HappyTextToText, TTSettings
# Constants
MODEL_NAME = "prithivida/grammar_error_correcter_v1"
CSS = """
.gradio-container { max-width: 1400px !important; }
.header { text-align: center; padding: 2rem; background: linear-gradient(135deg, #3b82f6, #6366f1); color: white; border-radius: 15px; }
#container { height: 500px; width: 100%; background: #1a1a1a; border-radius: 10px; }
.diff-ins { color: #10b981; background: #d1fae5; padding: 2px 4px; border-radius: 4px; }
.diff-del { color: #ef4444; background: #fee2e2; padding: 2px 4px; border-radius: 4px; }
"""
# Three.js Template
THREEJS_TEMPLATE = """
<div id="container"></div>
<script async src="https://unpkg.com/[email protected]/dist/es-module-shims.js"></script>
<script type="importmap">
{
"imports": {
"three": "https://unpkg.com/[email protected]/build/three.module.js",
"three/addons/": "https://unpkg.com/[email protected]/examples/jsm/"
}
}
</script>
<script type="module">
import * as THREE from 'three';
import { OrbitControls } from 'three/addons/controls/OrbitControls.js';
class GrammarVisualizer {
constructor() {
this.initScene();
this.addLights();
this.createGrammarSphere();
this.setupControls();
this.animate();
}
initScene() {
this.scene = new THREE.Scene();
this.camera = new THREE.PerspectiveCamera(75, 500/400, 0.1, 1000);
this.renderer = new THREE.WebGLRenderer({ antialias: true, alpha: true });
document.getElementById('container').appendChild(this.renderer.domElement);
this.renderer.setSize(500, 400);
this.camera.position.z = 5;
}
addLights() {
const ambient = new THREE.AmbientLight(0x404040);
const directional = new THREE.DirectionalLight(0xffffff, 1);
directional.position.set(5, 5, 5);
this.scene.add(ambient, directional);
}
createGrammarSphere() {
const geometry = new THREE.SphereGeometry(2, 32, 32);
this.material = new THREE.MeshPhongMaterial({
color: 0x3b82f6,
transparent: true,
opacity: 0.9
});
this.sphere = new THREE.Mesh(geometry, this.material);
this.scene.add(this.sphere);
}
setupControls() {
this.controls = new OrbitControls(this.camera, this.renderer.domElement);
this.controls.enableDamping = true;
this.controls.dampingFactor = 0.05;
}
animate() {
requestAnimationFrame(() => this.animate());
this.controls.update();
this.renderer.render(this.scene, this.camera);
}
updateVisuals(score) {
const hue = score / 100;
this.material.color.setHSL(hue, 0.8, 0.5);
this.sphere.rotation.x = (score / 50) * Math.PI;
this.sphere.rotation.y = (score / 75) * Math.PI;
}
}
let visualizer;
window.addEventListener('DOMContentLoaded', () => {
visualizer = new GrammarVisualizer();
});
window.updateGrammarVisuals = (score) => {
if(visualizer) visualizer.updateVisuals(score);
};
</script>
"""
# Load models
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
happy_tt = HappyTextToText("T5", MODEL_NAME)
def create_diff_html(original, corrected):
d = difflib.Differ()
diff = d.compare(original.split(), corrected.split())
return " ".join([
f'<span class="diff-ins">{p[2:]}</span> ' if p.startswith('+ ') else
f'<span class="diff-del">{p[2:]}</span> ' if p.startswith('- ') else
f'{p[2:]} ' for p in diff
])
def analyze_grammar(text):
inputs = tokenizer.encode("gec: " + text, return_tensors="pt", max_length=256, truncation=True)
with torch.no_grad():
outputs = model.generate(inputs, max_length=256, num_beams=5)
correction = tokenizer.decode(outputs[0], skip_special_tokens=True)
args = TTSettings(num_beams=5, min_length=1)
happy_correction = happy_tt.generate_text("gec: " + text, args=args).text
final_correction = happy_correction if len(happy_correction) > len(correction) else correction
changes = sum(1 for a, b in zip(text.split(), final_correction.split()) if a != b)
score = max(0, 100 - (changes * 2))
return {
"original": text,
"corrected": final_correction,
"score": score,
"diff_html": create_diff_html(text, final_correction)
}
def process_input(audio_path, text):
if audio_path:
recognizer = sr.Recognizer()
with sr.AudioFile(audio_path) as source:
audio = recognizer.record(source)
try:
text = recognizer.recognize_google(audio)
except sr.UnknownValueError:
return "Could not understand audio", 0, "", ""
if not text.strip():
return "No input", 0, "", ""
results = analyze_grammar(text)
return [
results['original'],
results['score'],
results['diff_html'],
f"<script>window.updateGrammarVisuals({results['score']})</script>"
]
with gr.Blocks(css=CSS) as app:
gr.Markdown("""
<div class="header">
<h1>🌍 3D Grammar Analyzer Pro</h1>
<p>Interactive AI-Powered Writing Assistant with 3D Visualization</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Input Section")
audio_input = gr.Audio(sources=["microphone"], type="filepath", label="🎀 Voice Input")
text_input = gr.Textbox(lines=5, placeholder="πŸ“ Type your text here...", label="Text Input")
submit_btn = gr.Button("πŸš€ Analyze Text", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### 3D Visualization")
threejs = gr.HTML(THREEJS_TEMPLATE)
with gr.Row():
grammar_score = gr.Number(label="πŸ“Š Grammar Score", precision=0)
score_gauge = gr.BarPlot(x=["Score"], y=[0], color="#3b82f6", height=150)
diff_output = gr.HTML(label="πŸ“ Text Corrections")
hidden_trigger = gr.HTML(visible=False)
gr.Markdown("### Example Sentences")
gr.Examples(
examples=[
["I is going to the park yesterday."],
["Their happy about there test results."],
["She dont like apples, but she like bananas."]
],
inputs=[text_input],
outputs=[text_input, grammar_score, diff_output, hidden_trigger],
fn=process_input,
cache_examples=True
)
submit_btn.click(
fn=process_input,
inputs=[audio_input, text_input],
outputs=[text_input, grammar_score, diff_output, hidden_trigger]
)
text_input.change(
lambda x: analyze_grammar(x)["score"] if x else 0,
inputs=text_input,
outputs=grammar_score
)
if __name__ == "__main__":
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
favicon_path="https://raw.githubusercontent.com/gradio-app/gradio/main/guides/assets/logo.png"
)