Spaces:
Runtime error
Runtime error
Update models/videochat.py
Browse files- models/videochat.py +9 -1
models/videochat.py
CHANGED
@@ -51,7 +51,13 @@ class VideoChat(Blip2Base):
|
|
51 |
|
52 |
self.tokenizer = self.init_tokenizer()
|
53 |
self.low_resource = low_resource
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
self.vit_precision = vit_precision
|
56 |
print(f'Loading VIT. Use fp16: {vit_precision}')
|
57 |
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
@@ -133,6 +139,7 @@ class VideoChat(Blip2Base):
|
|
133 |
use_auth_token=os.environ["HF_TOKEN"],
|
134 |
)
|
135 |
else:
|
|
|
136 |
self.llama_model = LlamaForCausalLM.from_pretrained(
|
137 |
llama_model_path,
|
138 |
torch_dtype=torch.float16,
|
@@ -140,6 +147,7 @@ class VideoChat(Blip2Base):
|
|
140 |
load_in_8bit=True,
|
141 |
device_map="auto"
|
142 |
)
|
|
|
143 |
|
144 |
print("freeze LLAMA")
|
145 |
for name, param in self.llama_model.named_parameters():
|
|
|
51 |
|
52 |
self.tokenizer = self.init_tokenizer()
|
53 |
self.low_resource = low_resource
|
54 |
+
self.llama_model = LlamaForCausalLM.from_pretrained(
|
55 |
+
llama_model_path,
|
56 |
+
torch_dtype=torch.float16,
|
57 |
+
use_auth_token=os.environ["HF_TOKEN"],
|
58 |
+
load_in_8bit=True,
|
59 |
+
device_map="auto"
|
60 |
+
)
|
61 |
self.vit_precision = vit_precision
|
62 |
print(f'Loading VIT. Use fp16: {vit_precision}')
|
63 |
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
|
|
139 |
use_auth_token=os.environ["HF_TOKEN"],
|
140 |
)
|
141 |
else:
|
142 |
+
'''
|
143 |
self.llama_model = LlamaForCausalLM.from_pretrained(
|
144 |
llama_model_path,
|
145 |
torch_dtype=torch.float16,
|
|
|
147 |
load_in_8bit=True,
|
148 |
device_map="auto"
|
149 |
)
|
150 |
+
'''
|
151 |
|
152 |
print("freeze LLAMA")
|
153 |
for name, param in self.llama_model.named_parameters():
|