Meme-caption-generator / utils /model_utils.py
nursulu
Update
77e0511
raw
history blame
3.21 kB
import base64
import requests
import json
import pandas as pd
import os
from tqdm import tqdm
import re
import torch
import io
from PIL import Image
def image_to_bytes(image):
"""Convert PIL Image to bytes."""
buffer = io.BytesIO()
image.save(buffer, format="JPEG") # Adjust format if necessary
return buffer.getvalue()
def query_clip(data, hf_token):
API_URL = "https://api-inference.huggingface.co/models/openai/clip-vit-base-patch32"
headers = {"Authorization": f"Bearer {hf_token}"}
img = data['image']
img_bytes = image_to_bytes(img)
image = Image.open(io.BytesIO(img_bytes))
encoded_img = base64.b64encode(img_bytes).decode("utf-8")
payload={
"parameters": data["parameters"],
"inputs": encoded_img
}
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
def get_sentiment(img, hf_token):
print("Getting the sentiment of the image...")
output = query_clip({
"image": img,
"parameters": {"candidate_labels": ["angry", "happy"]},
}, hf_token)
try:
print("Sentiment:", output[0]['label'])
return output[0]['label']
except:
print(output)
print("If the model is loading, try again in a minute. If you've reached a query limit (300 per hour), try within the next hour.")
def query_blip(img, hf_token):
API_URL = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large"
headers = {"Authorization": f"Bearer {hf_token}"}
img_bytes = image_to_bytes(img)
files = {
'file': ('image.jpg', img_bytes, 'image/jpeg')
}
response = requests.post(API_URL, headers=headers, data=files)
return response.json()
def get_description(img, hf_token):
print("Getting the context of the image...")
output = query_blip(img, hf_token)
try:
print("Context:", output[0]['generated_text'])
return output[0]['generated_text']
except:
print(output)
print("The model is not available right now due to query limits. Try running again now or within the next hour")
def get_model_caption(img_path, base_model, tokenizer, hf_token, device='cuda'):
sentiment = get_sentiment(img_path, hf_token)
description = get_description(img_path, hf_token)
prompt_template = """
Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n
You are given a topic. Your task is to generate a meme caption based on the topic. Only output the meme caption and nothing more.
Topic: {query}
<end_of_turn>\\n<start_of_turn>model Caption:
"""
prompt = prompt_template.format(query=description)
print("Generating captions...")
encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model_inputs = encodeds.to(device)
print("sentiment", sentiment)
base_model.set_adapter(sentiment)
base_model.to(device)
generated_ids = base_model.generate(**model_inputs, max_new_tokens=20, do_sample=True, pad_token_id=tokenizer.eos_token_id)
decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return (decoded)