import discord import logging import os from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor import torch import re import requests from PIL import Image import io import gradio as gr import threading # 로깅 설정 logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()]) # 디스코드 인텐트 설정 intents = discord.Intents.default() intents.message_content = True intents.messages = True intents.guilds = True intents.guild_messages = True # PaliGemma 모델 설정 (CPU 모드) model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to("cpu").eval() processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner") def modify_caption(caption: str) -> str: prefix_substrings = [ ('captured from ', ''), ('captured at ', '') ] pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings]) replacers = {opening: replacer for opening, replacer in prefix_substrings} def replace_fn(match): return replacers[match.group(0)] return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE) async def create_captions_rich(image: Image.Image) -> str: prompt = "caption en" image_tensor = processor(images=image, return_tensors="pt").pixel_values.to("cpu") image_tensor = (image_tensor * 255).type(torch.uint8) model_inputs = processor(text=prompt, images=image_tensor, return_tensors="pt").to("cpu") input_len = model_inputs["input_ids"].shape[-1] loop = asyncio.get_event_loop() generation = await loop.run_in_executor( None, lambda: model.generate(**model_inputs, max_new_tokens=256, do_sample=False) ) generation = generation[0][input_len:] decoded = processor.decode(generation, skip_special_tokens=True) modified_caption = modify_caption(decoded) return modified_caption # Gradio 인터페이스 설정 def create_captions_rich_sync(image): return asyncio.run(create_captions_rich(image)) css = """ #mkd { height: 500px; overflow: auto; border: 1px solid #ccc; } """ with gr.Blocks(css=css) as demo: gr.HTML("

PaliGemma Fine-tuned for Long Captioning

") with gr.Tab(label="PaliGemma Long Captioner"): with gr.Row(): with gr.Column(): input_img = gr.Image(label="Input Picture") submit_btn = gr.Button(value="Submit") output = gr.Text(label="Caption") gr.Examples( [["image1.jpg"], ["image2.jpg"], ["image3.png"], ["image4.jpg"], ["image5.jpg"], ["image6.PNG"]], inputs=[input_img], outputs=[output], fn=create_captions_rich_sync, label='Try captioning on examples' ) submit_btn.click(create_captions_rich_sync, [input_img], [output]) # Gradio 서버를 비동기적으로 실행 def run_gradio(): demo.launch( server_name="0.0.0.0", server_port=int(os.getenv("GRADIO_SERVER_PORT", 7861)), inbrowser=True ) # 특정 채널 ID 설정 SPECIFIC_CHANNEL_ID = int(os.getenv("DISCORD_CHANNEL_ID", "123456789012345678")) # 환경 변수 또는 직접 설정 # 디스코드 봇 설정 class MyClient(discord.Client): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.is_processing = False async def on_ready(self): logging.info(f'{self.user}로 로그인되었습니다!') threading.Thread(target=run_gradio, daemon=True).start() logging.info("Gradio 서버가 시작되었습니다.") async def on_message(self, message): if message.author == self.user: return if not self.is_message_in_specific_channel(message): return if self.is_processing: return self.is_processing = True try: if message.attachments: image_url = message.attachments[0].url response = await process_image(image_url, message) await message.channel.send(response) finally: self.is_processing = False def is_message_in_specific_channel(self, message): return message.channel.id == SPECIFIC_CHANNEL_ID or ( isinstance(message.channel, discord.Thread) and message.channel.parent_id == SPECIFIC_CHANNEL_ID ) async def process_image(image_url, message): image = await download_image(image_url) caption = await create_captions_rich(image) return f"{message.author.mention}, 인식된 이미지 설명: {caption}" async def download_image(url): response = requests.get(url) image = Image.open(io.BytesIO(response.content)).convert("RGB") # 이미지 변환 return image if __name__ == "__main__": discord_client = MyClient(intents=intents) discord_client.run(os.getenv('DISCORD_TOKEN')) discord_client = MyClient(intents=intents) discord_client.run(os.getenv('DISCORD_TOKEN'))