Valeriy Sinyukov commited on
Commit
82ec9f7
·
1 Parent(s): 5c5407c

Remove model wrappers, use dict and model input

Browse files
app.py CHANGED
@@ -2,7 +2,6 @@ import pandas as pd
2
  import streamlit as st
3
 
4
  from category_classification.models import models as class_models
5
- from common import Input
6
  from languages import *
7
  from results import process_results
8
 
@@ -33,7 +32,7 @@ authors = st.text_area(authors_label[lang], height=text_area_height(2))
33
  abstract = st.text_area(abstract_label[lang], height=text_area_height(5))
34
 
35
  if title:
36
- input = Input(title=title, abstract=abstract, authors=authors)
37
  model = load_class_model(model_name)
38
  results = model(input)
39
  results = process_results(results, lang)
 
2
  import streamlit as st
3
 
4
  from category_classification.models import models as class_models
 
5
  from languages import *
6
  from results import process_results
7
 
 
32
  abstract = st.text_area(abstract_label[lang], height=text_area_height(5))
33
 
34
  if title:
35
+ input = {"title": title, "abstract": abstract, "authors": authors}
36
  model = load_class_model(model_name)
37
  results = model(input)
38
  results = process_results(results, lang)
category_classification/models/HibiscusMaximus__scibert_paper_classification/model.py CHANGED
@@ -2,16 +2,8 @@ from transformers import pipeline
2
 
3
  name = "HibiscusMaximus/scibert_paper_classification"
4
 
5
- class SciBertPaperClassifier:
6
- def __init__(self):
7
- self.pipeline = pipeline("paper-classification", model=name)
8
-
9
- def __call__(self, input):
10
- return self.pipeline(input)
11
-
12
-
13
  def get_model():
14
- return SciBertPaperClassifier()
15
 
16
 
17
  supported_langs = ["en"]
 
2
 
3
  name = "HibiscusMaximus/scibert_paper_classification"
4
 
 
 
 
 
 
 
 
 
5
  def get_model():
6
+ return pipeline("paper-classification", model=name)
7
 
8
 
9
  supported_langs = ["en"]
category_classification/models/oracat__bert_paper_classifier/model.py CHANGED
@@ -2,15 +2,8 @@ from transformers import pipeline
2
 
3
  name = "oracat/bert-paper-classifier"
4
 
5
-
6
- class BertPaperClassifierModel:
7
- def __init__(self):
8
- self.pipeline = pipeline("text-classification", model=name)
9
-
10
- def __call__(self, input):
11
- return self.pipeline(input.title + ' ' + input.abstract)
12
-
13
  def get_model():
14
- return BertPaperClassifierModel()
 
15
 
16
- supported_langs = ['en']
 
2
 
3
  name = "oracat/bert-paper-classifier"
4
 
 
 
 
 
 
 
 
 
5
  def get_model():
6
+ return pipeline("paper-classification", model=name)
7
+
8
 
9
+ supported_langs = ["en"]
category_classification/models/oracat__bert_paper_classifier_arxiv/model.py CHANGED
@@ -2,15 +2,8 @@ from transformers import pipeline
2
 
3
  name = "oracat/bert-paper-classifier-arxiv"
4
 
5
-
6
- class BertPaperClassifierArxivModel:
7
- def __init__(self):
8
- self.pipeline = pipeline("text-classification", model=name)
9
-
10
- def __call__(self, input):
11
- return self.pipeline(input.title + ' ' + input.abstract)
12
-
13
  def get_model():
14
- return BertPaperClassifierArxivModel()
 
15
 
16
- supported_langs = ['en']
 
2
 
3
  name = "oracat/bert-paper-classifier-arxiv"
4
 
 
 
 
 
 
 
 
 
5
  def get_model():
6
+ return pipeline("paper-classification", model=name)
7
+
8
 
9
+ supported_langs = ["en"]
category_classification/models/pipeline.py CHANGED
@@ -6,16 +6,28 @@ import torch
6
  from transformers import Pipeline, AutoModelForSequenceClassification
7
  from transformers.pipelines import PIPELINE_REGISTRY
8
 
 
9
  class PapersClassificationPipeline(Pipeline):
10
  def _sanitize_parameters(self, **kwargs):
11
  return {}, {}, {}
