x-lai commited on
Commit
e629203
·
1 Parent(s): 6144294

Release training script

Browse files

Former-commit-id: 784d960221ee0a4799cfc44ead65ce515294c716

Files changed (2) hide show
  1. model/LISA.py +3 -3
  2. train_ds.py +1 -1
model/LISA.py CHANGED
@@ -6,15 +6,15 @@ import torch.nn.functional as F
6
  from peft import LoraConfig, get_peft_model
7
  from transformers import BitsAndBytesConfig, CLIPVisionModel
8
 
9
- from transformers import CLIPVisionModel, BitsAndBytesConfig
10
- from .llava.model.llava import LlavaLlamaForCausalLM
11
- from .segment_anything import build_sam_vit_h
12
  from utils.utils import (
13
  DEFAULT_IM_END_TOKEN,
14
  DEFAULT_IM_START_TOKEN,
15
  DEFAULT_IMAGE_PATCH_TOKEN,
16
  )
17
 
 
 
 
18
 
19
  def dice_loss(
20
  inputs: torch.Tensor,
 
6
  from peft import LoraConfig, get_peft_model
7
  from transformers import BitsAndBytesConfig, CLIPVisionModel
8
 
 
 
 
9
  from utils.utils import (
10
  DEFAULT_IM_END_TOKEN,
11
  DEFAULT_IM_START_TOKEN,
12
  DEFAULT_IMAGE_PATCH_TOKEN,
13
  )
14
 
15
+ from .llava.model.llava import LlavaLlamaForCausalLM
16
+ from .segment_anything import build_sam_vit_h
17
+
18
 
19
  def dice_loss(
20
  inputs: torch.Tensor,
train_ds.py CHANGED
@@ -63,7 +63,7 @@ def parse_args(args):
63
  parser.add_argument("--dataset_dir", default="./dataset", type=str)
64
  parser.add_argument("--log_base_dir", default="./runs", type=str)
65
  parser.add_argument("--exp_name", default="lisa", type=str)
66
- parser.add_argument("--epochs", default=20, type=int)
67
  parser.add_argument("--steps_per_epoch", default=500, type=int)
68
  parser.add_argument(
69
  "--batch_size", default=2, type=int, help="batch size per device per step"
 
63
  parser.add_argument("--dataset_dir", default="./dataset", type=str)
64
  parser.add_argument("--log_base_dir", default="./runs", type=str)
65
  parser.add_argument("--exp_name", default="lisa", type=str)
66
+ parser.add_argument("--epochs", default=10, type=int)
67
  parser.add_argument("--steps_per_epoch", default=500, type=int)
68
  parser.add_argument(
69
  "--batch_size", default=2, type=int, help="batch size per device per step"