Ayamohamed commited on
Commit
dda41bd
·
verified ·
1 Parent(s): 30784af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -5,31 +5,29 @@ from PIL import Image
5
  from huggingface_hub import hf_hub_download
6
  import gradio as gr
7
 
 
8
  model_path = hf_hub_download(repo_id="Ayamohamed/DiaClassification", filename="dia_none_classifier_full.pth")
9
- model = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False)
10
- model.eval()
 
 
11
 
12
  transform = transforms.Compose([
13
  transforms.Resize((224, 224)),
14
  transforms.ToTensor(),
15
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
16
  ])
17
-
18
-
19
- def predict(image):
20
  try:
21
- print("Received image:", image)
22
- image = transform(image).unsqueeze(0)
23
- print("Transformed image shape:", image.shape)
24
-
25
- # Model inference
26
  with torch.no_grad():
27
- output = model(image)
28
  print("Model output:", output)
29
  class_idx = torch.argmax(output, dim=1).item()
30
 
31
  return "Diagram" if class_idx == 0 else "Not Diagram"
32
-
33
  except Exception as e:
34
  print("Error during prediction:", str(e))
35
  return f"Prediction Error: {str(e)}"
 
5
  from huggingface_hub import hf_hub_download
6
  import gradio as gr
7
 
8
+ # Download model from Hugging Face Hub
9
  model_path = hf_hub_download(repo_id="Ayamohamed/DiaClassification", filename="dia_none_classifier_full.pth")
10
+
11
+ # Load model
12
+ model_hg = torch.load(model_path)
13
+ model_hg.eval()
14
 
15
  transform = transforms.Compose([
16
  transforms.Resize((224, 224)),
17
  transforms.ToTensor(),
18
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
19
  ])
20
+ def predict(image_path):
21
+
 
22
  try:
23
+ image = Image.open(image_path).convert("RGB")
24
+ image = transform(image).unsqueeze(0)
 
 
 
25
  with torch.no_grad():
26
+ output = model_hg(image)
27
  print("Model output:", output)
28
  class_idx = torch.argmax(output, dim=1).item()
29
 
30
  return "Diagram" if class_idx == 0 else "Not Diagram"
 
31
  except Exception as e:
32
  print("Error during prediction:", str(e))
33
  return f"Prediction Error: {str(e)}"