|
--- |
|
language: en |
|
license: apache-2.0 |
|
pipeline_tag: text-generation |
|
base_model: t5-small |
|
library_name: transformers |
|
|
|
widget: |
|
- text: "A 35-year-old female presents with a 2-week history of persistent cough..." |
|
--- |
|
|
|
# Medical Generation Model |
|
|
|
## Overview |
|
|
|
This repository contains a fine-tuned T5 model designed to generate medical diagnoses and treatment recommendations. The model was trained on clinical scenarios to provide accurate and contextually relevant medical outputs based on input prompts. |
|
|
|
## Model Details |
|
|
|
- **Model Type**: T5 |
|
- **Model Size**: small |
|
- **Tokenizer**: T5 tokenizer |
|
- **Training Data**: Clinical scenarios and medical texts |
|
|
|
## Installation |
|
|
|
To use this model, install the required libraries with `pip`: |
|
|
|
```bash |
|
pip install transformers |
|
pip install tensorflow |
|
|
|
# Load the fine-tuned model and tokenizer |
|
from transformers import T5Tokenizer, TFT5ForConditionalGeneration |
|
|
|
model_id = "Ra-Is/medical-gen-small" |
|
model = TFT5ForConditionalGeneration.from_pretrained(model_id) |
|
tokenizer = T5Tokenizer.from_pretrained(model_id) |
|
|
|
# Prepare a sample input prompt |
|
input_prompt = ("A 35-year-old female presents with a 2-week history of " |
|
"persistent cough, shortness of breath, and fatigue. She has " |
|
"a history of asthma and has recently been exposed to a sick " |
|
"family member with a respiratory infection. Chest X-ray shows " |
|
"bilateral infiltrates. What is the likely diagnosis, and what " |
|
"should be the treatment?") |
|
|
|
# Tokenize the input |
|
input_ids = tokenizer(input_prompt, return_tensors="tf").input_ids |
|
|
|
# Generate the output (diagnosis) |
|
outputs = model.generate( |
|
input_ids, |
|
max_length=512, |
|
num_beams=5, |
|
temperature=1, |
|
top_k=50, |
|
top_p=0.9, |
|
do_sample=True, # Enable sampling |
|
early_stopping=True |
|
) |
|
|
|
# Decode and print the output |
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
print(generated_text) |
|
|