Andrei-Iulian SĂCELEANU commited on
Commit
731f5de
·
1 Parent(s): 674a3ea

large files

Browse files
.gitattributes CHANGED
@@ -32,3 +32,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ checkpoints/ filter=lfs diff=lfs merge=lfs -text
36
+ checkpoints/freematch_tune.index filter=lfs diff=lfs merge=lfs -text
37
+ checkpoints/mixmatch.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
38
+ checkpoints/mixmatch.index filter=lfs diff=lfs merge=lfs -text
39
+ checkpoints/fixmatch_tune.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
40
+ checkpoints/fixmatch_tune.index filter=lfs diff=lfs merge=lfs -text
41
+ checkpoints/freematch_tune.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -33,9 +33,16 @@ def ssl_predict(in_text, model_type):
33
  truncation=True,
34
  return_tensors="tf"
35
  )
36
- if model_type == "freematch":
 
 
 
 
37
  model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch")
38
  model.cls_head.load_weights("./checkpoints/freematch_tune")
 
 
 
39
 
40
  preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
41
  probs = list(preds[0].numpy())
 
33
  truncation=True,
34
  return_tensors="tf"
35
  )
36
+
37
+ if model_type == "fixmatch":
38
+ model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
39
+ model.load_weights("./checkpoints/fixmatch_tune")
40
+ elif model_type == "freematch":
41
  model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch")
42
  model.cls_head.load_weights("./checkpoints/freematch_tune")
43
+ elif model_type == "mixmatch":
44
+ model = MixMatch(encoder_name="andrei-saceleanu/ro-offense-mixmatch")
45
+ model.cls_head.load_weights("./checkpoints/mixmatch")
46
 
47
  preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
48
  probs = list(preds[0].numpy())
checkpoints/fixmatch_tune.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8dfdc2f8ad0f036e0bfad3676782816272e98fdf63f09d54883a768084451f8
3
+ size 461147136
checkpoints/fixmatch_tune.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:208a951860ca39c7b24278e111864cfca1ec65bcda0ab9ba90f9fa4e052341a2
3
+ size 14764
checkpoints/freematch_tune.data-00000-of-00001 CHANGED
Binary files a/checkpoints/freematch_tune.data-00000-of-00001 and b/checkpoints/freematch_tune.data-00000-of-00001 differ
 
checkpoints/freematch_tune.index CHANGED
Binary files a/checkpoints/freematch_tune.index and b/checkpoints/freematch_tune.index differ
 
checkpoints/mixmatch.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5520588c6d43a9fda3dd5b111f92366fd4d855f321393b01de5d014f4bbe76f1
3
+ size 855091
checkpoints/mixmatch.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:699017f5ef5ef36f5e8dad59fa659290960bc36aca0936ca340492add17b23e3
3
+ size 518
models.py CHANGED
@@ -1,8 +1,10 @@
 
1
  import tensorflow as tf
2
  from transformers import TFAutoModel
3
 
4
 
5
  class FixMatchTune(tf.keras.Model):
 
6
  def __init__(
7
  self,
8
  encoder_name="readerbench/RoBERT-base",
@@ -35,3 +37,28 @@ class FixMatchTune(tf.keras.Model):
35
  weak_preds = self.cls_head(weaks,training=training)
36
 
37
  return weak_preds, strong_preds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model definitions"""
2
  import tensorflow as tf
3
  from transformers import TFAutoModel
4
 
5
 
6
  class FixMatchTune(tf.keras.Model):
7
+ """fixmatch"""
8
  def __init__(
9
  self,
10
  encoder_name="readerbench/RoBERT-base",
 
37
  weak_preds = self.cls_head(weaks,training=training)
38
 
39
  return weak_preds, strong_preds
40
+
41
+ class MixMatch(tf.keras.Model):
42
+ """mixmatch"""
43
+ def __init__(self,bert_model="readerbench/RoBERT-base",num_classes=4,**kwargs):
44
+ super(MixMatch,self).__init__(**kwargs)
45
+ self.bert = TFAutoModel.from_pretrained(bert_model)
46
+
47
+ self.num_classes = num_classes
48
+
49
+ self.cls_head = tf.keras.Sequential([
50
+ tf.keras.layers.Dense(256,activation="relu"),
51
+ tf.keras.layers.Dropout(0.2),
52
+ tf.keras.layers.Dense(64,activation="relu"),
53
+ tf.keras.layers.Dense(self.num_classes, activation="softmax")
54
+ ])
55
+
56
+ self.augment = tf.keras.layers.GaussianNoise(stddev=2)
57
+
58
+ def call(self, inputs, training):
59
+ ids, mask = inputs
60
+
61
+ embeds = self.bert(input_ids=ids, attention_mask=mask,training=training).pooler_output
62
+ augs = self.augment(embeds,training=training)
63
+
64
+ return self.cls_head(augs,training=training)