|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- kor_nli |
|
language: |
|
- ko |
|
metrics: |
|
- accuracy |
|
pipeline_tag: zero-shot-classification |
|
--- |
|
|
|
**This model has been referred to the following link : https://github.com/Huffon/klue-transformers-tutorial.git** |
|
|
|
ํด๋น ๋ชจ๋ธ์ ์ ๊นํ๋ธ๋ฅผ ์ฐธ๊ณ ํ์ฌ klue/roberta-base ๋ชจ๋ธ์ kor_nli ์ mnli, xnli๋ก ํ์ธํ๋ํ ๋ชจ๋ธ์
๋๋ค. |
|
| train_loss | val_loss | acc | epoch | batch | lr | |
|
| --- | --- | --- | --- | --- | --- | |
|
| 0.326 | 0.538 | 0.811 | 3 | 32 | 2e-5 | |
|
|
|
|
|
RoBERTa์ ๊ฐ์ด token_type_ids๋ฅผ ์ฌ์ฉํ์ง ์๋ ๋ชจ๋ธ์ ๊ฒฝ์ฐ, zero-shot pipeline์ ๋ฐ๋ก ์ ์ฉํ ์ ์์ต๋๋ค(transformers==4.7.0 ๊ธฐ์ค) |
|
๋ฐ๋ผ์ ๋ค์๊ณผ ๊ฐ์ด ๋ณํํ๋ ์ฝ๋๋ฅผ ๋ฃ์ด์ค์ผ ํฉ๋๋ค. ํด๋น ์ฝ๋ ๋ํ ์ ๊นํ๋ธ์ ์ฝ๋๋ฅผ ์์ ํ์์ต๋๋ค. |
|
|
|
```python |
|
class ArgumentHandler(ABC): |
|
""" |
|
Base interface for handling arguments for each :class:`~transformers.pipelines.Pipeline`. |
|
""" |
|
|
|
@abstractmethod |
|
def __call__(self, *args, **kwargs): |
|
raise NotImplementedError() |
|
|
|
|
|
class CustomZeroShotClassificationArgumentHandler(ArgumentHandler): |
|
""" |
|
Handles arguments for zero-shot for text classification by turning each possible label into an NLI |
|
premise/hypothesis pair. |
|
""" |
|
|
|
def _parse_labels(self, labels): |
|
if isinstance(labels, str): |
|
labels = [label.strip() for label in labels.split(",")] |
|
return labels |
|
|
|
def __call__(self, sequences, labels, hypothesis_template): |
|
if len(labels) == 0 or len(sequences) == 0: |
|
raise ValueError("You must include at least one label and at least one sequence.") |
|
if hypothesis_template.format(labels[0]) == hypothesis_template: |
|
raise ValueError( |
|
( |
|
'The provided hypothesis_template "{}" was not able to be formatted with the target labels. ' |
|
"Make sure the passed template includes formatting syntax such as {{}} where the label should go." |
|
).format(hypothesis_template) |
|
) |
|
|
|
if isinstance(sequences, str): |
|
sequences = [sequences] |
|
labels = self._parse_labels(labels) |
|
|
|
sequence_pairs = [] |
|
for label in labels: |
|
# ์์ ๋ถ: ๋ ๋ฌธ์ฅ์ ํ์ด๋ก ์
๋ ฅํ์ ๋, `token_type_ids`๊ฐ ์๋์ผ๋ก ๋ถ๋ ๋ฌธ์ ๋ฅผ ๋ฐฉ์งํ๊ธฐ ์ํด ๋ฏธ๋ฆฌ ๋ ๋ฌธ์ฅ์ `sep_token` ๊ธฐ์ค์ผ๋ก ์ด์ด์ฃผ๋๋ก ํจ |
|
sequence_pairs.append(f"{sequences} {tokenizer.sep_token} {hypothesis_template.format(label)}") |
|
|
|
return sequence_pairs, sequences |
|
``` |
|
|
|
์ดํ classifier๋ฅผ ์ ์ํ ๋ ์ด๋ฅผ ์ ์ฉํด์ผ ๋ฉ๋๋ค. |
|
```python |
|
classifier = pipeline( |
|
"zero-shot-classification", |
|
args_parser=CustomZeroShotClassificationArgumentHandler(), |
|
model="pongjin/roberta_with_kornli" |
|
) |
|
``` |
|
#### results |
|
```python |
|
sequence = "๋ฐฐ๋น๋ฝ D-1 ์ฝ์คํผ, 2330์ ์์น์ธ...์ธ์ธยท๊ธฐ๊ด ์ฌ์" |
|
candidate_labels =["์ธํ",'ํ์จ', "๊ฒฝ์ ", "๊ธ์ต", "๋ถ๋์ฐ","์ฃผ์"] |
|
|
|
classifier( |
|
sequence, |
|
candidate_labels, |
|
hypothesis_template='์ด๋ {}์ ๊ดํ ๊ฒ์ด๋ค.', |
|
) |
|
|
|
>>{'sequence': '๋ฐฐ๋น๋ฝ D-1 ์ฝ์คํผ, 2330์ ์์น์ธ...์ธ์ธยท๊ธฐ๊ด ์ฌ์', |
|
'labels': ['์ฃผ์', '๊ธ์ต', '๊ฒฝ์ ', '์ธํ', 'ํ์จ', '๋ถ๋์ฐ'], |
|
'scores': [0.5052872896194458, |
|
0.17972524464130402, |
|
0.13852974772453308, |
|
0.09460823982954025, |
|
0.042949128895998, |
|
0.038900360465049744]} |
|
``` |