ardavey commited on
Commit
9d6fdf7
1 Parent(s): 43cfaa4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -39
app.py CHANGED
@@ -2,9 +2,6 @@ import streamlit as st
2
  import torch
3
  from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
4
  from PIL import Image
5
- from transformers import AutoProcessor, AutoModelForImageTextToText
6
-
7
-
8
 
9
  # Load the model and processor
10
  model_id = "brucewayne0459/paligemma_derm"
@@ -38,7 +35,7 @@ st.markdown(
38
 
39
  # Streamlit app title and instructions
40
  st.title("Skin Condition Identifier")
41
- st.write("Upload an image and provide a text prompt to identify the skin condition.")
42
 
43
  # Column layout for input and display
44
  col1, col2 = st.columns([3, 2])
@@ -46,51 +43,41 @@ col1, col2 = st.columns([3, 2])
46
  with col1:
47
  # File uploader for image
48
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
49
- prompt = "Identify the skin condition?"
50
-
51
- # Text input for prompt
52
- input_text = st.text_input("Enter your prompt:", prompt)
53
-
54
  with col2:
55
- # Display uploaded image (if any)
56
  if uploaded_file:
 
57
  input_image = Image.open(uploaded_file).convert("RGB")
58
-
59
- # Resize image for display (300x300 pixels)
60
  resized_image = input_image.resize((300, 300))
61
-
62
- # Display the resized image
63
  st.image(resized_image, caption="Uploaded Image (300x300)", use_container_width=True)
64
 
65
- # Process and display the result when the button is clicked
66
- if uploaded_file and st.button("Analyze"):
67
- if not input_text.strip():
68
- st.error("Please provide a valid prompt!")
69
- else:
 
70
  try:
71
- # Resize image for processing (512x512 pixels)
72
- max_size = (512, 512)
73
- input_image = input_image.resize(max_size)
74
-
75
  # Prepare inputs
76
- with st.spinner("Processing..."):
77
- inputs = processor(
78
- text=input_text,
79
- images=input_image,
80
- return_tensors="pt",
81
- padding="longest"
82
- ).to(device)
83
-
84
- # Generate output with default max_new_tokens
85
- default_max_tokens = 50 # Set a default value for max tokens
86
- with torch.no_grad():
87
- outputs = model.generate(**inputs, max_new_tokens=default_max_tokens)
88
-
89
- # Decode output
90
- decoded_output = processor.decode(outputs[0], skip_special_tokens=True)
91
-
92
  # Display result
93
  st.success("Analysis Complete!")
94
  st.write("**Model Output:**", decoded_output)
 
95
  except Exception as e:
96
  st.error(f"Error: {str(e)}")
 
2
  import torch
3
  from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
4
  from PIL import Image
 
 
 
5
 
6
  # Load the model and processor
7
  model_id = "brucewayne0459/paligemma_derm"
 
35
 
36
  # Streamlit app title and instructions
37
  st.title("Skin Condition Identifier")
38
+ st.write("Upload an image and provide a custom prompt to identify the skin condition.")
39
 
40
  # Column layout for input and display
41
  col1, col2 = st.columns([3, 2])
 
43
  with col1:
44
  # File uploader for image
45
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
46
+ prompt = st.text_input("Enter your prompt:", "Identify the skin condition?")
47
+
 
 
 
48
  with col2:
 
49
  if uploaded_file:
50
+ # Open and resize the uploaded image
51
  input_image = Image.open(uploaded_file).convert("RGB")
 
 
52
  resized_image = input_image.resize((300, 300))
 
 
53
  st.image(resized_image, caption="Uploaded Image (300x300)", use_container_width=True)
54
 
55
+ # Resize image for processing (512x512 pixels)
56
+ max_size = (512, 512)
57
+ processed_image = input_image.resize(max_size)
58
+
59
+ # Predict automatically when the image is uploaded or the prompt changes
60
+ with st.spinner("Processing..."):
61
  try:
 
 
 
 
62
  # Prepare inputs
63
+ inputs = processor(
64
+ text=prompt,
65
+ images=processed_image,
66
+ return_tensors="pt",
67
+ padding="longest"
68
+ ).to(device)
69
+
70
+ # Generate output
71
+ default_max_tokens = 50 # Set a default value for max tokens
72
+ with torch.no_grad():
73
+ outputs = model.generate(**inputs, max_new_tokens=default_max_tokens)
74
+
75
+ # Decode output
76
+ decoded_output = processor.decode(outputs[0], skip_special_tokens=True)
77
+
 
78
  # Display result
79
  st.success("Analysis Complete!")
80
  st.write("**Model Output:**", decoded_output)
81
+
82
  except Exception as e:
83
  st.error(f"Error: {str(e)}")