Spaces:
Sleeping
Sleeping
Initial commit
Browse files
app.py
CHANGED
@@ -14,12 +14,10 @@ warnings.filterwarnings("ignore", category=UserWarning, module="torchvision")
|
|
14 |
try:
|
15 |
tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased", use_fast=False)
|
16 |
model = AutoModelForMaskedLM.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased")
|
17 |
-
model_name = "airesearch/wangchanberta-base-att-spm-uncased"
|
18 |
except Exception:
|
19 |
st.warning("Switching to xlm-roberta-base model due to compatibility issues.")
|
20 |
-
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
|
21 |
model = AutoModelForMaskedLM.from_pretrained("xlm-roberta-base")
|
22 |
-
model_name = "xlm-roberta-base"
|
23 |
|
24 |
# Initialize the fill-mask pipeline
|
25 |
pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, framework="pt")
|
@@ -68,7 +66,7 @@ Feel free to enter your own sentence with `<mask>` and explore the predictions!
|
|
68 |
|
69 |
# User input box
|
70 |
st.subheader("Input Text")
|
71 |
-
input_text = st.text_input("Enter a sentence with `<mask>` to find similar predictions:", "
|
72 |
|
73 |
# Ensure the input includes a `<mask>`
|
74 |
if "<mask>" not in input_text:
|
@@ -90,10 +88,8 @@ if input_text:
|
|
90 |
result = pipe(input_text)
|
91 |
|
92 |
for r in result:
|
93 |
-
# Adjust based on observed output structure
|
94 |
prediction_text = r.get('sequence', '')
|
95 |
|
96 |
-
# Only proceed if we have a valid prediction text
|
97 |
if prediction_text:
|
98 |
prediction_embedding = get_embedding(prediction_text)
|
99 |
similarity = cosine_similarity(input_embedding, prediction_embedding)[0][0]
|
|
|
14 |
try:
|
15 |
tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased", use_fast=False)
|
16 |
model = AutoModelForMaskedLM.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased")
|
|
|
17 |
except Exception:
|
18 |
st.warning("Switching to xlm-roberta-base model due to compatibility issues.")
|
19 |
+
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base", use_fast=False)
|
20 |
model = AutoModelForMaskedLM.from_pretrained("xlm-roberta-base")
|
|
|
21 |
|
22 |
# Initialize the fill-mask pipeline
|
23 |
pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, framework="pt")
|
|
|
66 |
|
67 |
# User input box
|
68 |
st.subheader("Input Text")
|
69 |
+
input_text = st.text_input("Enter a sentence with `<mask>` to find similar predictions:", "นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน <mask> เพื่อสัมผัสธรรมชาติ")
|
70 |
|
71 |
# Ensure the input includes a `<mask>`
|
72 |
if "<mask>" not in input_text:
|
|
|
88 |
result = pipe(input_text)
|
89 |
|
90 |
for r in result:
|
|
|
91 |
prediction_text = r.get('sequence', '')
|
92 |
|
|
|
93 |
if prediction_text:
|
94 |
prediction_embedding = get_embedding(prediction_text)
|
95 |
similarity = cosine_similarity(input_embedding, prediction_embedding)[0][0]
|