12
 
13
  def preprocess(self, inputs):
14
- if not isinstance(inputs, tp.Iterable):
 
 
 
 
15
  inputs = [inputs]
 
 
 
16
  texts = [
17
- f"AUTHORS: {' '.join(paper.authors) if isinstance(paper.authors, list) else paper.authors} "
18
- f"TITLE: {paper.title} ABSTRACT: {paper.abstract}"
 
 
 
 
19
  for paper in inputs
20
  ]
21
  inputs = self.tokenizer(
 
6
  from transformers import Pipeline, AutoModelForSequenceClassification
7
  from transformers.pipelines import PIPELINE_REGISTRY
8
 
9
+
10
  class PapersClassificationPipeline(Pipeline):
11
  def _sanitize_parameters(self, **kwargs):
12
  return {}, {}, {}
13
 
14
  def preprocess(self, inputs):
15
+ if (
16
+ not isinstance(inputs, tp.Iterable)
17
+ or isinstance(inputs, tp.Dict)
18
+ or isinstance(inputs, str)
19
+ ):
20
  inputs = [inputs]
21
+ title = "title"
22
+ authors = "authors"
23
+ abstract = "abstract"
24
  texts = [
25
+ (
26
+ f"AUTHORS: {' '.join(paper[title]) if isinstance(paper[authors], list) else paper[authors]} "
27
+ f"TITLE: {paper[title]} ABSTRACT: {paper[abstract]}"
28
+ if not isinstance(paper, str)
29
+ else paper
30
+ )
31
  for paper in inputs
32
  ]
33
  inputs = self.tokenizer(
category_classification/models/translation.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from collections import namedtuple
2
  from functools import partial
3
 
@@ -13,28 +14,27 @@ def get_translator():
13
  torch_dtype="auto",
14
  )
15
 
16
- class Input:
17
- def __init__(self, title, abstract, authors):
18
- self.title = title
19
- self.abstract = abstract
20
- self.authors = authors
21
-
22
  class TranslationModel:
23
  def __init__(self, get_model):
24
  self.translator = get_translator()
25
  self.model = get_model()
26
 
27
- def __call__(self, input):
28
- def translate(text):
29
- if text is None or text.strip() == "":
30
- return ""
31
- text = str(text).strip()
32
- translated = self.translator(text)[0]['translation_text']
33
- return translated
34
- title = translate(input.title)
35
- abstract = translate(input.abstract)
36
- authors = translate(input.authors)
37
- out = self.model(Input(title, abstract, authors))
 
 
 
 
38
  return out
39
 
40
 
@@ -43,4 +43,3 @@ def create_translation_models(models):
43
  f"{name} (С помощью перевода)": partial(TranslationModel, get_model=get_model)
44
  for name, get_model in models.items()
45
  }
46
-
 
1
+ import typing as tp
2
  from collections import namedtuple
3
  from functools import partial
4
 
 
14
  torch_dtype="auto",
15
  )
16
 
17
+
 
 
 
 
 
18
  class TranslationModel:
19
  def __init__(self, get_model):
20
  self.translator = get_translator()
21
  self.model = get_model()
22
 
23
+ def __call__(self, input, **kwargs):
24
+ def transform_input_dict_to_str(input):
25
+ if isinstance(input, tp.Dict):
26
+ return input["authors"] + " " + input["abstract"] + " " + input["title"]
27
+
28
+ if not isinstance(input, tp.Iterable) or isinstance(input, tp.Dict):
29
+ input = [input]
30
+ input = [transform_input_dict_to_str(i) for i in input]
31
+ translated_input = self.translator(input)
32
+ translated = [
33
+ translated_i["translation_text"] for translated_i in translated_input
34
+ ]
35
+ out = self.model(translated)
36
+ if 1 == len(out):
37
+ return out[0]
38
  return out
39
 
40
 
 
43
  f"{name} (С помощью перевода)": partial(TranslationModel, get_model=get_model)
44
  for name, get_model in models.items()
45
  }
 
common.py DELETED
@@ -1,5 +0,0 @@
1
- class Input:
2
- def __init__(self, title, abstract=None, authors=None):
3
- self.title = title
4
- self.abstract = abstract if abstract is not None else ''
5
- self.authors = authors if authors is not None else ''