Dhahlan2000 commited on
Commit
fe172d6
·
verified ·
1 Parent(s): 7ba4ef3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -17
app.py CHANGED
@@ -44,25 +44,28 @@ def transliterate_to_sinhala(text):
44
  return transliterate.process('Velthuis', 'Sinhala', text)
45
 
46
  # Load conversation model
47
- conv_model_name = "microsoft/Phi-3-mini-4k-instruct" # Use GPT-2 instead of the gated model
48
- tokenizer = AutoTokenizer.from_pretrained(conv_model_name, trust_remote_code=True)
49
- model = AutoModelForCausalLM.from_pretrained(conv_model_name, trust_remote_code=True).to(device)
 
 
50
 
51
  def conversation_predict(text):
52
- pipe = pipeline(
53
- "text-generation",
54
- model=model,
55
- tokenizer=tokenizer,
56
- )
57
- generation_args = {
58
- "max_new_tokens": 500,
59
- "return_full_text": False,
60
- "temperature": 0.0,
61
- "do_sample": False,
62
- }
63
-
64
- output = pipe(text, **generation_args)
65
- return output[0]['generated_text']
 
66
  # input_ids = tokenizer(text, return_tensors="pt").to(device)
67
  # outputs = model.generate(**input_ids)
68
  # return tokenizer.decode(outputs[0])
 
44
  return transliterate.process('Velthuis', 'Sinhala', text)
45
 
46
  # Load conversation model
47
+ # conv_model_name = "microsoft/Phi-3-mini-4k-instruct" # Use GPT-2 instead of the gated model
48
+ # tokenizer = AutoTokenizer.from_pretrained(conv_model_name, trust_remote_code=True)
49
+ # model = AutoModelForCausalLM.from_pretrained(conv_model_name, trust_remote_code=True).to(device)
50
+
51
+ client = InferenceClient("google/gemma-2b-it")
52
 
53
  def conversation_predict(text):
54
+ return client.text_generation(text, return_full_text=False)
55
+ # pipe = pipeline(
56
+ # "text-generation",
57
+ # model=model,
58
+ # tokenizer=tokenizer,
59
+ # )
60
+ # generation_args = {
61
+ # "max_new_tokens": 500,
62
+ # "return_full_text": False,
63
+ # "temperature": 0.0,
64
+ # "do_sample": False,
65
+ # }
66
+
67
+ # output = pipe(text, **generation_args)
68
+ # return output[0]['generated_text']
69
  # input_ids = tokenizer(text, return_tensors="pt").to(device)
70
  # outputs = model.generate(**input_ids)
71
  # return tokenizer.decode(outputs[0])