ardavey commited on
Commit
5a6aade
1 Parent(s): 17f137b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
4
+ from PIL import Image
5
+
6
+ # Load the model and processor
7
+ model_id = "brucewayne0459/paligemma_derm"
8
+ processor = AutoProcessor.from_pretrained(model_id)
9
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, device_map={"": 0})
10
+ model.eval()
11
+
12
+ # Set device
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ model.to(device)
15
+
16
+ # Streamlit app
17
+ st.title("Skin Condition Identifier")
18
+ st.write("Upload an image and provide a text prompt to identify the skin condition.")
19
+
20
+ # File uploader for image
21
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
22
+
23
+ # Text input for prompt
24
+ input_text = st.text_input("Enter your prompt:", "Identify the skin condition?")
25
+
26
+ # Process and display the result when the button is clicked
27
+ if uploaded_file is not None and st.button("Analyze"):
28
+ try:
29
+ # Open the uploaded image
30
+ input_image = Image.open(uploaded_file).convert("RGB")
31
+ st.image(input_image, caption="Uploaded Image", use_column_width=True)
32
+
33
+ # Prepare inputs
34
+ inputs = processor(
35
+ text=input_text,
36
+ images=input_image,
37
+ return_tensors="pt",
38
+ padding="longest"
39
+ ).to(device)
40
+
41
+ # Generate output
42
+ max_new_tokens = 50
43
+ with torch.no_grad():
44
+ outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
45
+
46
+ # Decode output
47
+ decoded_output = processor.decode(outputs[0], skip_special_tokens=True)
48
+
49
+ # Display result
50
+ st.success("Analysis Complete!")
51
+ st.write("**Model Output:**", decoded_output)
52
+ except Exception as e:
53
+ st.error(f"Error: {str(e)}")