pravin0077 commited on
Commit
630f14f
·
verified ·
1 Parent(s): 9051efa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -70
app.py CHANGED
@@ -1,83 +1,115 @@
1
- from transformers import MarianMTModel, MarianTokenizer, AutoModelForCausalLM, AutoTokenizer
2
  import requests
3
- from PIL import Image
4
  import io
 
 
 
 
5
  import gradio as gr
6
- import os
7
-
8
- # Load the MarianMT model and tokenizer for translation (Tamil to English)
9
- model_name = "Helsinki-NLP/opus-mt-mul-en"
10
- translation_model = MarianMTModel.from_pretrained(model_name)
11
- translation_tokenizer = MarianTokenizer.from_pretrained(model_name)
12
-
13
- # Load GPT-Neo for creative text generation
14
- text_generation_model_name = "EleutherAI/gpt-neo-1.3B"
15
- text_generation_model = AutoModelForCausalLM.from_pretrained(text_generation_model_name)
16
- text_generation_tokenizer = AutoTokenizer.from_pretrained(text_generation_model_name)
17
-
18
- # Add padding token to GPT-Neo tokenizer if not present
19
- if text_generation_tokenizer.pad_token is None:
20
- text_generation_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
21
-
22
- # Hugging Face API for FLUX.1 image generation
23
- API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
24
- headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACE_API_KEY')}"}
25
-
26
- # Query Hugging Face API to generate image
27
- def query(payload):
28
- response = requests.post(API_URL, headers=headers, json=payload)
29
- if response.status_code == 200:
30
- return response.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  else:
32
- return f"Error: {response.status_code} - {response.text}"
33
-
34
- # Translate Tamil text to English
35
- def translate_text(tamil_text):
36
- inputs = translation_tokenizer(tamil_text, return_tensors="pt", padding=True, truncation=True)
37
- translated_tokens = translation_model.generate(**inputs)
38
- translation = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
39
- return translation
40
-
41
- # Generate an image based on the translated text
42
- def generate_image(prompt):
43
- image_bytes = query({"inputs": prompt})
44
- if isinstance(image_bytes, str) and "Error" in image_bytes:
45
- return image_bytes # Return the error message if there's a problem with the API call
46
- image = Image.open(io.BytesIO(image_bytes))
47
- return image
48
-
49
- # Generate creative text based on the translated English text
50
- def generate_creative_text(translated_text):
51
- inputs = text_generation_tokenizer(translated_text, return_tensors="pt", padding=True, truncation=True)
52
- generated_tokens = text_generation_model.generate(**inputs, max_length=100)
53
- creative_text = text_generation_tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
54
- return creative_text
55
-
56
- # Function to handle the full workflow
57
- def translate_generate_image_and_text(tamil_text):
58
- # Step 1: Translate Tamil to English
59
- translated_text = translate_text(tamil_text)
60
-
61
- # Step 2: Generate an image from the translated text
62
- image = generate_image(translated_text)
63
-
64
- # Step 3: Generate creative text from the translated text
65
- creative_text = generate_creative_text(translated_text)
66
-
67
  return translated_text, creative_text, image
68
 
 
69
  # Create Gradio interface
70
  interface = gr.Interface(
71
- fn=translate_generate_image_and_text,
72
- inputs=gr.Textbox(label="Enter Tamil Text"), # Input for Tamil text
73
- outputs=[
74
- gr.Textbox(label="Translated Text"), # Output for translated text
75
- gr.Textbox(label="Creative Generated Text"),# Output for creative text
76
- gr.Image(label="Generated Image") # Output for generated image
77
- ],
78
  title="Tamil to English Translation, Image Generation & Creative Text",
79
  description="Enter Tamil text to translate to English, generate an image, and create creative text based on the translation."
80
  )
81
 
82
  # Launch Gradio app
83
- interface.launch()
 
 
1
+ # Import necessary libraries
2
  import requests
 
3
  import io
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ from transformers import MarianMTModel, MarianTokenizer, pipeline
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
  import gradio as gr
