File size: 3,409 Bytes
a87919e
6893fa1
 
 
 
 
 
6c696fb
ceec8fc
 
 
 
 
77e0511
ceec8fc
 
 
 
6c696fb
a87919e
ceec8fc
 
6c696fb
b30a610
ceec8fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77e0511
 
ceec8fc
77e0511
 
ceec8fc
6c696fb
 
 
 
ceec8fc
 
 
 
 
 
 
 
 
77e0511
ceec8fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77e0511
 
 
 
ceec8fc
 
77e0511
 
 
 
ceec8fc
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import streamlit as st

import sys

st.write("Python executable being used:")
st.write(sys.executable)

from PIL import Image
import base64
import requests
import json
import os
import re

import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse
import io

from utils.model_utils import get_model_caption
from utils.image_utils import overlay_caption


@st.cache_resource
def load_models():
    base_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
    model_angry = PeftModel.from_pretrained(base_model, "NursNurs/outputs_gemma2b_angry")
    model_happy = PeftModel.from_pretrained(base_model, "NursNurs/outputs_gemma2b_happy")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    base_model.to(device)
    model_happy.to(device)
    model_angry.to(device)
    
    # Load the adapters for specific moods
    base_model.load_adapter("NursNurs/outputs_gemma2b_happy", "happy")  
    base_model.load_adapter("NursNurs/outputs_gemma2b_angry", "angry")
    
    return base_model, tokenizer, model_happy, model_angry, device

# x = st.slider('Select a value')
# st.write(x, 'squared is', x * x)

def generate_meme_from_image(img_path, base_model, tokenizer, hf_token, device='cuda'):
  caption = get_model_caption(img_path, base_model, tokenizer, hf_token)
  print(caption)
  image = overlay_caption(caption, img_path)
  return image, caption

st.title("Image Upload and Processing App")


def main():
    st.title("Meme Generator with Mood")

    base_model, tokenizer, model_happy, model_angry, device = load_models()

    # Input widget to upload an image
    uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"])

    # Input widget to add Hugging Face token
    hf_token = st.text_input("Enter your Hugging Face Token", type='default')

    # Dropdown to select mood
    # mood = st.selectbox("Select Mood", options=["happy", "angry"])

    # Directory for saving the meme (optional, but you can let users set this if needed)
    output_dir = "results"

    if uploaded_image is not None and hf_token:
        # Convert uploaded image to a PIL image
        img = Image.open(uploaded_image)

        # Generate meme when button is pressed
        if st.button("Generate Meme"):
            with st.spinner('Generating meme...'):
                image, caption = generate_meme_from_image(img, base_model, tokenizer, hf_token, device)

                # Display the output
                st.image(image, caption=f"Generated Meme: {caption}")

                # # Optionally allow downloading the meme
                # buf = io.BytesIO()
                # image.save(buf, format="PNG")
                # byte_im = buf.getvalue()

                st.download_button(
                    label="Download Image with Caption",
                    data=image,
                    file_name="captioned_image.jpg",
                    mime="image/jpeg"
                )

if __name__ == '__main__':
    main()
# # Upload the image
# uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"])

# # Process and display if image is uploaded
# if uploaded_image is not None:
#     image = Image.open(uploaded_image)
#     st.image(image, caption="Uploaded Image", use_column_width=True)