nursulu commited on
Commit
ceec8fc
1 Parent(s): 6c696fb

add functions

Browse files
app.py CHANGED
@@ -1,17 +1,94 @@
1
  import streamlit as st
2
  from PIL import Image
 
 
 
 
 
 
 
 
 
3
  import io
4
 
5
- x = st.slider('Select a value')
6
- st.write(x, 'squared is', x * x)
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  st.title("Image Upload and Processing App")
10
 
11
- # Upload the image
12
- uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"])
13
 
14
- # Process and display if image is uploaded
15
- if uploaded_image is not None:
16
- image = Image.open(uploaded_image)
17
- st.image(image, caption="Uploaded Image", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from PIL import Image
3
+ import base64
4
+ import requests
5
+ import json
6
+ import os
7
+ import re
8
+ import torch
9
+ from peft import PeftModel, PeftConfig
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ import argparse
12
  import io
13
 
14
+ from utils.model_utils import get_model_caption
15
+ from utils.image_utils import overlay_caption
16
 
17
+ @st.cache_resource
18
+ def load_models():
19
+ base_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
20
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
21
+ model_angry = PeftModel.from_pretrained(base_model, "NursNurs/outputs_gemma2b_angry")
22
+ model_happy = PeftModel.from_pretrained(base_model, "NursNurs/outputs_gemma2b_happy")
23
+
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ base_model.to(device)
26
+ model_happy.to(device)
27
+ model_angry.to(device)
28
+
29
+ # Load the adapters for specific moods
30
+ base_model.load_adapter("NursNurs/outputs_gemma2b_happy", "happy")
31
+ base_model.load_adapter("NursNurs/outputs_gemma2b_angry", "angry")
32
+
33
+ return base_model, tokenizer, model_happy, model_angry, device
34
+
35
+ # x = st.slider('Select a value')
36
+ # st.write(x, 'squared is', x * x)
37
+ def generate_meme_from_image(img_path, base_model, tokenizer, hf_token, output_dir, device='cuda'):
38
+ caption = get_model_caption(img_path, base_model, tokenizer, hf_token)
39
+ image = overlay_caption(caption, img_path, output_dir)
40
+ return image, caption
41
 
42
  st.title("Image Upload and Processing App")
43
 
 
 
44
 
45
+ def main():
46
+ st.title("Meme Generator with Mood")
47
+
48
+ base_model, tokenizer, model_happy, model_angry, device = load_models()
49
+
50
+ # Input widget to upload an image
51
+ uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"])
52
+
53
+ # Input widget to add Hugging Face token
54
+ hf_token = st.text_input("Enter your Hugging Face Token", type="password")
55
+
56
+ # Dropdown to select mood
57
+ # mood = st.selectbox("Select Mood", options=["happy", "angry"])
58
+
59
+ # Directory for saving the meme (optional, but you can let users set this if needed)
60
+ output_dir = "results"
61
+
62
+ if uploaded_image is not None and hf_token:
63
+ # Convert uploaded image to a PIL image
64
+ img = Image.open(uploaded_image)
65
+
66
+ # Generate meme when button is pressed
67
+ if st.button("Generate Meme"):
68
+ with st.spinner('Generating meme...'):
69
+ image, caption = generate_meme_from_image(img, base_model, tokenizer, hf_token, device)
70
+
71
+ # Display the output
72
+ st.image(image, caption=f"Generated Meme: {caption}")
73
+
74
+ # Optionally allow downloading the meme
75
+ buf = io.BytesIO()
76
+ image.save(buf, format="PNG")
77
+ byte_im = buf.getvalue()
78
+
79
+ st.download_button(
80
+ label="Download Meme",
81
+ data=byte_im,
82
+ file_name="generated_meme.png",
83
+ mime="image/png"
84
+ )
85
+
86
+ if __name__ == '__main__':
87
+ main()
88
+ # # Upload the image
89
+ # uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"])
90
+
91
+ # # Process and display if image is uploaded
92
+ # if uploaded_image is not None:
93
+ # image = Image.open(uploaded_image)
94
+ # st.image(image, caption="Uploaded Image", use_column_width=True)
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (236 Bytes). View file
 
utils/__pycache__/image_utils.cpython-311.pyc ADDED
Binary file (6.98 kB). View file
 
utils/__pycache__/model_utils.cpython-311.pyc ADDED
Binary file (5.05 kB). View file
 
utils/image_utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ import textwrap
5
+
6
+
7
+ def get_unique_filename(filename):
8
+ """
9
+ Generate a unique filename by appending a number if a file with the same name already exists.
10
+ """
11
+ if not os.path.exists(filename):
12
+ return filename
13
+
14
+ base, ext = os.path.splitext(filename)
15
+ counter = 1
16
+ new_filename = f"{base}_{counter}{ext}"
17
+
18
+ while os.path.exists(new_filename):
19
+ counter += 1
20
+ new_filename = f"{base}_{counter}{ext}"
21
+
22
+ return new_filename
23
+
24
+
25
+ def save_image_with_unique_name(image, path):
26
+ unique_path = get_unique_filename(path)
27
+ image.save(unique_path)
28
+ print(f"Image saved as: {unique_path}")
29
+
30
+ def find_text_in_answer(text):
31
+ print("Full caption:", text)
32
+ text = text.split("Caption:")[1]
33
+ text = text.replace("\n", "")
34
+ text = text.replace("model", "")
35
+ # Remove everything that lookslike <>
36
+ text = re.sub(r'<[^>]*>', '', text)
37
+
38
+ # Remove non-alphanumeric characters (keeping spaces)
39
+ text = re.sub(r'[^a-zA-Z0-9\?\!\s]', '', text)
40
+ print("Filtered caption:", text)
41
+ if text:
42
+ return text
43
+ else:
44
+ return "Me when I couldn't parse the model's answer but I still want you to smile :)"
45
+
46
+
47
+ def draw_text(draw, text, position, font, max_width, outline_color="black", text_color="white", outline_width=2):
48
+ """
49
+ Draw text on the image with an outline, splitting it into lines if necessary and returning the total height used by the text.
50
+ The text is horizontally centered in the specified max_width.
51
+ """
52
+ print("Adding the caption on the image...")
53
+
54
+ # Split the text into multiple lines based on the max width
55
+ lines = []
56
+ words = text.split()
57
+ line = ''
58
+ for word in words:
59
+ test_line = f'{line} {word}'.strip()
60
+ bbox = draw.textbbox((0, 0), test_line, font=font)
61
+ width = bbox[2] - bbox[0] # Width of the text
62
+ if width <= max_width:
63
+ line = test_line
64
+ else:
65
+ if line: # Avoid appending empty lines
66
+ lines.append(line)
67
+ line = word
68
+ if line:
69
+ lines.append(line)
70
+
71
+ y = position[1]
72
+
73
+ # Draw the text with an outline (black) first, centered horizontally
74
+ for line in lines:
75
+ # Calculate the width of the line and adjust the x position to center it
76
+ bbox = draw.textbbox((0, 0), line, font=font)
77
+ line_width = bbox[2] - bbox[0]
78
+ x = (max_width - line_width) // 2 + position[0]
79
+
80
+ # Draw the outline by drawing the text multiple times around the original position
81
+ for offset_x in [-outline_width, 0, outline_width]:
82
+ for offset_y in [-outline_width, 0, outline_width]:
83
+ if offset_x != 0 or offset_y != 0:
84
+ draw.text((x + offset_x, y + offset_y), line, font=font, fill=outline_color)
85
+
86
+ # Draw the main text (white) on top of the outline
87
+ draw.text((x, y), line, font=font, fill=text_color)
88
+ y += bbox[3] - bbox[1] # Update y position based on line height
89
+
90
+ return y - position[1] # Return the total height used by the text
91
+
92
+ def calculate_text_height(caption, font, max_width):
93
+ """
94
+ Calculate the height of the text when drawn, given the caption, font, and maximum width.
95
+ """
96
+ image = Image.new('RGB', (max_width, 1))
97
+ draw = ImageDraw.Draw(image)
98
+ return draw_text(draw, caption, (0, 0), font, max_width)
99
+
100
+ def add_caption(image_path, caption, output_path, top_margin=10, bottom_margin=10, max_caption_length=10, min_distance_from_bottom_mm=10):
101
+ image = Image.open(image_path)
102
+ draw = ImageDraw.Draw(image)
103
+ width, height = image.size
104
+
105
+ # Convert mm to pixels (assuming 96 DPI)
106
+ dpi = 96
107
+ min_distance_from_bottom_px = min_distance_from_bottom_mm * dpi / 25.4
108
+
109
+ # Split the caption into two parts if it is too long
110
+ if len(caption.split()) > max_caption_length:
111
+ font_size=20
112
+ total_len = len(caption.split())
113
+ mid = int(total_len / 2)
114
+
115
+ top_caption = caption.split()[:mid]
116
+ bottom_caption = caption.split()[mid:]
117
+
118
+ top_caption = " ".join(top_caption)
119
+ bottom_caption = " ".join(bottom_caption)
120
+ else:
121
+ top_caption = ""
122
+ bottom_caption = caption
123
+ font_size=30
124
+
125
+ # Load a font
126
+ font = ImageFont.truetype(r"fonts/Anton/Anton-Regular.ttf", font_size)
127
+
128
+ # Top caption
129
+ top_caption_position = (width // 10, top_margin)
130
+ draw_text(draw, top_caption, top_caption_position, font, width - 2 * (width // 10))
131
+
132
+ # Bottom caption
133
+ if bottom_caption: # Draw bottom caption only if it's not empty
134
+ # Calculate the height of the bottom caption
135
+ bottom_caption_height = calculate_text_height(bottom_caption, font, width - 2 * (width // 10))
136
+ bottom_caption_position = (width // 10, height - min_distance_from_bottom_px - bottom_caption_height)
137
+ draw_text(draw, bottom_caption, bottom_caption_position, font, width - 2 * (width // 10))
138
+
139
+ save_image_with_unique_name(image, output_path)
140
+ return image
141
+
142
+
143
+ def overlay_caption(text, img_path, output_dir):
144
+ img_name = img_path.split("/")[-1]
145
+ text = find_text_in_answer(text)
146
+ text = text.strip(".")
147
+ image = add_caption(img_path, text, output_dir+"/"+img_name)
148
+ return image
utils/model_utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import requests
3
+ import json
4
+ import pandas as pd
5
+ import os
6
+ from tqdm import tqdm
7
+ import re
8
+ import torch
9
+
10
+
11
+
12
+ def query_clip(data, hf_token):
13
+ API_URL = "https://api-inference.huggingface.co/models/openai/clip-vit-base-patch32"
14
+ headers = {"Authorization": f"Bearer {hf_token}"}
15
+ with open(data["image_path"], "rb") as f:
16
+ img = f.read()
17
+ payload={
18
+ "parameters": data["parameters"],
19
+ "inputs": base64.b64encode(img).decode("utf-8")
20
+ }
21
+ response = requests.post(API_URL, headers=headers, json=payload)
22
+ return response.json()
23
+
24
+
25
+ def get_sentiment(img_path, hf_token):
26
+ print("Getting the sentiment of the image...")
27
+ output = query_clip({
28
+ "image_path": img_path,
29
+ "parameters": {"candidate_labels": ["angry", "happy"]},
30
+ }, hf_token)
31
+ try:
32
+ print("Sentiment:", output[0]['label'])
33
+ return output[0]['label']
34
+ except:
35
+ print(output)
36
+ print("If the model is loading, try again in a minute. If you've reached a query limit (300 per hour), try within the next hour.")
37
+
38
+
39
+ def query_blip(filename, hf_token):
40
+ API_URL = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large"
41
+ headers = {"Authorization": f"Bearer {hf_token}"}
42
+ with open(filename, "rb") as f:
43
+ file = f.read()
44
+ response = requests.post(API_URL, headers=headers, data=file)
45
+ return response.json()
46
+
47
+
48
+ def get_description(img_path, hf_token):
49
+ print("Getting the context of the image...")
50
+ output = query_blip(img_path, hf_token)
51
+
52
+ try:
53
+ print("Context:", output[0]['generated_text'])
54
+ return output[0]['generated_text']
55
+ except:
56
+ print(output)
57
+ print("The model is not available right now due to query limits. Try running again now or within the next hour")
58
+
59
+
60
+ def get_model_caption(img_path, base_model, tokenizer, hf_token, device='cuda'):
61
+ sentiment = get_sentiment(img_path, hf_token)
62
+ description = get_description(img_path, hf_token)
63
+
64
+ prompt_template = """
65
+ Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n
66
+ You are given a topic. Your task is to generate a meme caption based on the topic. Only output the meme caption and nothing more.
67
+ Topic: {query}
68
+ <end_of_turn>\\n<start_of_turn>model Caption:
69
+ """
70
+ prompt = prompt_template.format(query=description)
71
+
72
+ print("Generating captions...")
73
+ encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
74
+ model_inputs = encodeds.to(device)
75
+ base_model.set_adapter(sentiment)
76
+ base_model.to(device)
77
+ generated_ids = base_model.generate(**model_inputs, max_new_tokens=20, do_sample=True, pad_token_id=tokenizer.eos_token_id)
78
+ decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
79
+ return (decoded)