import discord
import logging
import os
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
import re
import requests
from PIL import Image
import io
import gradio as gr
import threading
# 로깅 설정
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()])
# 디스코드 인텐트 설정
intents = discord.Intents.default()
intents.message_content = True
intents.messages = True
intents.guilds = True
intents.guild_messages = True
# PaliGemma 모델 설정 (CPU 모드)
model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cpu").eval()
processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")
def modify_caption(caption: str) -> str:
prefix_substrings = [
('captured from ', ''),
('captured at ', '')
]
pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
replacers = {opening: replacer for opening, replacer in prefix_substrings}
def replace_fn(match):
return replacers[match.group(0)]
return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
async def create_captions_rich(image: Image.Image) -> str:
prompt = "caption en"
image_tensor = processor(images=image, return_tensors="pt").pixel_values.to("cpu")
image_tensor = (image_tensor * 255).type(torch.uint8)
model_inputs = processor(text=prompt, images=image_tensor, return_tensors="pt").to("cpu")
input_len = model_inputs["input_ids"].shape[-1]
loop = asyncio.get_event_loop()
generation = await loop.run_in_executor(
None,
lambda: model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
modified_caption = modify_caption(decoded)
return modified_caption
# Gradio 인터페이스 설정
def create_captions_rich_sync(image):
return asyncio.run(create_captions_rich(image))
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("
PaliGemma Fine-tuned for Long Captioning")
with gr.Tab(label="PaliGemma Long Captioner"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Picture")
submit_btn = gr.Button(value="Submit")
output = gr.Text(label="Caption")
gr.Examples(
[["image1.jpg"], ["image2.jpg"], ["image3.png"], ["image4.jpg"], ["image5.jpg"], ["image6.PNG"]],
inputs=[input_img],
outputs=[output],
fn=create_captions_rich_sync,
label='Try captioning on examples'
)
submit_btn.click(create_captions_rich_sync, [input_img], [output])
# Gradio 서버를 비동기적으로 실행
def run_gradio():
demo.launch(
server_name="0.0.0.0",
server_port=int(os.getenv("GRADIO_SERVER_PORT", 7861)),
inbrowser=True
)
# 특정 채널 ID 설정
SPECIFIC_CHANNEL_ID = int(os.getenv("DISCORD_CHANNEL_ID", "123456789012345678")) # 환경 변수 또는 직접 설정
# 디스코드 봇 설정
class MyClient(discord.Client):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_processing = False
async def on_ready(self):
logging.info(f'{self.user}로 로그인되었습니다!')
threading.Thread(target=run_gradio, daemon=True).start()
logging.info("Gradio 서버가 시작되었습니다.")
async def on_message(self, message):
if message.author == self.user:
return
if not self.is_message_in_specific_channel(message):
return
if self.is_processing:
return
self.is_processing = True
try:
if message.attachments:
image_url = message.attachments[0].url
response = await process_image(image_url, message)
await message.channel.send(response)
finally:
self.is_processing = False
def is_message_in_specific_channel(self, message):
return message.channel.id == SPECIFIC_CHANNEL_ID or (
isinstance(message.channel, discord.Thread) and message.channel.parent_id == SPECIFIC_CHANNEL_ID
)
async def process_image(image_url, message):
image = await download_image(image_url)
caption = await create_captions_rich(image)
return f"{message.author.mention}, 인식된 이미지 설명: {caption}"
async def download_image(url):
response = requests.get(url)
image = Image.open(io.BytesIO(response.content)).convert("RGB") # 이미지 변환
return image
if __name__ == "__main__":
discord_client = MyClient(intents=intents)
discord_client.run(os.getenv('DISCORD_TOKEN'))
discord_client = MyClient(intents=intents)
discord_client.run(os.getenv('DISCORD_TOKEN'))