File size: 2,616 Bytes
b502a48 24d193b b502a48 a66b74b b502a48 8bd959a b502a48 24d193b b502a48 8bd959a 24d193b a66b74b 8bd959a 24d193b 8bd959a 24d193b b502a48 8bd959a b502a48 8bd959a ac510fc 8bd959a 24d193b 8bd959a ac510fc b502a48 24d193b b502a48 8bd959a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import spaces
import torch
import re
from PIL import Image
model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cpu").eval()
processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")
def modify_caption(caption: str) -> str:
"""
Removes specific prefixes from captions.
Args:
caption (str): A string containing a caption.
Returns:
str: The caption with the prefix removed if it was present.
"""
prefix_substrings = [
('captured from ', ''),
('captured at ', '')
]
pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
replacers = {opening: replacer for opening, replacer in prefix_substrings}
def replace_fn(match):
return replacers[match.group(0)]
return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
def create_captions_rich(files):
captions = []
prompt = "caption en"
for file_path in files:
try:
image = Image.open(file_path.name)
except Exception as e:
captions.append(f"Error opening image: {e}")
continue
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to("cpu")
input_len = model_inputs["input_ids"].shape[-1]
try:
with torch.no_grad():
generation = model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
modified_caption = modify_caption(decoded)
captions.append(modified_caption)
except Exception as e:
captions.append(f"Error generating caption: {e}")
return "\n".join(captions)
css = """
#mkd {
height: 500px;
overflow: auto;
border: 16px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1><center>Fine-tuned PaliGemma for SD3 Image Guided Prompt Generation.<center><h1>")
with gr.Tab(label="Image to Prompt for SD3"):
with gr.Row():
with gr.Column():
input_files = gr.Files(label="Input Images")
submit_btn = gr.Button(value="Start")
outputs = gr.Textbox(label="Prompts", lines=10, interactive=False)
submit_btn.click(create_captions_rich, inputs=[input_files], outputs=[outputs])
demo.launch(debug=True) |