transart / app.py
pravin0077's picture
Update app.py
aba5af4 verified
raw
history blame
4.24 kB
# Import necessary libraries
import requests
import io
from PIL import Image
import matplotlib.pyplot as plt
from transformers import MarianMTModel, MarianTokenizer, pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import os # For accessing environment variables
# Constants for model names and API URLs
class Constants:
TRANSLATION_MODEL_NAME = "Helsinki-NLP/opus-mt-mul-en"
IMAGE_GENERATION_API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
GPT_NEO_MODEL_NAME = "EleutherAI/gpt-neo-125M"
# Get the Hugging Face API token from environment variables
HEADERS = {"Authorization": f"Bearer {os.getenv('HUGGINGFACE_API_KEY')}"}
# Translation Class
class Translator:
def __init__(self):
self.tokenizer = MarianTokenizer.from_pretrained(Constants.TRANSLATION_MODEL_NAME)
self.model = MarianMTModel.from_pretrained(Constants.TRANSLATION_MODEL_NAME)
self.pipeline = pipeline("translation", model=self.model, tokenizer=self.tokenizer)
def translate(self, tamil_text):
"""Translate Tamil text to English."""
try:
translation = self.pipeline(tamil_text, max_length=40)
return translation[0]['translation_text']
except Exception as e:
return f"Translation error: {str(e)}"
# Image Generation Class
class ImageGenerator:
def __init__(self):
self.api_url = Constants.IMAGE_GENERATION_API_URL
def generate(self, prompt):
"""Generate an image based on the given prompt."""
try:
response = requests.post(self.api_url, headers=Constants.HEADERS, json={"inputs": prompt})
if response.status_code == 200:
image_bytes = response.content
return Image.open(io.BytesIO(image_bytes))
else:
print(f"Image generation failed: Status code {response.status_code}")
return None
except Exception as e:
print(f"Image generation error: {str(e)}")
return None
# Creative Text Generation Class
class CreativeTextGenerator:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained(Constants.GPT_NEO_MODEL_NAME)
self.model = AutoModelForCausalLM.from_pretrained(Constants.GPT_NEO_MODEL_NAME)
def generate(self, translated_text):
"""Generate creative text based on translated text."""
input_ids = self.tokenizer(translated_text, return_tensors='pt').input_ids
generated_text_ids = self.model.generate(input_ids, max_length=100)
return self.tokenizer.decode(generated_text_ids[0], skip_special_tokens=True)
# Main Application Class
class TransArtApp:
def __init__(self):
self.translator = Translator()
self.image_generator = ImageGenerator()
self.creative_text_generator = CreativeTextGenerator()
def process(self, tamil_text):
"""Handle the full workflow: translate, generate image, and creative text."""
translated_text = self.translator.translate(tamil_text)
image = self.image_generator.generate(translated_text)
creative_text = self.creative_text_generator.generate(translated_text)
return translated_text, creative_text, image
# Function to display images
def show_image(image):
"""Display an image using matplotlib."""
if image:
plt.imshow(image)
plt.axis('off') # Hide axes
plt.show()
else:
print("No image to display.")
# Create an instance of the TransArt app
app = TransArtApp()
# Gradio interface function
def gradio_interface(tamil_text):
"""Interface function for Gradio."""
translated_text, creative_text, image = app.process(tamil_text)
return translated_text, creative_text, image
# Create Gradio interface
interface = gr.Interface(
fn=gradio_interface,
inputs="text",
outputs=["text", "text", "image"],
title="Tamil to English Translation, Image Generation & Creative Text",
description="Enter Tamil text to translate to English, generate an image, and create creative text based on the translation."
)
# Launch Gradio app
if __name__ == "__main__":
interface.launch()