Hannes Kuchelmeister
commited on
Commit
•
5bf8787
1
Parent(s):
18a45c1
remove error causeing (?) ResNET parameter and reduce range of lr
Browse files
configs/hparams_search/focusResNetMSE_150.yaml
CHANGED
@@ -55,12 +55,11 @@ hydra:
|
|
55 |
choices: [true, false]
|
56 |
model.lr:
|
57 |
type: float
|
58 |
-
low: 0.
|
59 |
high: 0.01
|
60 |
model.resnet_type:
|
61 |
type: categorical
|
62 |
-
choices: [
|
63 |
-
"ResNet",
|
64 |
"resnet18",
|
65 |
"resnet34",
|
66 |
"resnet50",
|
|
|
55 |
choices: [true, false]
|
56 |
model.lr:
|
57 |
type: float
|
58 |
+
low: 0.001
|
59 |
high: 0.01
|
60 |
model.resnet_type:
|
61 |
type: categorical
|
62 |
+
choices: [
|
|
|
63 |
"resnet18",
|
64 |
"resnet34",
|
65 |
"resnet50",
|
src/models/focus_resnet_module.py
CHANGED
@@ -12,7 +12,7 @@ import torchvision.models as models
|
|
12 |
class ResNetLitModule(LightningModule):
|
13 |
def __init__(
|
14 |
self,
|
15 |
-
resnet_type: str = "
|
16 |
pretrained=False,
|
17 |
lr: float = 0.001,
|
18 |
weight_decay: float = 0.0005,
|
@@ -22,7 +22,7 @@ class ResNetLitModule(LightningModule):
|
|
22 |
Args:
|
23 |
resnet_type (str, optional): Type of the used resnet network. Defaults to
|
24 |
"ResNet".
|
25 |
-
Can be one of the following values: "
|
26 |
"resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d",
|
27 |
"resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2"
|
28 |
pretrained (bool, optional): if True loads pytorch pretrained models.
|
@@ -48,9 +48,7 @@ class ResNetLitModule(LightningModule):
|
|
48 |
|
49 |
self.pretrained = pretrained
|
50 |
|
51 |
-
if resnet_type == "
|
52 |
-
resnet_constructor = models.ResNet
|
53 |
-
elif resnet_type == "resnet18":
|
54 |
resnet_constructor = models.resnet18
|
55 |
elif resnet_type == "resnet34":
|
56 |
resnet_constructor = models.resnet34
|
|
|
12 |
class ResNetLitModule(LightningModule):
|
13 |
def __init__(
|
14 |
self,
|
15 |
+
resnet_type: str = "resnet18",
|
16 |
pretrained=False,
|
17 |
lr: float = 0.001,
|
18 |
weight_decay: float = 0.0005,
|
|
|
22 |
Args:
|
23 |
resnet_type (str, optional): Type of the used resnet network. Defaults to
|
24 |
"ResNet".
|
25 |
+
Can be one of the following values: "resnet18",
|
26 |
"resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d",
|
27 |
"resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2"
|
28 |
pretrained (bool, optional): if True loads pytorch pretrained models.
|
|
|
48 |
|
49 |
self.pretrained = pretrained
|
50 |
|
51 |
+
if resnet_type == "resnet18":
|
|
|
|
|
52 |
resnet_constructor = models.resnet18
|
53 |
elif resnet_type == "resnet34":
|
54 |
resnet_constructor = models.resnet34
|