Spaces:
Paused
Paused
x-lai
commited on
Commit
·
e629203
1
Parent(s):
6144294
Release training script
Browse filesFormer-commit-id: 784d960221ee0a4799cfc44ead65ce515294c716
- model/LISA.py +3 -3
- train_ds.py +1 -1
model/LISA.py
CHANGED
@@ -6,15 +6,15 @@ import torch.nn.functional as F
|
|
6 |
from peft import LoraConfig, get_peft_model
|
7 |
from transformers import BitsAndBytesConfig, CLIPVisionModel
|
8 |
|
9 |
-
from transformers import CLIPVisionModel, BitsAndBytesConfig
|
10 |
-
from .llava.model.llava import LlavaLlamaForCausalLM
|
11 |
-
from .segment_anything import build_sam_vit_h
|
12 |
from utils.utils import (
|
13 |
DEFAULT_IM_END_TOKEN,
|
14 |
DEFAULT_IM_START_TOKEN,
|
15 |
DEFAULT_IMAGE_PATCH_TOKEN,
|
16 |
)
|
17 |
|
|
|
|
|
|
|
18 |
|
19 |
def dice_loss(
|
20 |
inputs: torch.Tensor,
|
|
|
6 |
from peft import LoraConfig, get_peft_model
|
7 |
from transformers import BitsAndBytesConfig, CLIPVisionModel
|
8 |
|
|
|
|
|
|
|
9 |
from utils.utils import (
|
10 |
DEFAULT_IM_END_TOKEN,
|
11 |
DEFAULT_IM_START_TOKEN,
|
12 |
DEFAULT_IMAGE_PATCH_TOKEN,
|
13 |
)
|
14 |
|
15 |
+
from .llava.model.llava import LlavaLlamaForCausalLM
|
16 |
+
from .segment_anything import build_sam_vit_h
|
17 |
+
|
18 |
|
19 |
def dice_loss(
|
20 |
inputs: torch.Tensor,
|
train_ds.py
CHANGED
@@ -63,7 +63,7 @@ def parse_args(args):
|
|
63 |
parser.add_argument("--dataset_dir", default="./dataset", type=str)
|
64 |
parser.add_argument("--log_base_dir", default="./runs", type=str)
|
65 |
parser.add_argument("--exp_name", default="lisa", type=str)
|
66 |
-
parser.add_argument("--epochs", default=
|
67 |
parser.add_argument("--steps_per_epoch", default=500, type=int)
|
68 |
parser.add_argument(
|
69 |
"--batch_size", default=2, type=int, help="batch size per device per step"
|
|
|
63 |
parser.add_argument("--dataset_dir", default="./dataset", type=str)
|
64 |
parser.add_argument("--log_base_dir", default="./runs", type=str)
|
65 |
parser.add_argument("--exp_name", default="lisa", type=str)
|
66 |
+
parser.add_argument("--epochs", default=10, type=int)
|
67 |
parser.add_argument("--steps_per_epoch", default=500, type=int)
|
68 |
parser.add_argument(
|
69 |
"--batch_size", default=2, type=int, help="batch size per device per step"
|