BioMike commited on
Commit
583b2ec
1 Parent(s): 0e658dd

Update interfaces/base_pipeline.py

Browse files
Files changed (1) hide show
  1. interfaces/base_pipeline.py +2 -2
interfaces/base_pipeline.py CHANGED
@@ -2,7 +2,7 @@ from utca.core import RenameAttribute,Flush
2
  from utca.implementation.predictors import TokenSearcherPredictor, TokenSearcherPredictorConfig
3
  from utca.implementation.tasks import TokenSearcherNER, TokenSearcherNERPostprocessor
4
  from utca.implementation.predictors.token_searcher.token_searcher_pipeline import TokenClassificationPipeline
5
- from transformers import AutoTokenizer, AutoModelForTokenClassification
6
 
7
  predictor = TokenSearcherPredictor(
8
  TokenSearcherPredictorConfig(
@@ -29,7 +29,7 @@ def generate_pipeline(threshold: float = 0.5):
29
  tokenizer = AutoTokenizer.from_pretrained("knowledgator/UTC-DeBERTa-large-v2")
30
  model = AutoModelForTokenClassification.from_pretrained("knowledgator/UTC-DeBERTa-large-v2")
31
 
32
- transformers_pipeline = TokenClassificationPipeline(device="cpu", model=model, tokenizer=tokenizer, aggregation_strategy = 'first')
33
 
34
  if __name__=="__main__":
35
  pipeline = generate_pipeline()
 
2
  from utca.implementation.predictors import TokenSearcherPredictor, TokenSearcherPredictorConfig
3
  from utca.implementation.tasks import TokenSearcherNER, TokenSearcherNERPostprocessor
4
  from utca.implementation.predictors.token_searcher.token_searcher_pipeline import TokenClassificationPipeline
5
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
6
 
7
  predictor = TokenSearcherPredictor(
8
  TokenSearcherPredictorConfig(
 
29
  tokenizer = AutoTokenizer.from_pretrained("knowledgator/UTC-DeBERTa-large-v2")
30
  model = AutoModelForTokenClassification.from_pretrained("knowledgator/UTC-DeBERTa-large-v2")
31
 
32
+ transformers_pipeline = pipeline(task="ner", model ="knowledgator/UTC-DeBERTa-small", pipeline_class =TokenClassificationPipeline, aggregation_strategy = "first", batch_size = 12)
33
 
34
  if __name__=="__main__":
35
  pipeline = generate_pipeline()