Spaces:
Sleeping
Sleeping
Update app/model/model.py
Browse files- app/model/model.py +6 -2
app/model/model.py
CHANGED
@@ -28,9 +28,11 @@ class LLM:
|
|
28 |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
29 |
if tokenizer.pad_token_id is None:
|
30 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
31 |
return model, tokenizer
|
32 |
|
33 |
def language_detection(self, input_text):
|
|
|
34 |
# Prompt with one shot for each language
|
35 |
prompt = f"""Identify the language of the following sentences. Options: 'english', 'español', 'française' .
|
36 |
* <Identity theft is not a joke, millions of families suffer every year>(english)
|
@@ -39,9 +41,11 @@ class LLM:
|
|
39 |
* <{input_text}>"""
|
40 |
# Generation and extraction of the language tag
|
41 |
answer_ids = self.model.generate(**self.tokenizer([prompt], return_tensors="pt"), max_new_tokens=10)
|
42 |
-
answer = self.tokenizer.batch_decode(answer_ids, skip_special_tokens=False)[0]
|
|
|
|
|
43 |
pattern = r'\b(?:' + '|'.join(map(re.escape, self.lang_codes.keys())) + r')\b'
|
44 |
-
lang = re.search(pattern,
|
45 |
# Returns tag identified or 'unk' if none is detected
|
46 |
return self.lang_codes[lang.group()] if lang else self.lang_codes["unknown"]
|
47 |
|
|
|
28 |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
29 |
if tokenizer.pad_token_id is None:
|
30 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
31 |
+
print("Model and tokenizer loaded.")
|
32 |
return model, tokenizer
|
33 |
|
34 |
def language_detection(self, input_text):
|
35 |
+
print(f"### Input text\n{input_text}")
|
36 |
# Prompt with one shot for each language
|
37 |
prompt = f"""Identify the language of the following sentences. Options: 'english', 'español', 'française' .
|
38 |
* <Identity theft is not a joke, millions of families suffer every year>(english)
|
|
|
41 |
* <{input_text}>"""
|
42 |
# Generation and extraction of the language tag
|
43 |
answer_ids = self.model.generate(**self.tokenizer([prompt], return_tensors="pt"), max_new_tokens=10)
|
44 |
+
answer = self.tokenizer.batch_decode(answer_ids, skip_special_tokens=False)[0]
|
45 |
+
print(answer)
|
46 |
+
generation = answer.split(prompt)[1]
|
47 |
pattern = r'\b(?:' + '|'.join(map(re.escape, self.lang_codes.keys())) + r')\b'
|
48 |
+
lang = re.search(pattern, generation, flags=re.IGNORECASE)
|
49 |
# Returns tag identified or 'unk' if none is detected
|
50 |
return self.lang_codes[lang.group()] if lang else self.lang_codes["unknown"]
|
51 |
|