kritsadaK commited on
Commit
4b84b24
1 Parent(s): 16ab143

Initial commit

Browse files
Files changed (1) hide show
  1. app.py +2 -6
app.py CHANGED
@@ -14,12 +14,10 @@ warnings.filterwarnings("ignore", category=UserWarning, module="torchvision")
14
  try:
15
  tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased", use_fast=False)
16
  model = AutoModelForMaskedLM.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased")
17
- model_name = "airesearch/wangchanberta-base-att-spm-uncased"
18
  except Exception:
19
  st.warning("Switching to xlm-roberta-base model due to compatibility issues.")
20
- tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
21
  model = AutoModelForMaskedLM.from_pretrained("xlm-roberta-base")
22
- model_name = "xlm-roberta-base"
23
 
24
  # Initialize the fill-mask pipeline
25
  pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, framework="pt")
@@ -68,7 +66,7 @@ Feel free to enter your own sentence with `<mask>` and explore the predictions!
68
 
69
  # User input box
70
  st.subheader("Input Text")
71
- input_text = st.text_input("Enter a sentence with `<mask>` to find similar predictions:", "ผู้ใช้งานท่าอากาศยานนานาชาติ <mask> มีกว่าสามล้านคน")
72
 
73
  # Ensure the input includes a `<mask>`
74
  if "<mask>" not in input_text:
@@ -90,10 +88,8 @@ if input_text:
90
  result = pipe(input_text)
91
 
92
  for r in result:
93
- # Adjust based on observed output structure
94
  prediction_text = r.get('sequence', '')
95
 
96
- # Only proceed if we have a valid prediction text
97
  if prediction_text:
98
  prediction_embedding = get_embedding(prediction_text)
99
  similarity = cosine_similarity(input_embedding, prediction_embedding)[0][0]
 
14
  try:
15
  tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased", use_fast=False)
16
  model = AutoModelForMaskedLM.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased")
 
17
  except Exception:
18
  st.warning("Switching to xlm-roberta-base model due to compatibility issues.")
19
+ tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base", use_fast=False)
20
  model = AutoModelForMaskedLM.from_pretrained("xlm-roberta-base")
 
21
 
22
  # Initialize the fill-mask pipeline
23
  pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, framework="pt")
 
66
 
67
  # User input box
68
  st.subheader("Input Text")
69
+ input_text = st.text_input("Enter a sentence with `<mask>` to find similar predictions:", "นักท่องเที่ยวจำนวนมากเลือกที่จะไปเยือน <mask> เพื่อสัมผัสธรรมชาติ")
70
 
71
  # Ensure the input includes a `<mask>`
72
  if "<mask>" not in input_text:
 
88
  result = pipe(input_text)
89
 
90
  for r in result:
 
91
  prediction_text = r.get('sequence', '')
92
 
 
93
  if prediction_text:
94
  prediction_embedding = get_embedding(prediction_text)
95
  similarity = cosine_similarity(input_embedding, prediction_embedding)[0][0]