|
import base64 |
|
import requests |
|
import json |
|
import pandas as pd |
|
import os |
|
from tqdm import tqdm |
|
import re |
|
import torch |
|
import io |
|
from PIL import Image |
|
|
|
def image_to_bytes(image): |
|
"""Convert PIL Image to bytes.""" |
|
buffer = io.BytesIO() |
|
image.save(buffer, format="JPEG") |
|
return buffer.getvalue() |
|
|
|
|
|
def query_clip(data, hf_token): |
|
API_URL = "https://api-inference.huggingface.co/models/openai/clip-vit-base-patch32" |
|
headers = {"Authorization": f"Bearer {hf_token}"} |
|
img = data['image'] |
|
img_bytes = image_to_bytes(img) |
|
image = Image.open(io.BytesIO(img_bytes)) |
|
|
|
encoded_img = base64.b64encode(img_bytes).decode("utf-8") |
|
|
|
payload={ |
|
"parameters": data["parameters"], |
|
"inputs": encoded_img |
|
} |
|
response = requests.post(API_URL, headers=headers, json=payload) |
|
return response.json() |
|
|
|
|
|
def get_sentiment(img, hf_token): |
|
print("Getting the sentiment of the image...") |
|
output = query_clip({ |
|
"image": img, |
|
"parameters": {"candidate_labels": ["angry", "happy"]}, |
|
}, hf_token) |
|
try: |
|
print("Sentiment:", output[0]['label']) |
|
return output[0]['label'] |
|
except: |
|
print(output) |
|
print("If the model is loading, try again in a minute. If you've reached a query limit (300 per hour), try within the next hour.") |
|
|
|
|
|
def query_blip(img, hf_token): |
|
API_URL = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large" |
|
headers = {"Authorization": f"Bearer {hf_token}"} |
|
|
|
img_bytes = image_to_bytes(img) |
|
|
|
files = { |
|
'file': ('image.jpg', img_bytes, 'image/jpeg') |
|
} |
|
response = requests.post(API_URL, headers=headers, data=files) |
|
return response.json() |
|
|
|
|
|
def get_description(img, hf_token): |
|
print("Getting the context of the image...") |
|
output = query_blip(img, hf_token) |
|
|
|
try: |
|
print("Context:", output[0]['generated_text']) |
|
return output[0]['generated_text'] |
|
except: |
|
print(output) |
|
print("The model is not available right now due to query limits. Try running again now or within the next hour") |
|
|
|
|
|
def get_model_caption(img_path, base_model, tokenizer, hf_token, device='cuda'): |
|
sentiment = get_sentiment(img_path, hf_token) |
|
description = get_description(img_path, hf_token) |
|
|
|
prompt_template = """ |
|
Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n |
|
You are given a topic. Your task is to generate a meme caption based on the topic. Only output the meme caption and nothing more. |
|
Topic: {query} |
|
<end_of_turn>\\n<start_of_turn>model Caption: |
|
""" |
|
prompt = prompt_template.format(query=description) |
|
|
|
print("Generating captions...") |
|
encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) |
|
model_inputs = encodeds.to(device) |
|
print("sentiment", sentiment) |
|
base_model.set_adapter(sentiment) |
|
base_model.to(device) |
|
generated_ids = base_model.generate(**model_inputs, max_new_tokens=20, do_sample=True, pad_token_id=tokenizer.eos_token_id) |
|
decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
return (decoded) |