shamaayan commited on
Commit
4aa700d
·
1 Parent(s): 8c22397
Files changed (1) hide show
  1. handler.py +7 -5
handler.py CHANGED
@@ -1,12 +1,14 @@
1
- from transformers import pipeline
2
  from typing import Dict, List, Any
3
  from tokenizers.decoders import WordPiece
4
 
5
 
6
  class EndpointHandler:
7
- def __init__(self):
8
- self.model = pipeline('ner', model='dicta-il/dictabert-ner', aggregation_strategy='simple')
9
- self.model.tokenizer.backend_tokenizer.decoder = WordPiece()
 
 
10
 
11
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
12
  """
@@ -16,4 +18,4 @@ class EndpointHandler:
16
  Return:
17
  A :obj:`list` | `dict`: will be serialized and returned
18
  """
19
- return self.model(data['inputs'])
 
1
+ from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
2
  from typing import Dict, List, Any
3
  from tokenizers.decoders import WordPiece
4
 
5
 
6
  class EndpointHandler:
7
+ def __init__(self, path="."):
8
+ model = AutoModelForTokenClassification.from_pretrained(path)
9
+ tokenizer = AutoTokenizer.from_pretrained(path)
10
+ self.pipeline = pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy='simple')
11
+ self.pipeline.tokenizer.backend_tokenizer.decoder = WordPiece()
12
 
13
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
14
  """
 
18
  Return:
19
  A :obj:`list` | `dict`: will be serialized and returned
20
  """
21
+ return self.pipeline(data['inputs'])