import torch
from datasets import load_dataset, load_from_disk
from peft import LoraConfig, get_peft_model
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor, Trainer, TrainingArguments, BitsAndBytesConfig
import torchvision.transforms as transforms
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "med_tongue_vision-zh_V0.1"
Here we skip some special modules that can't be quantized properly bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, )
Simply take-off the quantization_config arg if you want to load the original model model = AutoModelForVision2Seq.from_pretrained( model_id, quantization_config=bnb_config, torch_dtype=torch.float16, device_map="auto") print(model)
processor = AutoProcessor.from_pretrained(model_id)
messages = [ {
"role": "user",
"content": [
{"type": "text", "text": sys_prompt},
{"type": "image"},
{"type": "text", "text": "告诉我图片中的舌象指标有哪些"}
]
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=[text.strip()], images=[image], return_tensors="pt", padding=True)
generated_ids = model.generate(**inputs, max_new_tokens=512)
generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True)
===== 结果 ====================
客户提问: 请描述这张图片中的舌象细分类别。
客户舌图: <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=494x521 at 0x7F5D32BDBC10>
舌图15类标签: 绛_舌胖_有齿痕_有裂纹_有点刺_无瘀斑_无瘀点_无老嫩_无歪斜_黄_厚苔_无腐苔_有腻苔_润_无剥脱
大模型识别结果: ['淡红_舌胖_有齿痕_有裂纹_有点刺_无瘀斑_无瘀点_无老嫩_无歪斜_白_厚苔_无腐苔_有腻苔_润_无剥脱']
这里呈现的模型,是一个精度低的模型的视觉多模态模型的能力演示,
基于更大规模的舌诊数据集的高精度模型需要联系:[email protected]
- Downloads last month
- 46