import torch from model import IntentPredictModel from transformers import T5Tokenizer, GPT2LMHeadModel, GPT2Tokenizer from diffusers import StableDiffusionPipeline class Chat: def __init__( self, intent_predict_model: IntentPredictModel, intent_predict_tokenizer: T5Tokenizer, text_dialog_model: GPT2LMHeadModel, text_dialog_tokenizer: GPT2Tokenizer, text2image_model: StableDiffusionPipeline, device="cuda:0" ): self.intent_predict_model = intent_predict_model.to(device) self.intent_predict_tokenizer = intent_predict_tokenizer self.text_dialog_model = text_dialog_model.to(device) self.text_dialog_tokenizer = text_dialog_tokenizer self.text2image_model = text2image_model.to(device) self.device = device self.extra_prompt = {"human": ", facing the camera, photograph, highly detailed face, depth of field, moody light, style by Yasmin Albatoul, Harry Fayt, centered, extremely detailed, Nikon D850, award winning photography", "others": ", depth of field. bokeh. soft light. by Yasmin Albatoul, Harry Fayt. centered. extremely detailed. Nikon D850, (35mm|50mm|85mm). award winning photography."} self.human_words = ["man", "men", "woman", "women", "people", "person", "human", "male", "female", "boy", "girl", "child", "kid", "baby", "player"] self.negative_prompt="cartoon, anime, ugly, asian, (aged, white beard, black skin, wrinkle:1.1), (bad proportions, unnatural feature, incongruous feature:1.4), (blurry, un-sharp, fuzzy, un-detailed skin:1.2), (facial contortion, poorly drawn face, deformed iris, deformed pupils:1.3), (mutated hands and fingers:1.5), disconnected hands, disconnected limbs" self.context_for_intent = "" self.context_for_text_dialog = "" def intent_predict(self, context: str): context_encoded = self.intent_predict_tokenizer.encode_plus( text=context, add_special_tokens=True, truncation=True, max_length=512, return_attention_mask=True, return_tensors='pt' ) input_ids = context_encoded['input_ids'].to(self.device) attention_mask = context_encoded['attention_mask'].to(self.device) pred_logits = self.intent_predict_model(input_ids=input_ids, attention_mask=attention_mask).logits pred_label = torch.max(pred_logits, dim=1)[1] return True if pred_label else False def generate_response(self, context: str, share_photo: bool, num_beams: int): tokenizer = self.text_dialog_tokenizer tag_list = ["[UTT]", "[DST]"] # 文本回复以 [UTT] 开头, 图像描述以 [DST] 开头 tag_id_dic = {tag: tokenizer.convert_tokens_to_ids(tag) for tag in tag_list} tag = "[DST]" if share_photo else "[UTT]" bad_words = ["[UTT] [UTT]", "[UTT] [DST]", "[UTT] <|endoftext|>", "[DST] [UTT]", "[DST] [DST]", "[DST] <|endoftext|>"] input_ids = tokenizer.encode( context, add_special_tokens=False, return_tensors='pt' ) generated_ids = self.text_dialog_model.generate(input_ids.to(self.device), max_new_tokens=64, min_new_tokens=3, do_sample=False, num_beams=num_beams, length_penalty=0.7, num_beam_groups=5, no_repeat_ngram_size=3, bad_words_ids=tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids, forced_decoder_ids=[[input_ids.shape[-1], tag_id_dic[tag]]], # 指定生成的回复中第一个token始终是tag(因为generated_ids中包括input_ids, 所以是第input_ids.shape[-1]位) pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id) generated_tokens = tokenizer.convert_ids_to_tokens(generated_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True) end, i = 0, 0 for i, token in enumerate(generated_tokens): if i == 0: # 由于forced_decoder_ids的定义, generated_tokens第1个token必为tag, 故从第2个token开始 continue if token in tag_list: end = i break if end == 0 and i != 0: # 可能遇不到tag end = len(generated_tokens) response_tokens = generated_tokens[1:end] response_str = tokenizer.convert_tokens_to_string(response_tokens).lstrip() return response_str def respond(self, message, num_beams, text2image_seed, chat_history, chat_state): # process context if self.context_for_intent == "": self.context_for_intent += message else: self.context_for_intent += " [SEP] " + message self.context_for_text_dialog += "[UTT] " + message share_photo = self.intent_predict(self.context_for_intent) response = self.generate_response(self.context_for_text_dialog, share_photo, num_beams) if share_photo: print(f"Image Caption: {response}") type = "others" for human_word in self.human_words: if human_word in response: type = "human" break caption = response + self.extra_prompt[type] generator = torch.Generator(device=self.device).manual_seed(text2image_seed) image = self.text2image_model( prompt=caption, negative_prompt=self.negative_prompt, num_inference_steps=20, guidance_scale=7.5, generator=generator).images[0] save_image_path = f"generated_images/{response}.png" image.save(save_image_path) self.context_for_intent += " [SEP] " + response self.context_for_text_dialog += "[DST] " + response chat_history.append((message, (save_image_path, None))) else: print(f"Bot: {response}") self.context_for_intent += " [SEP] " + response self.context_for_text_dialog += "[UTT] " + response chat_history.append((message, response)) return "", chat_history, chat_state