Yuchan5386 commited on
Commit
a7caa3a
·
verified ·
1 Parent(s): 9c9e250

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -1,18 +1,20 @@
 
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
- model_name = "beomi/gemma-ko-2b" # 혹은 더 가벼운 모델 선택
5
 
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32)
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model.to(device)
11
 
12
- prompt = "인공지능이란 무엇인가?"
13
- with torch.no_grad():
14
- tokens = tokenizer.encode(prompt, return_tensors='pt').to(device)
15
- gen_tokens = model.generate(tokens, do_sample=True, temperature=0.8, max_length=64)
16
- generated = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
17
 
18
- print(generated)
 
 
1
+ import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ model_name = "beomi/gemma-ko-2b"
6
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32)
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model.to(device)
12
 
13
+ def chatbot(prompt):
14
+ with torch.no_grad():
15
+ tokens = tokenizer(prompt, return_tensors='pt').to(device)
16
+ gen_tokens = model.generate(tokens, do_sample=True, temperature=0.8, max_length=64)
17
+ return tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
18
 
19
+ iface = gr.Interface(fn=chatbot, inputs="text", outputs="text")
20
+ iface.launch()