Valeriy Sinyukov
commited on
Commit
·
82ec9f7
1
Parent(s):
5c5407c
Remove model wrappers, use dict and model input
Browse files- app.py +1 -2
- category_classification/models/HibiscusMaximus__scibert_paper_classification/model.py +1 -9
- category_classification/models/oracat__bert_paper_classifier/model.py +3 -10
- category_classification/models/oracat__bert_paper_classifier_arxiv/model.py +3 -10
- category_classification/models/pipeline.py +15 -3
- category_classification/models/translation.py +17 -18
- common.py +0 -5
app.py
CHANGED
@@ -2,7 +2,6 @@ import pandas as pd
|
|
2 |
import streamlit as st
|
3 |
|
4 |
from category_classification.models import models as class_models
|
5 |
-
from common import Input
|
6 |
from languages import *
|
7 |
from results import process_results
|
8 |
|
@@ -33,7 +32,7 @@ authors = st.text_area(authors_label[lang], height=text_area_height(2))
|
|
33 |
abstract = st.text_area(abstract_label[lang], height=text_area_height(5))
|
34 |
|
35 |
if title:
|
36 |
-
input =
|
37 |
model = load_class_model(model_name)
|
38 |
results = model(input)
|
39 |
results = process_results(results, lang)
|
|
|
2 |
import streamlit as st
|
3 |
|
4 |
from category_classification.models import models as class_models
|
|
|
5 |
from languages import *
|
6 |
from results import process_results
|
7 |
|
|
|
32 |
abstract = st.text_area(abstract_label[lang], height=text_area_height(5))
|
33 |
|
34 |
if title:
|
35 |
+
input = {"title": title, "abstract": abstract, "authors": authors}
|
36 |
model = load_class_model(model_name)
|
37 |
results = model(input)
|
38 |
results = process_results(results, lang)
|
category_classification/models/HibiscusMaximus__scibert_paper_classification/model.py
CHANGED
@@ -2,16 +2,8 @@ from transformers import pipeline
|
|
2 |
|
3 |
name = "HibiscusMaximus/scibert_paper_classification"
|
4 |
|
5 |
-
class SciBertPaperClassifier:
|
6 |
-
def __init__(self):
|
7 |
-
self.pipeline = pipeline("paper-classification", model=name)
|
8 |
-
|
9 |
-
def __call__(self, input):
|
10 |
-
return self.pipeline(input)
|
11 |
-
|
12 |
-
|
13 |
def get_model():
|
14 |
-
return
|
15 |
|
16 |
|
17 |
supported_langs = ["en"]
|
|
|
2 |
|
3 |
name = "HibiscusMaximus/scibert_paper_classification"
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
def get_model():
|
6 |
+
return pipeline("paper-classification", model=name)
|
7 |
|
8 |
|
9 |
supported_langs = ["en"]
|
category_classification/models/oracat__bert_paper_classifier/model.py
CHANGED
@@ -2,15 +2,8 @@ from transformers import pipeline
|
|
2 |
|
3 |
name = "oracat/bert-paper-classifier"
|
4 |
|
5 |
-
|
6 |
-
class BertPaperClassifierModel:
|
7 |
-
def __init__(self):
|
8 |
-
self.pipeline = pipeline("text-classification", model=name)
|
9 |
-
|
10 |
-
def __call__(self, input):
|
11 |
-
return self.pipeline(input.title + ' ' + input.abstract)
|
12 |
-
|
13 |
def get_model():
|
14 |
-
return
|
|
|
15 |
|
16 |
-
supported_langs = [
|
|
|
2 |
|
3 |
name = "oracat/bert-paper-classifier"
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
def get_model():
|
6 |
+
return pipeline("paper-classification", model=name)
|
7 |
+
|
8 |
|
9 |
+
supported_langs = ["en"]
|
category_classification/models/oracat__bert_paper_classifier_arxiv/model.py
CHANGED
@@ -2,15 +2,8 @@ from transformers import pipeline
|
|
2 |
|
3 |
name = "oracat/bert-paper-classifier-arxiv"
|
4 |
|
5 |
-
|
6 |
-
class BertPaperClassifierArxivModel:
|
7 |
-
def __init__(self):
|
8 |
-
self.pipeline = pipeline("text-classification", model=name)
|
9 |
-
|
10 |
-
def __call__(self, input):
|
11 |
-
return self.pipeline(input.title + ' ' + input.abstract)
|
12 |
-
|
13 |
def get_model():
|
14 |
-
return
|
|
|
15 |
|
16 |
-
supported_langs = [
|
|
|
2 |
|
3 |
name = "oracat/bert-paper-classifier-arxiv"
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
def get_model():
|
6 |
+
return pipeline("paper-classification", model=name)
|
7 |
+
|
8 |
|
9 |
+
supported_langs = ["en"]
|
category_classification/models/pipeline.py
CHANGED
@@ -6,16 +6,28 @@ import torch
|
|
6 |
from transformers import Pipeline, AutoModelForSequenceClassification
|
7 |
from transformers.pipelines import PIPELINE_REGISTRY
|
8 |
|
|
|
9 |
class PapersClassificationPipeline(Pipeline):
|
10 |
def _sanitize_parameters(self, **kwargs):
|
11 |
return {}, {}, {}
|
12 |
|
13 |
def preprocess(self, inputs):
|
14 |
-
if
|
|
|
|
|
|
|
|
|
15 |
inputs = [inputs]
|
|
|
|
|
|
|
16 |
texts = [
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
for paper in inputs
|
20 |
]
|
21 |
inputs = self.tokenizer(
|
|
|
6 |
from transformers import Pipeline, AutoModelForSequenceClassification
|
7 |
from transformers.pipelines import PIPELINE_REGISTRY
|
8 |
|
9 |
+
|
10 |
class PapersClassificationPipeline(Pipeline):
|
11 |
def _sanitize_parameters(self, **kwargs):
|
12 |
return {}, {}, {}
|
13 |
|
14 |
def preprocess(self, inputs):
|
15 |
+
if (
|
16 |
+
not isinstance(inputs, tp.Iterable)
|
17 |
+
or isinstance(inputs, tp.Dict)
|
18 |
+
or isinstance(inputs, str)
|
19 |
+
):
|
20 |
inputs = [inputs]
|
21 |
+
title = "title"
|
22 |
+
authors = "authors"
|
23 |
+
abstract = "abstract"
|
24 |
texts = [
|
25 |
+
(
|
26 |
+
f"AUTHORS: {' '.join(paper[title]) if isinstance(paper[authors], list) else paper[authors]} "
|
27 |
+
f"TITLE: {paper[title]} ABSTRACT: {paper[abstract]}"
|
28 |
+
if not isinstance(paper, str)
|
29 |
+
else paper
|
30 |
+
)
|
31 |
for paper in inputs
|
32 |
]
|
33 |
inputs = self.tokenizer(
|
category_classification/models/translation.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from collections import namedtuple
|
2 |
from functools import partial
|
3 |
|
@@ -13,28 +14,27 @@ def get_translator():
|
|
13 |
torch_dtype="auto",
|
14 |
)
|
15 |
|
16 |
-
|
17 |
-
def __init__(self, title, abstract, authors):
|
18 |
-
self.title = title
|
19 |
-
self.abstract = abstract
|
20 |
-
self.authors = authors
|
21 |
-
|
22 |
class TranslationModel:
|
23 |
def __init__(self, get_model):
|
24 |
self.translator = get_translator()
|
25 |
self.model = get_model()
|
26 |
|
27 |
-
def __call__(self, input):
|
28 |
-
def
|
29 |
-
if
|
30 |
-
return ""
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
38 |
return out
|
39 |
|
40 |
|
@@ -43,4 +43,3 @@ def create_translation_models(models):
|
|
43 |
f"{name} (С помощью перевода)": partial(TranslationModel, get_model=get_model)
|
44 |
for name, get_model in models.items()
|
45 |
}
|
46 |
-
|
|
|
1 |
+
import typing as tp
|
2 |
from collections import namedtuple
|
3 |
from functools import partial
|
4 |
|
|
|
14 |
torch_dtype="auto",
|
15 |
)
|
16 |
|
17 |
+
|
|
|
|
|
|
|
|
|
|
|
18 |
class TranslationModel:
|
19 |
def __init__(self, get_model):
|
20 |
self.translator = get_translator()
|
21 |
self.model = get_model()
|
22 |
|
23 |
+
def __call__(self, input, **kwargs):
|
24 |
+
def transform_input_dict_to_str(input):
|
25 |
+
if isinstance(input, tp.Dict):
|
26 |
+
return input["authors"] + " " + input["abstract"] + " " + input["title"]
|
27 |
+
|
28 |
+
if not isinstance(input, tp.Iterable) or isinstance(input, tp.Dict):
|
29 |
+
input = [input]
|
30 |
+
input = [transform_input_dict_to_str(i) for i in input]
|
31 |
+
translated_input = self.translator(input)
|
32 |
+
translated = [
|
33 |
+
translated_i["translation_text"] for translated_i in translated_input
|
34 |
+
]
|
35 |
+
out = self.model(translated)
|
36 |
+
if 1 == len(out):
|
37 |
+
return out[0]
|
38 |
return out
|
39 |
|
40 |
|
|
|
43 |
f"{name} (С помощью перевода)": partial(TranslationModel, get_model=get_model)
|
44 |
for name, get_model in models.items()
|
45 |
}
|
|
common.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
class Input:
|
2 |
-
def __init__(self, title, abstract=None, authors=None):
|
3 |
-
self.title = title
|
4 |
-
self.abstract = abstract if abstract is not None else ''
|
5 |
-
self.authors = authors if authors is not None else ''
|
|
|
|
|
|
|
|
|
|
|
|