ianpan commited on
Commit
39592a1
·
verified ·
1 Parent(s): 19db4ca

Upload model

Browse files
Files changed (1) hide show
  1. modeling.py +15 -4
modeling.py CHANGED
@@ -113,7 +113,7 @@ class CTCropModel(PreTrainedModel):
113
  fname: str = "",
114
  sort_by_instance_number: bool = False,
115
  exclude_invalid_dicoms: bool = False,
116
- ):
117
  attributes = [
118
  "pixel_array",
119
  "RescaleSlope",
@@ -281,7 +281,7 @@ class CTCropModel(PreTrainedModel):
281
  coords: torch.Tensor,
282
  buffer: float | tuple[float, float] = 0.05,
283
  empty_threshold: float = 1e-4,
284
- ):
285
  coords = coords.clone()
286
  empty = (coords < empty_threshold).all(dim=1)
287
  # assumes coords is a torch.Tensor of shape (N, 4) containing
@@ -352,7 +352,12 @@ class CTCropModel(PreTrainedModel):
352
  raw_hu: bool = False,
353
  remove_empty_slices: bool = False,
354
  add_buffer: float | tuple[float, float] | None = None,
355
- ) -> np.ndarray | tuple[np.ndarray, list[int]]:
 
 
 
 
 
356
  assert mode in ["2d", "3d"]
357
  if device is None:
358
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -393,5 +398,11 @@ class CTCropModel(PreTrainedModel):
393
  empty_indices = list(torch.where(empty)[0].cpu().numpy())
394
  print(f"removing {empty.sum()} empty slices ...")
395
  cropped = cropped[~empty.cpu().numpy()]
396
- return cropped, empty_indices
 
 
 
 
 
 
397
  return cropped
 
113
  fname: str = "",
114
  sort_by_instance_number: bool = False,
115
  exclude_invalid_dicoms: bool = False,
116
+ ) -> bool:
117
  attributes = [
118
  "pixel_array",
119
  "RescaleSlope",
 
281
  coords: torch.Tensor,
282
  buffer: float | tuple[float, float] = 0.05,
283
  empty_threshold: float = 1e-4,
284
+ ) -> torch.Tensor:
285
  coords = coords.clone()
286
  empty = (coords < empty_threshold).all(dim=1)
287
  # assumes coords is a torch.Tensor of shape (N, 4) containing
 
352
  raw_hu: bool = False,
353
  remove_empty_slices: bool = False,
354
  add_buffer: float | tuple[float, float] | None = None,
355
+ return_coords: bool = False,
356
+ ) -> (
357
+ np.ndarray
358
+ | tuple[np.ndarray, list[int]]
359
+ | tuple[np.ndarray, list[int], list[int]]
360
+ ):
361
  assert mode in ["2d", "3d"]
362
  if device is None:
363
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
398
  empty_indices = list(torch.where(empty)[0].cpu().numpy())
399
  print(f"removing {empty.sum()} empty slices ...")
400
  cropped = cropped[~empty.cpu().numpy()]
401
+ if not isinstance(cropped, tuple):
402
+ cropped = (cropped,)
403
+ cropped = cropped + (empty_indices,)
404
+ if return_coords:
405
+ if not isinstance(cropped, tuple):
406
+ cropped = (cropped,)
407
+ cropped = cropped + ([x1, y1, x2, y2],)
408
  return cropped