Hannes Kuchelmeister commited on
Commit
5bf8787
1 Parent(s): 18a45c1

remove error causeing (?) ResNET parameter and reduce range of lr

Browse files
configs/hparams_search/focusResNetMSE_150.yaml CHANGED
@@ -55,12 +55,11 @@ hydra:
55
  choices: [true, false]
56
  model.lr:
57
  type: float
58
- low: 0.0001
59
  high: 0.01
60
  model.resnet_type:
61
  type: categorical
62
- choices: [
63
- "ResNet",
64
  "resnet18",
65
  "resnet34",
66
  "resnet50",
 
55
  choices: [true, false]
56
  model.lr:
57
  type: float
58
+ low: 0.001
59
  high: 0.01
60
  model.resnet_type:
61
  type: categorical
62
+ choices: [
 
63
  "resnet18",
64
  "resnet34",
65
  "resnet50",
src/models/focus_resnet_module.py CHANGED
@@ -12,7 +12,7 @@ import torchvision.models as models
12
  class ResNetLitModule(LightningModule):
13
  def __init__(
14
  self,
15
- resnet_type: str = "ResNet",
16
  pretrained=False,
17
  lr: float = 0.001,
18
  weight_decay: float = 0.0005,
@@ -22,7 +22,7 @@ class ResNetLitModule(LightningModule):
22
  Args:
23
  resnet_type (str, optional): Type of the used resnet network. Defaults to
24
  "ResNet".
25
- Can be one of the following values: "ResNet", "resnet18",
26
  "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d",
27
  "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2"
28
  pretrained (bool, optional): if True loads pytorch pretrained models.
@@ -48,9 +48,7 @@ class ResNetLitModule(LightningModule):
48
 
49
  self.pretrained = pretrained
50
 
51
- if resnet_type == "ResNet":
52
- resnet_constructor = models.ResNet
53
- elif resnet_type == "resnet18":
54
  resnet_constructor = models.resnet18
55
  elif resnet_type == "resnet34":
56
  resnet_constructor = models.resnet34
 
12
  class ResNetLitModule(LightningModule):
13
  def __init__(
14
  self,
15
+ resnet_type: str = "resnet18",
16
  pretrained=False,
17
  lr: float = 0.001,
18
  weight_decay: float = 0.0005,
 
22
  Args:
23
  resnet_type (str, optional): Type of the used resnet network. Defaults to
24
  "ResNet".
25
+ Can be one of the following values: "resnet18",
26
  "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d",
27
  "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2"
28
  pretrained (bool, optional): if True loads pytorch pretrained models.
 
48
 
49
  self.pretrained = pretrained
50
 
51
+ if resnet_type == "resnet18":
 
 
52
  resnet_constructor = models.resnet18
53
  elif resnet_type == "resnet34":
54
  resnet_constructor = models.resnet34