dstars's picture
Update demo/agent.py
2c2417f verified
import os
import logging
from datetime import datetime
import gradio as gr
from PIL import Image
# from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig, ChatTemplateConfig
from lmdeploy import pipeline, GenerationConfig, ChatTemplateConfig
from lmdeploy.vl import load_image
class ConversationalAgent:
def __init__(self, model_path, outputs_dir, device='cpu'):
# 传入 device 参数,并设置 pipeline 时指定设备
self.device = device
self.pipe = pipeline(
model_path,
chat_template_config=ChatTemplateConfig(model_name='internvl2-internlm2'),
# backend_config=TurbomindEngineConfig(session_len=8192),
# backend_config=PytorchEngineConfig(max_length=8192),
device=self.device
)
self.uploaded_images_storage = os.path.join(outputs_dir, "uploaded")
self.uploaded_images_storage = os.path.abspath(self.uploaded_images_storage)
os.makedirs(self.uploaded_images_storage, exist_ok=True)
self.sess = None
def start_chat(self, chat_state):
self.sess = None
self.context = ""
self.current_image_id = -1
self.image_list = []
self.pixel_values_list = []
self.seen_image_idx = []
logging.info("=" * 30 + "Start Chat" + "=" * 30)
return (
#gr.update(interactive=False), # [image] Image
gr.update(interactive=True, placeholder='input the text.'), # [input_text] Textbox
gr.update(interactive=False), # [start_btn] Button
gr.update(interactive=True), # [clear_btn] Button
gr.update(interactive=True), # [image] Image
gr.update(interactive=True), # [upload_btn] Button
chat_state # [chat_state] State
)
def restart_chat(self, chat_state):
self.sess = None
self.context = ""
self.current_image_id = -1
self.image_list = []
self.pixel_values_list = []
self.seen_image_idx = []
logging.info("=" * 30 + "End Chat" + "=" * 30)
return (
None, # [chatbot] Chatbot
#gr.update(value=None, interactive=True), # [image] Image
gr.update(interactive=False, placeholder="Please click the <Start Chat> button to start chat!"), # [input_text] Textbox
gr.update(interactive=True), # [start] Button
gr.update(interactive=False), # [clear] Button
gr.update(value=None, interactive=False), # [image] Image
gr.update(interactive=False), # [upload_btn] Button
chat_state # [chat_state] State
)
def upload_image(self, image: Image.Image, chat_history: gr.Chatbot, chat_state: gr.State):
logging.info(f"type(image): {type(image)}")
self.image_list.append(image)
save_image_path = os.path.join(self.uploaded_images_storage, "{}.jpg".format(len(os.listdir(self.uploaded_images_storage))))
image.save(save_image_path)
logging.info(f"image save path: {save_image_path}")
chat_history.append((gr.HTML(f'<img src="./file={save_image_path}" style="width: 200px; height: auto; display: inline-block;">'), "Received."))
return None, chat_history, chat_state
def respond(
self,
message,
image,
chat_history: gr.Chatbot,
top_p,
temperature,
chat_state,
):
current_time = datetime.now().strftime("%b%d-%H:%M:%S")
logging.info(f"Time: {current_time}")
logging.info(f"User: {message}")
gen_config = GenerationConfig(top_p=top_p, temperature=temperature)
chat_input = message
if image is not None:
save_image_path = os.path.join(self.uploaded_images_storage, "{}.jpg".format(len(os.listdir(self.uploaded_images_storage))))
image.save(save_image_path)
logging.info(f"image save path: {save_image_path}")
chat_input = (message, image)
if self.sess is None:
self.sess = self.pipe.chat(chat_input, gen_config=gen_config)
else:
self.sess = self.pipe.chat(chat_input, session=self.sess, gen_config=gen_config)
response = self.sess.response.text
if image is not None:
chat_history.append((gr.HTML(f'{message}\n\n<img src="./file={save_image_path}" style="width: 200px; height: auto; display: inline-block;">'), response))
else:
chat_history.append((message, response))
logging.info(f"generated text = \n{response}")
return "", None, chat_history, chat_state