runninglsy commited on
Commit
8c7ac68
·
verified ·
1 Parent(s): ffaa2a5

Update modeling_ovis.py

Browse files
Files changed (1) hide show
  1. modeling_ovis.py +32 -13
modeling_ovis.py CHANGED
@@ -1,10 +1,12 @@
1
  import logging
2
  import os
 
3
  from importlib import import_module
4
  from typing import List, Callable, Union, Optional, Dict
5
 
6
  import PIL.Image
7
  import torch
 
8
  from torch import Tensor
9
  from torch.nn import init
10
  from torch.nn.functional import softmax, gumbel_softmax, pad
@@ -556,25 +558,42 @@ class Ovis(OvisPreTrainedModel):
556
  cache_cls = HybridCache
557
  llm = self.get_llm()
558
 
559
- need_new_cache = (
560
- not hasattr(llm, "_cache")
561
- or (not isinstance(llm._cache, cache_cls))
562
- or llm._cache.batch_size != batch_size
563
- or llm._cache.max_cache_len < max_cache_len
564
- )
 
 
 
 
 
 
 
 
565
 
566
  if need_new_cache:
567
  if hasattr(llm.config, "_pre_quantization_dtype"):
568
  cache_dtype = llm.config._pre_quantization_dtype
569
  else:
570
  cache_dtype = llm.dtype
571
- llm._cache = cache_cls(
572
- config=llm.config,
573
- batch_size=batch_size,
574
- max_cache_len=max_cache_len,
575
- device=llm.device,
576
- dtype=cache_dtype,
577
- )
 
 
 
 
 
 
 
 
 
578
  else:
579
  llm._cache.reset()
580
  return llm._cache
 
1
  import logging
2
  import os
3
+ from packaging import version
4
  from importlib import import_module
5
  from typing import List, Callable, Union, Optional, Dict
6
 
7
  import PIL.Image
8
  import torch
9
+ import transformers
10
  from torch import Tensor
11
  from torch.nn import init
12
  from torch.nn.functional import softmax, gumbel_softmax, pad
 
558
  cache_cls = HybridCache
559
  llm = self.get_llm()
560
 
561
+ if version.parse(transformers.__version__) >= version.parse("4.46.0"):
562
+ need_new_cache = (
563
+ not hasattr(llm, "_cache")
564
+ or (not isinstance(llm._cache, cache_cls))
565
+ or llm._cache.batch_size != batch_size
566
+ or llm._cache.max_cache_len < max_cache_len
567
+ )
568
+ else:
569
+ need_new_cache = (
570
+ not hasattr(llm, "_cache")
571
+ or (not isinstance(llm._cache, cache_cls))
572
+ or llm._cache.max_batch_size != batch_size
573
+ or llm._cache.max_cache_len < max_cache_len
574
+ )
575
 
576
  if need_new_cache:
577
  if hasattr(llm.config, "_pre_quantization_dtype"):
578
  cache_dtype = llm.config._pre_quantization_dtype
579
  else:
580
  cache_dtype = llm.dtype
581
+ if version.parse(transformers.__version__) >= version.parse("4.46.0"):
582
+ llm._cache = cache_cls(
583
+ config=llm.config,
584
+ batch_size=batch_size,
585
+ max_cache_len=max_cache_len,
586
+ device=llm.device,
587
+ dtype=cache_dtype,
588
+ )
589
+ else:
590
+ llm._cache = cache_cls(
591
+ config=llm.config,
592
+ max_batch_size=batch_size,
593
+ max_cache_len=max_cache_len,
594
+ device=llm.device,
595
+ dtype=cache_dtype,
596
+ )
597
  else:
598
  llm._cache.reset()
599
  return llm._cache