🎬 Get Started

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from peft import PeftModel, PeftConfig

# step 1: Setup constant
model_name = "StanfordAIMI/SRR-Mistral7b-finetuned"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# step 2: Load Processor and Model
config = PeftConfig.from_pretrained(model_name)
base_model_name_or_path = config.base_model_name_or_path
model = AutoModelForCausalLM.from_pretrained(base_model_name_or_path, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

model = PeftModel.from_pretrained(model, model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(device)
model.eval()

# step 3: Inference (example from MIMIC-CXR dataset)
input_text = "<|system|> You are a radiology expert.<|end|> <|user|>CHEST RADIOGRAPH PERFORMED ON ___  COMPARISON: Prior exam from ___.  CLINICAL HISTORY: Weakness, assess pneumonia.  FINDINGS: Frontal and lateral views of the chest were provided. Midline sternotomy wires are again noted. The heart is poorly assessed, though remains enlarged. There are at least small bilateral pleural effusions.  There may be mild interstitial edema. No pneumothorax. Bony structures are demineralized with kyphotic angulation in the lower T-spine again noted.  IMPRESSION: Limited exam with small bilateral effusions, cardiomegaly, and possible mild interstitial edema. <|end|> \n<|assistant|> Output: "
inputs = tokenizer(input_text, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
inputs["attention_mask"] = inputs["input_ids"].ne(tokenizer.pad_token_id)  # Add attention mask
input_ids = inputs['input_ids'].to(device)
attention_mask=inputs["attention_mask"].to(device)
generated_ids = model.generate(
    input_ids, attention_mask=attention_mask, max_new_tokens=286, min_new_tokens= 120, num_beams=5, early_stopping=True, max_length=None
    )[0]
decoded = tokenizer.decode(generated_ids, skip_special_tokens=True)

decoded = decoded.rsplit("Output:", 1)[-1].strip()

print(decoded)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support