Spaces:
Runtime error
Runtime error
import discord | |
import logging | |
import os | |
import asyncio | |
import subprocess | |
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
import torch | |
import re | |
import requests | |
from PIL import Image | |
import io | |
# λ‘κΉ μ€μ | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()]) | |
# λμ€μ½λ μΈν νΈ μ€μ | |
intents = discord.Intents.default() | |
intents.message_content = True | |
intents.messages = True | |
intents.guilds = True | |
intents.guild_messages = True | |
# PaliGemma λͺ¨λΈ μ€μ | |
model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cuda").eval() | |
processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner") | |
def modify_caption(caption: str) -> str: | |
prefix_substrings = [ | |
('captured from ', ''), | |
('captured at ', '') | |
] | |
pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings]) | |
replacers = {opening: replacer for opening, replacer in prefix_substrings} | |
def replace_fn(match): | |
return replacers[match.group(0)] | |
return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE) | |
def create_captions_rich(image: Image.Image) -> str: | |
prompt = "caption en" | |
image_tensor = processor(image, return_tensors="pt").pixel_values.to("cuda") | |
model_inputs = processor(text=prompt, images=image_tensor, return_tensors="pt").to("cuda") | |
input_len = model_inputs["input_ids"].shape[-1] | |
with torch.inference_mode(): | |
generation = model.generate(**model_inputs, max_new_tokens=256, do_sample=False) | |
generation = generation[0][input_len:] | |
decoded = processor.decode(generation, skip_special_tokens=True) | |
modified_caption = modify_caption(decoded) | |
return modified_caption | |
# λμ€μ½λ λ΄ μ€μ | |
class MyClient(discord.Client): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.is_processing = False | |
async def on_ready(self): | |
logging.info(f'{self.user}λ‘ λ‘κ·ΈμΈλμμ΅λλ€!') | |
subprocess.Popen(["python", "web.py"]) | |
logging.info("Web.py μλ²κ° μμλμμ΅λλ€.") | |
async def on_message(self, message): | |
if message.author == self.user: | |
return | |
if not self.is_message_in_specific_channel(message): | |
return | |
if self.is_processing: | |
return | |
self.is_processing = True | |
try: | |
if message.attachments: | |
image_url = message.attachments[0].url | |
response = await process_image(image_url, message) | |
await message.channel.send(response) | |
finally: | |
self.is_processing = False | |
def is_message_in_specific_channel(self, message): | |
return message.channel.id == SPECIFIC_CHANNEL_ID or ( | |
isinstance(message.channel, discord.Thread) and message.channel.parent_id == SPECIFIC_CHANNEL_ID | |
) | |
async def process_image(image_url, message): | |
image = await download_image(image_url) | |
caption = create_captions_rich(image) | |
return f"{message.author.mention}, μΈμλ μ΄λ―Έμ§ μ€λͺ : {caption}" | |
async def download_image(url): | |
response = requests.get(url) | |
image = Image.open(io.BytesIO(response.content)) | |
return image | |
if __name__ == "__main__": | |
discord_client = MyClient(intents=intents) | |
discord_client.run(os.getenv('DISCORD_TOKEN')) | |