Update README.md
Browse files
README.md
CHANGED
@@ -188,6 +188,7 @@ Results:
|
|
188 |
|
189 |
For more information, see: [Model Recycling](https://ibm.github.io/model-recycling/)
|
190 |
|
|
|
191 |
# Citation
|
192 |
|
193 |
More details on this [article:](https://arxiv.org/abs/2301.05948)
|
@@ -200,7 +201,26 @@ More details on this [article:](https://arxiv.org/abs/2301.05948)
|
|
200 |
year={2023}
|
201 |
}
|
202 |
```
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
|
206 |
# Model Card Contact
|
|
|
188 |
|
189 |
For more information, see: [Model Recycling](https://ibm.github.io/model-recycling/)
|
190 |
|
191 |
+
|
192 |
# Citation
|
193 |
|
194 |
More details on this [article:](https://arxiv.org/abs/2301.05948)
|
|
|
201 |
year={2023}
|
202 |
}
|
203 |
```
|
204 |
+
|
205 |
+
# Loading a specific classifier
|
206 |
+
|
207 |
+
```
|
208 |
+
from torch import nn
|
209 |
+
|
210 |
+
TASK_NAME = "hh-rlhf"
|
211 |
+
|
212 |
+
class MultiTask(transformers.DebertaV2ForMultipleChoice):
|
213 |
+
def __init__(self, *args, **kwargs):
|
214 |
+
super().__init__(*args)
|
215 |
+
n=len(self.config.tasks)
|
216 |
+
cs=self.config.classifiers_size
|
217 |
+
self.Z = nn.Embedding(n,768)
|
218 |
+
self.classifiers = nn.ModuleList([torch.nn.Linear(*size) for size in cs])
|
219 |
+
|
220 |
+
model = MultiTask.from_pretrained("sileod/deberta-v3-base-tasksource-nli",ignore_mismatched_sizes=True)
|
221 |
+
task_index = {k:v for v,k in dict(enumerate(model.config.tasks)).items()}[TASK_NAME]
|
222 |
+
model.classifier = model.classifiers[task_index] # model is ready for $TASK_NAME !
|
223 |
+
```
|
224 |
|
225 |
|
226 |
# Model Card Contact
|