runninglsy
commited on
Update modeling_ovis.py
Browse files- modeling_ovis.py +32 -13
modeling_ovis.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
import logging
|
2 |
import os
|
|
|
3 |
from importlib import import_module
|
4 |
from typing import List, Callable, Union, Optional, Dict
|
5 |
|
6 |
import PIL.Image
|
7 |
import torch
|
|
|
8 |
from torch import Tensor
|
9 |
from torch.nn import init
|
10 |
from torch.nn.functional import softmax, gumbel_softmax, pad
|
@@ -556,25 +558,42 @@ class Ovis(OvisPreTrainedModel):
|
|
556 |
cache_cls = HybridCache
|
557 |
llm = self.get_llm()
|
558 |
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
565 |
|
566 |
if need_new_cache:
|
567 |
if hasattr(llm.config, "_pre_quantization_dtype"):
|
568 |
cache_dtype = llm.config._pre_quantization_dtype
|
569 |
else:
|
570 |
cache_dtype = llm.dtype
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
578 |
else:
|
579 |
llm._cache.reset()
|
580 |
return llm._cache
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
+
from packaging import version
|
4 |
from importlib import import_module
|
5 |
from typing import List, Callable, Union, Optional, Dict
|
6 |
|
7 |
import PIL.Image
|
8 |
import torch
|
9 |
+
import transformers
|
10 |
from torch import Tensor
|
11 |
from torch.nn import init
|
12 |
from torch.nn.functional import softmax, gumbel_softmax, pad
|
|
|
558 |
cache_cls = HybridCache
|
559 |
llm = self.get_llm()
|
560 |
|
561 |
+
if version.parse(transformers.__version__) >= version.parse("4.46.0"):
|
562 |
+
need_new_cache = (
|
563 |
+
not hasattr(llm, "_cache")
|
564 |
+
or (not isinstance(llm._cache, cache_cls))
|
565 |
+
or llm._cache.batch_size != batch_size
|
566 |
+
or llm._cache.max_cache_len < max_cache_len
|
567 |
+
)
|
568 |
+
else:
|
569 |
+
need_new_cache = (
|
570 |
+
not hasattr(llm, "_cache")
|
571 |
+
or (not isinstance(llm._cache, cache_cls))
|
572 |
+
or llm._cache.max_batch_size != batch_size
|
573 |
+
or llm._cache.max_cache_len < max_cache_len
|
574 |
+
)
|
575 |
|
576 |
if need_new_cache:
|
577 |
if hasattr(llm.config, "_pre_quantization_dtype"):
|
578 |
cache_dtype = llm.config._pre_quantization_dtype
|
579 |
else:
|
580 |
cache_dtype = llm.dtype
|
581 |
+
if version.parse(transformers.__version__) >= version.parse("4.46.0"):
|
582 |
+
llm._cache = cache_cls(
|
583 |
+
config=llm.config,
|
584 |
+
batch_size=batch_size,
|
585 |
+
max_cache_len=max_cache_len,
|
586 |
+
device=llm.device,
|
587 |
+
dtype=cache_dtype,
|
588 |
+
)
|
589 |
+
else:
|
590 |
+
llm._cache = cache_cls(
|
591 |
+
config=llm.config,
|
592 |
+
max_batch_size=batch_size,
|
593 |
+
max_cache_len=max_cache_len,
|
594 |
+
device=llm.device,
|
595 |
+
dtype=cache_dtype,
|
596 |
+
)
|
597 |
else:
|
598 |
llm._cache.reset()
|
599 |
return llm._cache
|