Upload model
Browse files- modeling.py +15 -4
modeling.py
CHANGED
@@ -113,7 +113,7 @@ class CTCropModel(PreTrainedModel):
|
|
113 |
fname: str = "",
|
114 |
sort_by_instance_number: bool = False,
|
115 |
exclude_invalid_dicoms: bool = False,
|
116 |
-
):
|
117 |
attributes = [
|
118 |
"pixel_array",
|
119 |
"RescaleSlope",
|
@@ -281,7 +281,7 @@ class CTCropModel(PreTrainedModel):
|
|
281 |
coords: torch.Tensor,
|
282 |
buffer: float | tuple[float, float] = 0.05,
|
283 |
empty_threshold: float = 1e-4,
|
284 |
-
):
|
285 |
coords = coords.clone()
|
286 |
empty = (coords < empty_threshold).all(dim=1)
|
287 |
# assumes coords is a torch.Tensor of shape (N, 4) containing
|
@@ -352,7 +352,12 @@ class CTCropModel(PreTrainedModel):
|
|
352 |
raw_hu: bool = False,
|
353 |
remove_empty_slices: bool = False,
|
354 |
add_buffer: float | tuple[float, float] | None = None,
|
355 |
-
|
|
|
|
|
|
|
|
|
|
|
356 |
assert mode in ["2d", "3d"]
|
357 |
if device is None:
|
358 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -393,5 +398,11 @@ class CTCropModel(PreTrainedModel):
|
|
393 |
empty_indices = list(torch.where(empty)[0].cpu().numpy())
|
394 |
print(f"removing {empty.sum()} empty slices ...")
|
395 |
cropped = cropped[~empty.cpu().numpy()]
|
396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
return cropped
|
|
|
113 |
fname: str = "",
|
114 |
sort_by_instance_number: bool = False,
|
115 |
exclude_invalid_dicoms: bool = False,
|
116 |
+
) -> bool:
|
117 |
attributes = [
|
118 |
"pixel_array",
|
119 |
"RescaleSlope",
|
|
|
281 |
coords: torch.Tensor,
|
282 |
buffer: float | tuple[float, float] = 0.05,
|
283 |
empty_threshold: float = 1e-4,
|
284 |
+
) -> torch.Tensor:
|
285 |
coords = coords.clone()
|
286 |
empty = (coords < empty_threshold).all(dim=1)
|
287 |
# assumes coords is a torch.Tensor of shape (N, 4) containing
|
|
|
352 |
raw_hu: bool = False,
|
353 |
remove_empty_slices: bool = False,
|
354 |
add_buffer: float | tuple[float, float] | None = None,
|
355 |
+
return_coords: bool = False,
|
356 |
+
) -> (
|
357 |
+
np.ndarray
|
358 |
+
| tuple[np.ndarray, list[int]]
|
359 |
+
| tuple[np.ndarray, list[int], list[int]]
|
360 |
+
):
|
361 |
assert mode in ["2d", "3d"]
|
362 |
if device is None:
|
363 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
398 |
empty_indices = list(torch.where(empty)[0].cpu().numpy())
|
399 |
print(f"removing {empty.sum()} empty slices ...")
|
400 |
cropped = cropped[~empty.cpu().numpy()]
|
401 |
+
if not isinstance(cropped, tuple):
|
402 |
+
cropped = (cropped,)
|
403 |
+
cropped = cropped + (empty_indices,)
|
404 |
+
if return_coords:
|
405 |
+
if not isinstance(cropped, tuple):
|
406 |
+
cropped = (cropped,)
|
407 |
+
cropped = cropped + ([x1, y1, x2, y2],)
|
408 |
return cropped
|