crscardellino commited on
Commit
1e86311
·
1 Parent(s): cad4540

Fixed an error regarding the verification of a type

Browse files
Files changed (1) hide show
  1. chatbot.py +6 -6
chatbot.py CHANGED
@@ -17,7 +17,7 @@ prompt.
17
  import argparse
18
  import torch
19
 
20
- from transformers import AutoModelForCausalLM, AutoTokenizer
21
  from typing import Optional, Union
22
 
23
 
@@ -28,9 +28,9 @@ class ChatBot:
28
 
29
  Parameters
30
  ----------
31
- base_model : str | AutoModelForCausalLM
32
  A name (path in hugging face hub) for a model, or the model itself.
33
- tokenizer : AutoTokenizer | None
34
  Needed in case the base_model is a given model, otherwise it will load
35
  the same model given by the base_model path.
36
  initial_prompt : str
@@ -53,8 +53,8 @@ class ChatBot:
53
  """
54
 
55
  def __init__(self,
56
- base_model: Union[str, AutoModelForCausalLM],
57
- tokenizer: Optional[AutoTokenizer] = None,
58
  initial_prompt: Optional[str] = None,
59
  keep_context: bool = False,
60
  creative: bool = False,
@@ -69,7 +69,7 @@ class ChatBot:
69
  )
70
  self.tokenizer = AutoTokenizer.from_pretrained(base_model)
71
  else:
72
- assert isinstance(self.tokenizer, AutoTokenizer),\
73
  "If the base model is given, the tokenizer should be given as well"
74
  self.model = base_model
75
  self.tokenizer = tokenizer
 
17
  import argparse
18
  import torch
19
 
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
21
  from typing import Optional, Union
22
 
23
 
 
28
 
29
  Parameters
30
  ----------
31
+ base_model : str | PreTrainedModel
32
  A name (path in hugging face hub) for a model, or the model itself.
33
+ tokenizer : PreTrainedTokenizerBase | None
34
  Needed in case the base_model is a given model, otherwise it will load
35
  the same model given by the base_model path.
36
  initial_prompt : str
 
53
  """
54
 
55
  def __init__(self,
56
+ base_model: Union[str, PreTrainedModel],
57
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
58
  initial_prompt: Optional[str] = None,
59
  keep_context: bool = False,
60
  creative: bool = False,
 
69
  )
70
  self.tokenizer = AutoTokenizer.from_pretrained(base_model)
71
  else:
72
+ assert isinstance(tokenizer, PreTrainedTokenizerBase),\
73
  "If the base model is given, the tokenizer should be given as well"
74
  self.model = base_model
75
  self.tokenizer = tokenizer