File size: 6,697 Bytes
8f65667
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from flask import Flask, request, jsonify, render_template
import torch
from torchvision import transforms
from PIL import Image
import os
import torch.nn as nn
import timm
from torchvision.models import swin_t, Swin_T_Weights, vit_b_16, ViT_B_16_Weights
from transformers import GPT2LMHeadModel, GPT2Tokenizer

app = Flask(__name__)

# Set up directories for uploads and models
UPLOAD_FOLDER = os.path.join('static', 'uploads')
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the LLM model and tokenizer
model = GPT2LMHeadModel.from_pretrained('models\\LLM').to(device)
tokenizer = GPT2Tokenizer.from_pretrained('models\\LLM')
separator_token = tokenizer.eos_token  # Separator token for the model

# Define and load the pre-trained Swin models

# Gastrointestinal Model (4 classes: Diverticulosis, Neoplasm, Peritonitis, Ureters)
gastrointestinal_classes = ['Diverticulosis', 'Neoplasm', 'Peritonitis', 'Ureters']
gastrointestinal_model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
gastrointestinal_model.head = nn.Linear(gastrointestinal_model.head.in_features, len(gastrointestinal_classes))
gastrointestinal_model = gastrointestinal_model.to(device)
gastrointestinal_model.load_state_dict(torch.load('models\\gastrointestinal_model_swin.pth', map_location=device, weights_only=True), strict=False)
gastrointestinal_model.eval()

# Chest CT Model (4 classes: Adenocarcinoma, Large cell carcinoma, Normal, Squamous cell carcinoma)
chest_ct_classes = ['Adenocarcinoma', 'Large Cell Carcinoma', 'Normal', 'Squamous Cell Carcinoma']
chest_ct_model = swin_t(weights=Swin_T_Weights.IMAGENET1K_V1)
chest_ct_model.head = nn.Linear(chest_ct_model.head.in_features, len(chest_ct_classes))
chest_ct_model = chest_ct_model.to(device)
chest_ct_model.load_state_dict(torch.load('models\\best_model.pth', map_location=device, weights_only=True), strict=False)
chest_ct_model.eval()

# Chest X-ray Model (2 classes: Normal, Pneumonia)
chest_xray_classes = ['Normal', 'Pneumonia']
chest_xray_model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
chest_xray_model.heads.head = nn.Linear(chest_xray_model.heads.head.in_features, len(chest_xray_classes))
chest_xray_model = chest_xray_model.to(device)
chest_xray_model.load_state_dict(torch.load('models\\best_model_vit_chest_xray.pth', map_location=device, weights_only=True), strict=False)
chest_xray_model.eval()

# Image transformation (same for all models)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


# Helper function to load and transform images
def process_image(image_path):
    image = Image.open(image_path).convert('RGB')
    return transform(image).unsqueeze(0).to(device)

# LLM helper function to generate answers
def generate_answer(question, max_length=1024):
    model.eval()  # Set the model to evaluation mode
    input_text = question + separator_token
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)

    output = model.generate(input_ids, max_length=max_length, pad_token_id=tokenizer.eos_token_id)
    answer = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
    return answer

# Prediction routes for each model
@app.route('/predict_gastrointestinal', methods=['POST'])
def predict_gastrointestinal():
    if 'file' not in request.files:
        return jsonify({"error": "No file uploaded"}), 400
    
    file = request.files['file']
    file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
    file.save(file_path)

    # Preprocess the image
    image_tensor = process_image(file_path)

    # Make prediction using the gastrointestinal model
    with torch.no_grad():
        output = gastrointestinal_model(image_tensor)
        
        # Ensure the output tensor has the right shape and handle it
        
        # If the output has extra dimensions, flatten it
        if len(output.shape) > 2:
            output = output.view(output.size(0), -1)

        # Check if output is for a batch or single sample
        if output.size(0) != 1:
            return jsonify({"error": "Unexpected output size"}), 500
        
        # Get the predicted class (ensure it's scalar)
        _, predicted = torch.max(output, 1)
        predicted_class = gastrointestinal_classes[predicted.item()]

    return jsonify({'prediction': predicted_class})


@app.route('/predict_chest_ct', methods=['POST'])
def predict_chest_ct():
    if 'file' not in request.files:
        return jsonify({"error": "No file uploaded"}), 400
    
    file = request.files['file']
    file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
    file.save(file_path)

    # Preprocess the image
    image_tensor = process_image(file_path)

    # Make prediction using the chest CT model
    with torch.no_grad():
        output = chest_ct_model(image_tensor)
        _, predicted = torch.max(output, 1)
        predicted_class = chest_ct_classes[predicted.item()]

    return jsonify({'prediction': predicted_class})


@app.route('/predict_chest_xray', methods=['POST'])
def predict_chest_xray():
    if 'file' not in request.files:
        return jsonify({"error": "No file uploaded"}), 400
    
    file = request.files['file']
    file_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
    file.save(file_path)

    # Preprocess the image
    image_tensor = process_image(file_path)

    # Make prediction using the chest X-ray model
    with torch.no_grad():
        output = chest_xray_model(image_tensor)
        _, predicted = torch.max(output, 1)
        predicted_class = chest_xray_classes[predicted.item()]

    return jsonify({'prediction': predicted_class})


# New LLM route for asking questions
@app.route('/ask_llm', methods=['POST'])
def ask_llm():
    user_question = request.json.get('question', None)
    
    if not user_question:
        return jsonify({"error": "No question provided"}), 400
    
    try:
        # Generate answer using the fine-tuned GPT-2 model
        answer = generate_answer(user_question)
    except Exception as e:
        return jsonify({"error": f"An error occurred: {str(e)}"}), 500

    return jsonify({'answer': answer})


# Main route for the homepage
@app.route('/')
def index():
    return render_template('index.html')

if __name__ == "__main__":
    app.run(debug=True)