Spaces:
Runtime error
Runtime error
import argparse | |
import time | |
from threading import Thread | |
from PIL import Image | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer | |
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
import dataclasses | |
from enum import auto, Enum | |
from typing import List, Tuple, Any | |
from minigpt4.common.registry import registry | |
class SeparatorStyle(Enum): | |
"""Different separator style.""" | |
SINGLE = auto() | |
TWO = auto() | |
class Conversation: | |
"""A class that keeps all conversation history.""" | |
system: str | |
roles: List[str] | |
messages: List[List[str]] | |
offset: int | |
# system_img: List[Image.Image] = [] | |
sep_style: SeparatorStyle = SeparatorStyle.SINGLE | |
sep: str = "###" | |
sep2: str = None | |
skip_next: bool = False | |
conv_id: Any = None | |
def get_prompt(self): | |
if self.sep_style == SeparatorStyle.SINGLE: | |
ret = self.system + self.sep | |
for role, message in self.messages: | |
if message: | |
ret += role + message + self.sep | |
else: | |
ret += role | |
return ret | |
elif self.sep_style == SeparatorStyle.TWO: | |
seps = [self.sep, self.sep2] | |
ret = self.system + seps[0] | |
for i, (role, message) in enumerate(self.messages): | |
if message: | |
ret += role + message + seps[i % 2] | |
else: | |
ret += role | |
return ret | |
else: | |
raise ValueError(f"Invalid style: {self.sep_style}") | |
def append_message(self, role, message): | |
self.messages.append([role, message]) | |
def to_gradio_chatbot(self): | |
ret = [] | |
for i, (role, msg) in enumerate(self.messages[self.offset:]): | |
if i % 2 == 0: | |
ret.append([msg, None]) | |
else: | |
ret[-1][-1] = msg | |
return ret | |
def copy(self): | |
return Conversation( | |
system=self.system, | |
# system_img=self.system_img, | |
roles=self.roles, | |
messages=[[x, y] for x, y in self.messages], | |
offset=self.offset, | |
sep_style=self.sep_style, | |
sep=self.sep, | |
sep2=self.sep2, | |
conv_id=self.conv_id) | |
def dict(self): | |
return { | |
"system": self.system, | |
# "system_img": self.system_img, | |
"roles": self.roles, | |
"messages": self.messages, | |
"offset": self.offset, | |
"sep": self.sep, | |
"sep2": self.sep2, | |
"conv_id": self.conv_id, | |
} | |
class StoppingCriteriaSub(StoppingCriteria): | |
def __init__(self, stops=[], encounters=1): | |
super().__init__() | |
self.stops = stops | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): | |
for stop in self.stops: | |
if torch.all((stop == input_ids[0][-len(stop):])).item(): | |
return True | |
return False | |
CONV_VISION_Vicuna0 = Conversation( | |
system="Give the following medical scan image: <Img>ImageContent</Img>. " | |
"You will be able to see the image once I provide it to you. Please answer as an AI assisted doctor who is more specialised in Radiology a person who evaluates MRI scan and CT scan images to verify traces of tumor in the scanned image and pointing them. I will provide you with the patients MRI and CT scan image, and your task is to use the latest artificial intelligence tools such as medical imaging software and other machine learning programs in order to diagnose if there is tumor present or not and generate report according to the findings from 50 to 100 words along with specifying the location of the tumor by highlighting it. You should also incorporate traditional methods such as medical questioning answering, drug prediction using the disease described, disease symptom analyser, and also provide remedies, diet plan and act like a personal health advisor and answer my queries accurately, informatively and understandably.", | |
roles=("Human: ", "Assistant: "), | |
messages=[], | |
offset=2, | |
sep_style=SeparatorStyle.SINGLE, | |
sep="###", | |
) | |
CONV_VISION_LLama2 = Conversation( | |
system="Give the following medical scan image: <Img>ImageContent</Img>. " | |
"You will be able to see the image once I provide it to you. Please answer as an AI assisted doctor who is more specialised in Radiology a person who evaluates MRI scan and CT scan images to verify traces of tumor in the scanned image and pointing them. I will provide you with the patients MRI and CT scan image, and your task is to use the latest artificial intelligence tools such as medical imaging software and other machine learning programs in order to diagnose if there is tumor present or not and generate report according to the findings from 50 to 100 words along with specifying the location of the tumor by highlighting it. You should also incorporate traditional methods such as medical questioning answering, drug prediction using the disease described, disease symptom analyser, and also provide remedies, diet plan and act like a personal health advisor and answer my queries accurately, informatively and understandably.", | |
roles=("<s>[INST] ", " [/INST] "), | |
messages=[], | |
offset=2, | |
sep_style=SeparatorStyle.SINGLE, | |
sep="", | |
) | |
class Chat: | |
def __init__(self, model, vis_processor, device='cuda:0', stopping_criteria=None): | |
self.device = device | |
self.model = model | |
self.vis_processor = vis_processor | |
if stopping_criteria is not None: | |
self.stopping_criteria = stopping_criteria | |
else: | |
stop_words_ids = [torch.tensor([2]).to(self.device)] | |
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) | |
def ask(self, text, conv): | |
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ | |
and conv.messages[-1][1][-6:] == '</Img>': # last message is image. | |
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) | |
else: | |
conv.append_message(conv.roles[0], text) | |
def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, | |
repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000): | |
conv.append_message(conv.roles[1], None) | |
embs = self.get_context_emb(conv, img_list) | |
current_max_len = embs.shape[1] + max_new_tokens | |
if current_max_len - max_length > 0: | |
print('Warning: The number of tokens in current conversation exceeds the max length. ' | |
'The model will not see the contexts outside the range.') | |
begin_idx = max(0, current_max_len - max_length) | |
embs = embs[:, begin_idx:] | |
generation_kwargs = dict( | |
inputs_embeds=embs, | |
max_new_tokens=max_new_tokens, | |
stopping_criteria=self.stopping_criteria, | |
num_beams=num_beams, | |
do_sample=True, | |
min_length=min_length, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
length_penalty=length_penalty, | |
temperature=float(temperature), | |
) | |
return generation_kwargs | |
def answer(self, conv, img_list, **kargs): | |
generation_dict = self.answer_prepare(conv, img_list, **kargs) | |
output_token = self.model.llama_model.generate(**generation_dict)[0] | |
output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True) | |
output_text = output_text.split('###')[0] # remove the stop sign '###' | |
output_text = output_text.split('Assistant:')[-1].strip() | |
conv.messages[-1][1] = output_text | |
return output_text, output_token.cpu().numpy() | |
def stream_answer(self, conv, img_list, **kargs): | |
generation_kwargs = self.answer_prepare(conv, img_list, **kargs) | |
streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True) | |
generation_kwargs['streamer'] = streamer | |
thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs) | |
thread.start() | |
return streamer | |
def encode_img(self, img_list): | |
image = img_list[0] | |
img_list.pop(0) | |
if isinstance(image, str): # is a image path | |
raw_image = Image.open(image).convert('RGB') | |
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) | |
elif isinstance(image, Image.Image): | |
raw_image = image | |
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) | |
elif isinstance(image, torch.Tensor): | |
if len(image.shape) == 3: | |
image = image.unsqueeze(0) | |
image = image.to(self.device) | |
image_emb, _ = self.model.encode_img(image) | |
img_list.append(image_emb) | |
def upload_img(self, image, conv, img_list): | |
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>") | |
img_list.append(image) | |
msg = "Received." | |
return msg | |
def get_context_emb(self, conv, img_list): | |
prompt = conv.get_prompt() | |
prompt_segs = prompt.split('<ImageHere>') | |
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." | |
seg_tokens = [ | |
self.model.llama_tokenizer( | |
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids | |
# only add bos to the first seg | |
for i, seg in enumerate(prompt_segs) | |
] | |
print('debug device: ', self.device) | |
print('debug model device: ', self.model.device) | |
seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens] | |
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] | |
mixed_embs = torch.cat(mixed_embs, dim=1) | |
return mixed_embs | |