Norakneath commited on
Commit
9d47a47
·
verified ·
1 Parent(s): 4df8e36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -17,7 +17,7 @@ def get_class_color(class_id):
17
  return CLASS_COLORS[class_id]
18
 
19
  # Class Names (Modify based on your dataset)
20
- CLASS_NAMES = {0: "Text Line", 1: "Heading", 2: "Signature"} # Example labels
21
 
22
  def detect_text_lines(image):
23
  """Detects text lines with two different confidence and IoU thresholds."""
@@ -38,8 +38,8 @@ def detect_text_lines(image):
38
 
39
  # Run YOLO text detection with specific thresholds
40
  results = model.predict(image, conf=conf, iou=iou, device="cpu")
41
- detected_boxes = results[0].boxes.xyxy.tolist()
42
- class_ids = results[0].boxes.cls.tolist()
43
  detected_boxes = [list(map(int, box)) for box in detected_boxes]
44
 
45
  # Draw bounding boxes on the image
@@ -52,7 +52,7 @@ def detect_text_lines(image):
52
  font = ImageFont.load_default() # Fallback in case font is missing
53
 
54
  for idx, (x1, y1, x2, y2) in enumerate(detected_boxes):
55
- class_id = int(class_ids[idx])
56
  color = get_class_color(class_id)
57
  class_name = CLASS_NAMES.get(class_id, f"Class {class_id}")
58
 
 
17
  return CLASS_COLORS[class_id]
18
 
19
  # Class Names (Modify based on your dataset)
20
+ CLASS_NAMES = {0: "Text", 1: "Non-Text"} # Example labels
21
 
22
  def detect_text_lines(image):
23
  """Detects text lines with two different confidence and IoU thresholds."""
 
38
 
39
  # Run YOLO text detection with specific thresholds
40
  results = model.predict(image, conf=conf, iou=iou, device="cpu")
41
+ detected_boxes = results[0].boxes.xyxy.tolist() if hasattr(results[0].boxes, 'xyxy') else []
42
+ class_ids = results[0].boxes.cls.tolist() if hasattr(results[0].boxes, 'cls') else []
43
  detected_boxes = [list(map(int, box)) for box in detected_boxes]
44
 
45
  # Draw bounding boxes on the image
 
52
  font = ImageFont.load_default() # Fallback in case font is missing
53
 
54
  for idx, (x1, y1, x2, y2) in enumerate(detected_boxes):
55
+ class_id = int(class_ids[idx]) if idx < len(class_ids) else -1
56
  color = get_class_color(class_id)
57
  class_name = CLASS_NAMES.get(class_id, f"Class {class_id}")
58