Spaces:
Sleeping
Sleeping
# Import necessary libraries | |
import requests | |
import io | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from transformers import MarianMTModel, MarianTokenizer, pipeline | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
import os # For accessing environment variables | |
# Constants for model names and API URLs | |
class Constants: | |
TRANSLATION_MODEL_NAME = "Helsinki-NLP/opus-mt-mul-en" | |
IMAGE_GENERATION_API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev" | |
GPT_NEO_MODEL_NAME = "EleutherAI/gpt-neo-125M" | |
# Get the Hugging Face API token from environment variables | |
HEADERS = {"Authorization": f"Bearer {os.getenv('HUGGINGFACE_API_KEY')}"} | |
# Translation Class | |
class Translator: | |
def __init__(self): | |
self.tokenizer = MarianTokenizer.from_pretrained(Constants.TRANSLATION_MODEL_NAME) | |
self.model = MarianMTModel.from_pretrained(Constants.TRANSLATION_MODEL_NAME) | |
self.pipeline = pipeline("translation", model=self.model, tokenizer=self.tokenizer) | |
def translate(self, tamil_text): | |
"""Translate Tamil text to English.""" | |
try: | |
translation = self.pipeline(tamil_text, max_length=40) | |
return translation[0]['translation_text'] | |
except Exception as e: | |
return f"Translation error: {str(e)}" | |
# Image Generation Class | |
class ImageGenerator: | |
def __init__(self): | |
self.api_url = Constants.IMAGE_GENERATION_API_URL | |
def generate(self, prompt): | |
"""Generate an image based on the given prompt.""" | |
try: | |
response = requests.post(self.api_url, headers=Constants.HEADERS, json={"inputs": prompt}) | |
if response.status_code == 200: | |
image_bytes = response.content | |
return Image.open(io.BytesIO(image_bytes)) | |
else: | |
print(f"Image generation failed: Status code {response.status_code}") | |
return None | |
except Exception as e: | |
print(f"Image generation error: {str(e)}") | |
return None | |
# Creative Text Generation Class | |
class CreativeTextGenerator: | |
def __init__(self): | |
self.tokenizer = AutoTokenizer.from_pretrained(Constants.GPT_NEO_MODEL_NAME) | |
self.model = AutoModelForCausalLM.from_pretrained(Constants.GPT_NEO_MODEL_NAME) | |
def generate(self, translated_text): | |
"""Generate creative text based on translated text.""" | |
input_ids = self.tokenizer(translated_text, return_tensors='pt').input_ids | |
generated_text_ids = self.model.generate(input_ids, max_length=100) | |
return self.tokenizer.decode(generated_text_ids[0], skip_special_tokens=True) | |
# Main Application Class | |
class TransArtApp: | |
def __init__(self): | |
self.translator = Translator() | |
self.image_generator = ImageGenerator() | |
self.creative_text_generator = CreativeTextGenerator() | |
def process(self, tamil_text): | |
"""Handle the full workflow: translate, generate image, and creative text.""" | |
translated_text = self.translator.translate(tamil_text) | |
image = self.image_generator.generate(translated_text) | |
creative_text = self.creative_text_generator.generate(translated_text) | |
return translated_text, creative_text, image | |
# Function to display images | |
def show_image(image): | |
"""Display an image using matplotlib.""" | |
if image: | |
plt.imshow(image) | |
plt.axis('off') # Hide axes | |
plt.show() | |
else: | |
print("No image to display.") | |
# Create an instance of the TransArt app | |
app = TransArtApp() | |
# Gradio interface function | |
def gradio_interface(tamil_text): | |
"""Interface function for Gradio.""" | |
translated_text, creative_text, image = app.process(tamil_text) | |
return translated_text, creative_text, image | |
# Create Gradio interface | |
interface = gr.Interface( | |
fn=gradio_interface, | |
inputs="text", | |
outputs=["text", "text", "image"], | |
title="Tamil to English Translation, Image Generation & Creative Text", | |
description="Enter Tamil text to translate to English, generate an image, and create creative text based on the translation." | |
) | |
# Launch Gradio app | |
if __name__ == "__main__": | |
interface.launch() |