ardavey commited on
Commit
717b6e0
1 Parent(s): 9335b96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -7
app.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
4
  from PIL import Image
5
 
6
- # Load model and processor
7
  model_id = "brucewayne0459/paligemma_derm"
8
  processor = AutoProcessor.from_pretrained(model_id)
9
  model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
@@ -59,11 +58,9 @@ elif uploaded_file:
59
  st.error(f"Error loading image: {str(e)}")
60
  input_image = None
61
 
62
- # Display and process the image
63
  with col2:
64
  if input_image:
65
  try:
66
- # Display the uploaded or captured image
67
  resized_image = input_image.resize((300, 300))
68
  st.image(resized_image, caption="Selected Image (300x300)", use_container_width=True)
69
 
@@ -72,7 +69,6 @@ with col2:
72
  processed_image = input_image.resize(max_size)
73
 
74
  with st.spinner("Processing..."):
75
- # Prepare inputs for the model
76
  inputs = processor(
77
  text=prompt,
78
  images=processed_image,
@@ -80,12 +76,10 @@ with col2:
80
  padding="longest"
81
  ).to(device)
82
 
83
- # Generate output from the model
84
- default_max_tokens = 50 # Default value for max tokens
85
  with torch.no_grad():
86
  outputs = model.generate(**inputs, max_new_tokens=default_max_tokens)
87
 
88
- # Decode and clean the output
89
  decoded_output = processor.decode(outputs[0], skip_special_tokens=True)
90
  if prompt in decoded_output:
91
  decoded_output = decoded_output.replace(prompt, "").strip()
 
3
  from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
4
  from PIL import Image
5
 
 
6
  model_id = "brucewayne0459/paligemma_derm"
7
  processor = AutoProcessor.from_pretrained(model_id)
8
  model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
 
58
  st.error(f"Error loading image: {str(e)}")
59
  input_image = None
60
 
 
61
  with col2:
62
  if input_image:
63
  try:
 
64
  resized_image = input_image.resize((300, 300))
65
  st.image(resized_image, caption="Selected Image (300x300)", use_container_width=True)
66
 
 
69
  processed_image = input_image.resize(max_size)
70
 
71
  with st.spinner("Processing..."):
 
72
  inputs = processor(
73
  text=prompt,
74
  images=processed_image,
 
76
  padding="longest"
77
  ).to(device)
78
 
79
+ default_max_tokens = 50
 
80
  with torch.no_grad():
81
  outputs = model.generate(**inputs, max_new_tokens=default_max_tokens)
82
 
 
83
  decoded_output = processor.decode(outputs[0], skip_special_tokens=True)
84
  if prompt in decoded_output:
85
  decoded_output = decoded_output.replace(prompt, "").strip()