import torch

from datasets import load_dataset, load_from_disk

from peft import LoraConfig, get_peft_model

from PIL import Image

from transformers import AutoModelForVision2Seq, AutoProcessor, Trainer, TrainingArguments, BitsAndBytesConfig

import torchvision.transforms as transforms

device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "med_tongue_vision-zh_V0.1"

Here we skip some special modules that can't be quantized properly bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, )

Simply take-off the quantization_config arg if you want to load the original model model = AutoModelForVision2Seq.from_pretrained( model_id, quantization_config=bnb_config, torch_dtype=torch.float16, device_map="auto") print(model)

processor = AutoProcessor.from_pretrained(model_id)

messages = [ {

"role": "user",

"content": [

{"type": "text", "text": sys_prompt},

{"type": "image"},

{"type": "text", "text": "告诉我图片中的舌象指标有哪些"}

]

}

]

text = processor.apply_chat_template(messages, add_generation_prompt=True)

inputs = processor(text=[text.strip()], images=[image], return_tensors="pt", padding=True)

generated_ids = model.generate(**inputs, max_new_tokens=512)

generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True)

===== 结果 ====================

客户提问: 请描述这张图片中的舌象细分类别。

客户舌图: <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=494x521 at 0x7F5D32BDBC10>

舌图15类标签: 绛_舌胖_有齿痕_有裂纹_有点刺_无瘀斑_无瘀点_无老嫩_无歪斜_黄_厚苔_无腐苔_有腻苔_润_无剥脱

大模型识别结果: ['淡红_舌胖_有齿痕_有裂纹_有点刺_无瘀斑_无瘀点_无老嫩_无歪斜_白_厚苔_无腐苔_有腻苔_润_无剥脱']

这里呈现的模型,是一个精度低的模型的视觉多模态模型的能力演示,

基于更大规模的舌诊数据集的高精度模型需要联系:[email protected]

Downloads last month
46
Safetensors
Model size
8.4B params
Tensor type
BF16
·
Inference API
Inference API (serverless) does not yet support transformers models for this pipeline type.