Miking98's picture
Update README.md
f0bc24c verified
---
license: cc-by-nc-4.0
library_name: mamba-tiny-16384-clmbr
tags:
- healthcare
- medical
extra_gated_prompt: "You agree to all terms outlined in 'The EHRSHOT Credentialed Health Data License' (see https://shahlab.stanford.edu/ehrshot_license). Access requires a verified CITI training certificate using the same process outlined by PhysioNet (see https://physionet.org/about/citi-course/). Please complete the 'Data or Specimens Only Research' course and please provide proof via the verification URL, which takes the form https://www.citiprogram.org/verify/?XXXXXX. You agree to not use the model to conduct experiments that cause harm to human subjects."
extra_gated_fields:
Full Name: text
Email: text
Affiliation: text
CITI Certification Verification URL: text
I agree to all terms outlined in 'The EHRSHOT Credentialed Health Data License': checkbox
I agree to use this model for non-commercial use ONLY: checkbox
---
# mamba-tiny-16384-clmbr
This is a **mamba** model with context length **16384** with **121100544** parameters from the [Context Clues paper](https://arxiv.org/abs/2412.16178).
It is a foundation model trained from scratch on the structured data within 2.57 million deidentified EHRs from Stanford Medicine.
As input, this model expects a sequence of coded medical events that have been mapped to Standard Concepts within the [OMOP-CDM vocabulary](https://ohdsi.github.io/CommonDataModel/index.html). As output, the model can generate either (a) synthetic future timelines or (b) a vector representation of a patient which can then be used for downstream prediction tasks.
## Usage
First, install the `hf_ehr` package:
```bash
pip install transformers torch hf_ehr
```
Second, run this Python script to do inference on a patient representation:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from hf_ehr.data.tokenization import CLMBRTokenizer
from hf_ehr.config import Event
from typing import List, Dict
import torch
####################################
# 1. Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("StanfordShahLab/mamba-tiny-16384-clmbr")
tokenizer = AutoTokenizer.from_pretrained("StanfordShahLab/mamba-tiny-16384-clmbr")
####################################
# 2. Define patient as sequence of `Event` objects. Only `code` is required.
patient: List[Event] = [
Event(code='SNOMED/3950001', value=None, unit=None, start=None, end=None, omop_table=None),
Event(code='Gender/F', value=None, unit=None, start=None, end=None, omop_table=None),
Event(code='Ethnicity/Hispanic', value=None, unit=None, start=None, end=None, omop_table=None),
Event(code='SNOMED/609040007', value=None, unit=None, start=None, end=None, omop_table=None),
Event(code='LOINC/2236-8', value=-3.0, unit=None, start=None, end=None, omop_table=None),
Event(code='SNOMED/12199005', value=26.3, unit=None, start=None, end=None, omop_table=None),
]
####################################
# 3. Tokenize patient
batch: Dict[str, torch.Tensor] = tokenizer([ patient ], add_special_tokens=True, return_tensors='pt')
# > batch = {
# 'input_ids': tensor([[ 5, 0, 7, 9, 27, 2049, 6557, 22433, 1]]),
# 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0]]),
# 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])
# }
textual_tokens: List[str] = tokenizer.convert_events_to_tokens(patient)
# > textual_tokens = ['SNOMED/3950001', 'Gender/F', 'Ethnicity/Hispanic', 'SNOMED/609040007', 'LOINC/2236-8 || None || -1.7976931348623157e+308 - 4.0', 'SNOMED/12199005 || None || 26.0 - 28.899999618530273']
####################################
# 4. Run model
logits = model(**batch).logits
# > logits.shape = torch.Size([1, 9, 39818])
####################################
# 5. Get patient representation for finetuning (usually we choose the last token's logits)
representation = logits[:, -1, :]
# > representation.shape = torch.Size([1, 39818])
```
## Model Details
- **Developed by:** Shah lab @ Stanford University
- **Funded by:** Stanford Healthcare
- **Shared by:** Shah lab @ Stanford University
- **Model type:** mamba
- **Languages:** Electronic health record codes (as standardized by the [OMOP-CDM](https://ohdsi.github.io/CommonDataModel/index.html))
- **License:** CC-BY NC 4.0
- **Finetuned from model:** N/A -- trained from scratch
## Uses
This model is intended to generate representations for patients based on the structured data within their electronic health record.
These representations can then be used for downstream tasks such as predicting diagnoses, detecting anomalies, or doing propensity score matching for causal inference.
### Direct Use
You will likely want to tune the model for your downstream use case.
### Out-of-Scope Use
This model is for research purposes only. It is not for use in any real-world decision making that impacts patients, providers, or hospital operations.
## Bias, Risks, and Limitations
This model was trained on a corpus of 2 billion tokens sourced from 2.57 million patients from Stanford Medicine.
The model will thus reflect the patterns of how care is delivered at Stanford Medicine, in addition to the racial and socioeconomic makeup of Stanford Medicine's patient base.
This model may not generalize well to other hospitals and demographic mixes.
While this is technically a generative model, we have not tested its generative abilities and thus do not anticipate it being used to generate synthetic EHR records.
We aim to explore its generative abilities in future work.
## Training Details
Full training details are provided in our accompanying paper, [Context Clues](https://arxiv.org/abs/2412.16178).
### Training Data
The model is trained on 2 billion tokens sourced from 2.57 million patients from the [Stanford Medicine Research Data Repository (STARR)](https://academic.oup.com/jamiaopen/article/6/3/ooad054/7236015),
which contains structured EHR data from both Stanford Health Care (primarily adult care) and Lucile Packard Children’s Hospital (primarily pediatric care).
The dataset contains only structured data (i.e. no clinical text or images) and covers demographics (e.g. age, sex, race), diagnoses, procedures, laboratory results, medication prescriptions, and other coded clinical observations.
The data is formatted according to the [Observational Medical Outcomes Partnership Common Data Model (OMOP-CDM)](https://ohdsi.github.io/CommonDataModel/cdm53.html).
All data that we work with is deidentified.
### Training Procedure
We train our model using an autoregressive next code prediction objective, i.e. predict the next code in a patient's timeline given their previous codes.
## Citation
**BibTeX:**
```
@article{wornow2024contextclues,
title={Context Clues: Evaluating Long Context Models for Clinical Prediction Tasks on EHRs},
author={Michael Wornow and Suhana Bedi and Miguel Angel Fuentes Hernandez and Ethan Steinberg and Jason Alan Fries and Christopher Ré and Sanmi Koyejo and Nigam H. Shah},
year={2024},
eprint={2412.16178},
url={https://arxiv.org/abs/2412.16178},
}
```
## Model Card Authors
Michael Wornow, Suhana Bedi, Ethan Steinberg