File size: 1,977 Bytes
916bdb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
---
language: en
license: apache-2.0
pipeline_tag: text-generation
base_model: t5-small
library_name: transformers

widget:
  - text: "A 35-year-old female presents with a 2-week history of persistent cough..."
---

# Medical Generation Model

## Overview

This repository contains a fine-tuned T5 model designed to generate medical diagnoses and treatment recommendations. The model was trained on clinical scenarios to provide accurate and contextually relevant medical outputs based on input prompts.

## Model Details

- **Model Type**: T5
- **Tokenizer**: T5 tokenizer
- **Training Data**: Clinical scenarios and medical texts

## Installation

To use this model, install the required libraries with `pip`:

```bash
pip install transformers
pip install tensorflow

# Load the fine-tuned model and tokenizer
from transformers import T5Tokenizer, TFT5ForConditionalGeneration

model_id = "Ra-Is/medical-gen-small"
model = TFT5ForConditionalGeneration.from_pretrained(model_id)
tokenizer = T5Tokenizer.from_pretrained(model_id)

# Prepare a sample input prompt
input_prompt = ("A 35-year-old female presents with a 2-week history of "
                "persistent cough, shortness of breath, and fatigue. She has "
                "a history of asthma and has recently been exposed to a sick "
                "family member with a respiratory infection. Chest X-ray shows "
                "bilateral infiltrates. What is the likely diagnosis, and what "
                "should be the treatment?")

# Tokenize the input
input_ids = tokenizer(input_prompt, return_tensors="tf").input_ids

# Generate the output (diagnosis)
outputs = model.generate(
        input_ids,
        max_length=512,
        num_beams=5,
        temperature=1,
        top_k=50,
        top_p=0.9,
        do_sample=True,  # Enable sampling
        early_stopping=True
    )

# Decode and print the output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)