File size: 4,238 Bytes
630f14f
93cb70c
39ba994
630f14f
 
 
 
93cb70c
630f14f
 
 
 
 
 
 
 
aba5af4
630f14f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9051efa
630f14f
 
 
 
 
 
 
 
 
 
39ba994
93cb70c
630f14f
39ba994
1786f21
630f14f
 
 
39ba994
 
93cb70c
3201c4f
39ba994
630f14f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Import necessary libraries
import requests
import io
from PIL import Image
import matplotlib.pyplot as plt
from transformers import MarianMTModel, MarianTokenizer, pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import os  # For accessing environment variables

# Constants for model names and API URLs
class Constants:
    TRANSLATION_MODEL_NAME = "Helsinki-NLP/opus-mt-mul-en"
    IMAGE_GENERATION_API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
    GPT_NEO_MODEL_NAME = "EleutherAI/gpt-neo-125M"
    # Get the Hugging Face API token from environment variables
    HEADERS = {"Authorization": f"Bearer {os.getenv('HUGGINGFACE_API_KEY')}"}

# Translation Class
class Translator:
    def __init__(self):
        self.tokenizer = MarianTokenizer.from_pretrained(Constants.TRANSLATION_MODEL_NAME)
        self.model = MarianMTModel.from_pretrained(Constants.TRANSLATION_MODEL_NAME)
        self.pipeline = pipeline("translation", model=self.model, tokenizer=self.tokenizer)

    def translate(self, tamil_text):
        """Translate Tamil text to English."""
        try:
            translation = self.pipeline(tamil_text, max_length=40)
            return translation[0]['translation_text']
        except Exception as e:
            return f"Translation error: {str(e)}"


# Image Generation Class
class ImageGenerator:
    def __init__(self):
        self.api_url = Constants.IMAGE_GENERATION_API_URL

    def generate(self, prompt):
        """Generate an image based on the given prompt."""
        try:
            response = requests.post(self.api_url, headers=Constants.HEADERS, json={"inputs": prompt})
            if response.status_code == 200:
                image_bytes = response.content
                return Image.open(io.BytesIO(image_bytes))
            else:
                print(f"Image generation failed: Status code {response.status_code}")
                return None
        except Exception as e:
            print(f"Image generation error: {str(e)}")
            return None


# Creative Text Generation Class
class CreativeTextGenerator:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained(Constants.GPT_NEO_MODEL_NAME)
        self.model = AutoModelForCausalLM.from_pretrained(Constants.GPT_NEO_MODEL_NAME)

    def generate(self, translated_text):
        """Generate creative text based on translated text."""
        input_ids = self.tokenizer(translated_text, return_tensors='pt').input_ids
        generated_text_ids = self.model.generate(input_ids, max_length=100)
        return self.tokenizer.decode(generated_text_ids[0], skip_special_tokens=True)


# Main Application Class
class TransArtApp:
    def __init__(self):
        self.translator = Translator()
        self.image_generator = ImageGenerator()
        self.creative_text_generator = CreativeTextGenerator()

    def process(self, tamil_text):
        """Handle the full workflow: translate, generate image, and creative text."""
        translated_text = self.translator.translate(tamil_text)
        image = self.image_generator.generate(translated_text)
        creative_text = self.creative_text_generator.generate(translated_text)
        return translated_text, creative_text, image


# Function to display images
def show_image(image):
    """Display an image using matplotlib."""
    if image:
        plt.imshow(image)
        plt.axis('off')  # Hide axes
        plt.show()
    else:
        print("No image to display.")


# Create an instance of the TransArt app
app = TransArtApp()

# Gradio interface function
def gradio_interface(tamil_text):
    """Interface function for Gradio."""
    translated_text, creative_text, image = app.process(tamil_text)
    return translated_text, creative_text, image


# Create Gradio interface
interface = gr.Interface(
    fn=gradio_interface,
    inputs="text",
    outputs=["text", "text", "image"],
    title="Tamil to English Translation, Image Generation & Creative Text",
    description="Enter Tamil text to translate to English, generate an image, and create creative text based on the translation."
)

# Launch Gradio app
if __name__ == "__main__":
    interface.launch()