|
import streamlit as st |
|
from PIL import Image |
|
import base64 |
|
import requests |
|
import json |
|
import os |
|
import re |
|
import torch |
|
from peft import PeftModel, PeftConfig |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import argparse |
|
import io |
|
|
|
from utils.model_utils import get_model_caption |
|
from utils.image_utils import overlay_caption |
|
|
|
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 |
|
|
|
@st.cache_resource |
|
def load_models(): |
|
base_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b") |
|
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") |
|
model_angry = PeftModel.from_pretrained(base_model, "NursNurs/outputs_gemma2b_angry") |
|
model_happy = PeftModel.from_pretrained(base_model, "NursNurs/outputs_gemma2b_happy") |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
base_model.to(device) |
|
model_happy.to(device) |
|
model_angry.to(device) |
|
|
|
|
|
base_model.load_adapter("NursNurs/outputs_gemma2b_happy", "happy") |
|
base_model.load_adapter("NursNurs/outputs_gemma2b_angry", "angry") |
|
|
|
return base_model, tokenizer, model_happy, model_angry, device |
|
|
|
|
|
|
|
def generate_meme_from_image(img_path, base_model, tokenizer, hf_token, output_dir, device='cuda'): |
|
caption = get_model_caption(img_path, base_model, tokenizer, hf_token) |
|
image = overlay_caption(caption, img_path, output_dir) |
|
return image, caption |
|
|
|
st.title("Image Upload and Processing App") |
|
|
|
|
|
def main(): |
|
st.title("Meme Generator with Mood") |
|
|
|
base_model, tokenizer, model_happy, model_angry, device = load_models() |
|
|
|
|
|
uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"]) |
|
|
|
|
|
hf_token = st.text_input("Enter your Hugging Face Token", type="password") |
|
|
|
|
|
|
|
|
|
|
|
output_dir = "results" |
|
|
|
if uploaded_image is not None and hf_token: |
|
|
|
img = Image.open(uploaded_image) |
|
|
|
|
|
if st.button("Generate Meme"): |
|
with st.spinner('Generating meme...'): |
|
image, caption = generate_meme_from_image(img, base_model, tokenizer, hf_token, device) |
|
|
|
|
|
st.image(image, caption=f"Generated Meme: {caption}") |
|
|
|
|
|
buf = io.BytesIO() |
|
image.save(buf, format="PNG") |
|
byte_im = buf.getvalue() |
|
|
|
st.download_button( |
|
label="Download Meme", |
|
data=byte_im, |
|
file_name="generated_meme.png", |
|
mime="image/png" |
|
) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|