sarahyurick
commited on
Update README.md
Browse files
README.md
CHANGED
@@ -8,17 +8,22 @@ license: apache-2.0
|
|
8 |
|
9 |
# Model Overview
|
10 |
This is a text classification model to classify documents into one of 26 domain classes:
|
11 |
-
|
12 |
'Adult', 'Arts_and_Entertainment', 'Autos_and_Vehicles', 'Beauty_and_Fitness', 'Books_and_Literature', 'Business_and_Industrial', 'Computers_and_Electronics', 'Finance', 'Food_and_Drink', 'Games', 'Health', 'Hobbies_and_Leisure', 'Home_and_Garden', 'Internet_and_Telecom', 'Jobs_and_Education', 'Law_and_Government', 'News', 'Online_Communities', 'People_and_Society', 'Pets_and_Animals', 'Real_Estate', 'Science', 'Sensitive_Subjects', 'Shopping', 'Sports', 'Travel_and_Transportation'
|
|
|
|
|
13 |
# Model Architecture
|
14 |
-
The model architecture is Deberta V3 Base
|
15 |
-
Context length is 512 tokens
|
16 |
-
|
|
|
17 |
## Training data:
|
18 |
- 1 million Common Crawl samples, labeled using Google Cloud’s Natural Language API: https://cloud.google.com/natural-language/docs/classifying-text
|
19 |
- 500k Wikepedia articles, curated using Wikipedia-API: https://pypi.org/project/Wikipedia-API/
|
|
|
20 |
## Training steps:
|
21 |
Model was trained in multiple rounds using Wikipedia and Common Crawl data, labeled by a combination of pseudo labels and Google Cloud API.
|
|
|
22 |
# How To Use This Model
|
23 |
## Input
|
24 |
The model takes one or several paragraphs of text as input.
|
@@ -38,15 +43,14 @@ Example output:
|
|
38 |
Food_and_Drink
|
39 |
```
|
40 |
|
41 |
-
# How to
|
42 |
|
43 |
-
The inference code is available on [NeMo Curator's GitHub repository](https://github.com/NVIDIA/NeMo-Curator).
|
44 |
|
45 |
-
# How to
|
46 |
-
To use the
|
47 |
|
48 |
```python
|
49 |
-
|
50 |
import torch
|
51 |
from torch import nn
|
52 |
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
@@ -55,9 +59,9 @@ from huggingface_hub import PyTorchModelHubMixin
|
|
55 |
class CustomModel(nn.Module, PyTorchModelHubMixin):
|
56 |
def __init__(self, config):
|
57 |
super(CustomModel, self).__init__()
|
58 |
-
self.model = AutoModel.from_pretrained(config[
|
59 |
-
self.dropout = nn.Dropout(config[
|
60 |
-
self.fc = nn.Linear(self.model.config.hidden_size, len(config[
|
61 |
|
62 |
def forward(self, input_ids, attention_mask):
|
63 |
features = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
|
@@ -73,7 +77,7 @@ model = CustomModel.from_pretrained("nvidia/domain-classifier")
|
|
73 |
# Prepare and process inputs
|
74 |
text_samples = ["Sports is a popular domain", "Politics is a popular domain"]
|
75 |
inputs = tokenizer(text_samples, return_tensors="pt", padding="longest", truncation=True)
|
76 |
-
outputs = model(inputs[
|
77 |
|
78 |
# Predict and display results
|
79 |
predicted_classes = torch.argmax(outputs, dim=1)
|
@@ -122,6 +126,7 @@ PR-AUC score for each domain:
|
|
122 |
# References
|
123 |
- https://arxiv.org/abs/2111.09543
|
124 |
- https://github.com/microsoft/DeBERTa
|
|
|
125 |
# License
|
126 |
License to use this model is covered by the Apache 2.0. By downloading the public and release version of the model, you accept the terms and conditions of the Apache License 2.0.
|
127 |
-
This repository contains the code for the domain classifier model.
|
|
|
8 |
|
9 |
# Model Overview
|
10 |
This is a text classification model to classify documents into one of 26 domain classes:
|
11 |
+
```
|
12 |
'Adult', 'Arts_and_Entertainment', 'Autos_and_Vehicles', 'Beauty_and_Fitness', 'Books_and_Literature', 'Business_and_Industrial', 'Computers_and_Electronics', 'Finance', 'Food_and_Drink', 'Games', 'Health', 'Hobbies_and_Leisure', 'Home_and_Garden', 'Internet_and_Telecom', 'Jobs_and_Education', 'Law_and_Government', 'News', 'Online_Communities', 'People_and_Society', 'Pets_and_Animals', 'Real_Estate', 'Science', 'Sensitive_Subjects', 'Shopping', 'Sports', 'Travel_and_Transportation'
|
13 |
+
```
|
14 |
+
|
15 |
# Model Architecture
|
16 |
+
- The model architecture is Deberta V3 Base
|
17 |
+
- Context length is 512 tokens
|
18 |
+
|
19 |
+
# Training Details
|
20 |
## Training data:
|
21 |
- 1 million Common Crawl samples, labeled using Google Cloud’s Natural Language API: https://cloud.google.com/natural-language/docs/classifying-text
|
22 |
- 500k Wikepedia articles, curated using Wikipedia-API: https://pypi.org/project/Wikipedia-API/
|
23 |
+
|
24 |
## Training steps:
|
25 |
Model was trained in multiple rounds using Wikipedia and Common Crawl data, labeled by a combination of pseudo labels and Google Cloud API.
|
26 |
+
|
27 |
# How To Use This Model
|
28 |
## Input
|
29 |
The model takes one or several paragraphs of text as input.
|
|
|
43 |
Food_and_Drink
|
44 |
```
|
45 |
|
46 |
+
# How to Use in NVIDIA NeMo Curator
|
47 |
|
48 |
+
The inference code is available on [NeMo Curator's GitHub repository](https://github.com/NVIDIA/NeMo-Curator). Check out this [example notebook](https://github.com/NVIDIA/NeMo-Curator/blob/main/tutorials/distributed_data_classification/distributed_data_classification.ipynb) to get started.
|
49 |
|
50 |
+
# How to Use in Transformers
|
51 |
+
To use the domain classifier, use the following code:
|
52 |
|
53 |
```python
|
|
|
54 |
import torch
|
55 |
from torch import nn
|
56 |
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
|
|
59 |
class CustomModel(nn.Module, PyTorchModelHubMixin):
|
60 |
def __init__(self, config):
|
61 |
super(CustomModel, self).__init__()
|
62 |
+
self.model = AutoModel.from_pretrained(config["base_model"])
|
63 |
+
self.dropout = nn.Dropout(config["fc_dropout"])
|
64 |
+
self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"]))
|
65 |
|
66 |
def forward(self, input_ids, attention_mask):
|
67 |
features = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
|
|
|
77 |
# Prepare and process inputs
|
78 |
text_samples = ["Sports is a popular domain", "Politics is a popular domain"]
|
79 |
inputs = tokenizer(text_samples, return_tensors="pt", padding="longest", truncation=True)
|
80 |
+
outputs = model(inputs["input_ids"], inputs["attention_mask"])
|
81 |
|
82 |
# Predict and display results
|
83 |
predicted_classes = torch.argmax(outputs, dim=1)
|
|
|
126 |
# References
|
127 |
- https://arxiv.org/abs/2111.09543
|
128 |
- https://github.com/microsoft/DeBERTa
|
129 |
+
|
130 |
# License
|
131 |
License to use this model is covered by the Apache 2.0. By downloading the public and release version of the model, you accept the terms and conditions of the Apache License 2.0.
|
132 |
+
This repository contains the code for the domain classifier model.
|