pszemraj commited on
Commit
2b8c4c9
1 Parent(s): defd486

✨ add ability to force CPU

Browse files

Signed-off-by: peter szemraj <[email protected]>

Files changed (1) hide show
  1. aggregate.py +9 -2
aggregate.py CHANGED
@@ -54,15 +54,22 @@ class BatchAggregator:
54
  DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:"
55
 
56
  def __init__(
57
- self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs
 
 
 
58
  ):
59
  """
60
  __init__ initializes the BatchAggregator class.
61
 
62
  :param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
 
63
  """
64
  self.device = None
65
  self.is_compiled = False
 
 
 
66
  self.logger = logging.getLogger(__name__)
67
  self.init_model(model_name)
68
 
@@ -105,7 +112,7 @@ class BatchAggregator:
105
 
106
  :raises Exception: if the pipeline cannot be created
107
  """
108
- self.device = 0 if torch.cuda.is_available() else -1
109
  try:
110
  self.logger.info(
111
  f"Creating pipeline with model {model_name} on device {self.device}"
 
54
  DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:"
55
 
56
  def __init__(
57
+ self,
58
+ model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1",
59
+ force_cpu: bool = False,
60
+ **kwargs,
61
  ):
62
  """
63
  __init__ initializes the BatchAggregator class.
64
 
65
  :param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
66
+ :param bool force_cpu: force the model to run on CPU, default: False
67
  """
68
  self.device = None
69
  self.is_compiled = False
70
+ self.model_name = None
71
+ self.aggregator = None
72
+ self.force_cpu = force_cpu
73
  self.logger = logging.getLogger(__name__)
74
  self.init_model(model_name)
75
 
 
112
 
113
  :raises Exception: if the pipeline cannot be created
114
  """
115
+ self.device = 0 if torch.cuda.is_available() and not self.force_cpu else -1
116
  try:
117
  self.logger.info(
118
  f"Creating pipeline with model {model_name} on device {self.device}"