9
+ import os # For accessing environment variables
10
+
11
+ # Constants for model names and API URLs
12
+ class Constants:
13
+ TRANSLATION_MODEL_NAME = "Helsinki-NLP/opus-mt-mul-en"
14
+ IMAGE_GENERATION_API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
15
+ GPT_NEO_MODEL_NAME = "EleutherAI/gpt-neo-125M"
16
+ # Get the Hugging Face API token from environment variables
17
+ HEADERS = {"Authorization": f"Bearer {os.getenv('HF_API_TOKEN')}"}
18
+
19
+ # Translation Class
20
+ class Translator:
21
+ def __init__(self):
22
+ self.tokenizer = MarianTokenizer.from_pretrained(Constants.TRANSLATION_MODEL_NAME)
23
+ self.model = MarianMTModel.from_pretrained(Constants.TRANSLATION_MODEL_NAME)
24
+ self.pipeline = pipeline("translation", model=self.model, tokenizer=self.tokenizer)
25
+
26
+ def translate(self, tamil_text):
27
+ """Translate Tamil text to English."""
28
+ try:
29
+ translation = self.pipeline(tamil_text, max_length=40)
30
+ return translation[0]['translation_text']
31
+ except Exception as e:
32
+ return f"Translation error: {str(e)}"
33
+
34
+
35
+ # Image Generation Class
36
+ class ImageGenerator:
37
+ def __init__(self):
38
+ self.api_url = Constants.IMAGE_GENERATION_API_URL
39
+
40
+ def generate(self, prompt):
41
+ """Generate an image based on the given prompt."""
42
+ try:
43
+ response = requests.post(self.api_url, headers=Constants.HEADERS, json={"inputs": prompt})
44
+ if response.status_code == 200:
45
+ image_bytes = response.content
46
+ return Image.open(io.BytesIO(image_bytes))
47
+ else:
48
+ print(f"Image generation failed: Status code {response.status_code}")
49
+ return None
50
+ except Exception as e:
51
+ print(f"Image generation error: {str(e)}")
52
+ return None
53
+
54
+
55
+ # Creative Text Generation Class
56
+ class CreativeTextGenerator:
57
+ def __init__(self):
58
+ self.tokenizer = AutoTokenizer.from_pretrained(Constants.GPT_NEO_MODEL_NAME)
59
+ self.model = AutoModelForCausalLM.from_pretrained(Constants.GPT_NEO_MODEL_NAME)
60
+
61
+ def generate(self, translated_text):
62
+ """Generate creative text based on translated text."""
63
+ input_ids = self.tokenizer(translated_text, return_tensors='pt').input_ids
64
+ generated_text_ids = self.model.generate(input_ids, max_length=100)
65
+ return self.tokenizer.decode(generated_text_ids[0], skip_special_tokens=True)
66
+
67
+
68
+ # Main Application Class
69
+ class TransArtApp:
70
+ def __init__(self):
71
+ self.translator = Translator()
72
+ self.image_generator = ImageGenerator()
73
+ self.creative_text_generator = CreativeTextGenerator()
74
+
75
+ def process(self, tamil_text):
76
+ """Handle the full workflow: translate, generate image, and creative text."""
77
+ translated_text = self.translator.translate(tamil_text)
78
+ image = self.image_generator.generate(translated_text)
79
+ creative_text = self.creative_text_generator.generate(translated_text)
80
+ return translated_text, creative_text, image
81
+
82
+
83
+ # Function to display images
84
+ def show_image(image):
85
+ """Display an image using matplotlib."""
86
+ if image:
87
+ plt.imshow(image)
88
+ plt.axis('off') # Hide axes
89
+ plt.show()
90
  else:
91
+ print("No image to display.")
92
+
93
+
94
+ # Create an instance of the TransArt app
95
+ app = TransArtApp()
96
+
97
+ # Gradio interface function
98
+ def gradio_interface(tamil_text):
99
+ """Interface function for Gradio."""
100
+ translated_text, creative_text, image = app.process(tamil_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  return translated_text, creative_text, image
102
 
103
+
104
  # Create Gradio interface
105
  interface = gr.Interface(
106
+ fn=gradio_interface,
107
+ inputs="text",
108
+ outputs=["text", "text", "image"],
 
 
 
 
109
  title="Tamil to English Translation, Image Generation & Creative Text",
110
  description="Enter Tamil text to translate to English, generate an image, and create creative text based on the translation."
111
  )
112
 
113
  # Launch Gradio app
114
+ if __name__ == "__main__":
115
+ interface.launch()