Cicciokr commited on
Commit
90e7939
·
verified ·
1 Parent(s): 57ead9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -1
app.py CHANGED
@@ -68,12 +68,36 @@ input_text = st.text_area(
68
  tokenizer_roberta = AutoTokenizer.from_pretrained("Cicciokr/Roberta-Base-Latin-Uncased")
69
  model_roberta = AutoModelForMaskedLM.from_pretrained("Cicciokr/Roberta-Base-Latin-Uncased")
70
  fill_mask_roberta = pipeline("fill-mask", model=model_roberta, tokenizer=tokenizer_roberta)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # Se l'utente ha inserito (o selezionato) un testo
73
  if input_text:
74
  # Sostituiamo [MASK] con <mask> (lo tokenizer Roberta se lo aspetta così)
75
  input_text_roberta = input_text.replace("[MASK]", "<mask>")
76
- predictions_roberta = fill_mask_roberta(input_text_roberta)
 
 
77
 
78
  st.subheader("Risultati delle previsioni:")
79
  for pred in predictions_roberta:
 
68
  tokenizer_roberta = AutoTokenizer.from_pretrained("Cicciokr/Roberta-Base-Latin-Uncased")
69
  model_roberta = AutoModelForMaskedLM.from_pretrained("Cicciokr/Roberta-Base-Latin-Uncased")
70
  fill_mask_roberta = pipeline("fill-mask", model=model_roberta, tokenizer=tokenizer_roberta)
71
+ punctuation_marks = {".", ",", ";", ":", "!", "?"}
72
+
73
+ def get_valid_predictions(sentence, max_attempts=3, top_k=5):
74
+ attempt = 0
75
+ filtered_predictions = []
76
+
77
+ while attempt < max_attempts:
78
+ predictions = fill_mask_roberta(sentence, top_k=top_k)
79
+
80
+ # Filtra le predizioni rimuovendo la punteggiatura
81
+ filtered_predictions = [
82
+ pred for pred in predictions if pred["token_str"] not in punctuation_marks
83
+ ]
84
+
85
+ # Se troviamo almeno una parola valida, interrompiamo il ciclo
86
+ if filtered_predictions:
87
+ break
88
+
89
+ attempt += 1
90
+
91
+ return filtered_predictions
92
+
93
 
94
  # Se l'utente ha inserito (o selezionato) un testo
95
  if input_text:
96
  # Sostituiamo [MASK] con <mask> (lo tokenizer Roberta se lo aspetta così)
97
  input_text_roberta = input_text.replace("[MASK]", "<mask>")
98
+
99
+
100
+ predictions_roberta = get_valid_predictions(input_text_roberta)
101
 
102
  st.subheader("Risultati delle previsioni:")
103
  for pred in predictions_roberta: