heegyu commited on
Commit
9366986
·
1 Parent(s): 357cd33

간단데모

Browse files
Files changed (2) hide show
  1. app.py +48 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ @st.cache(allow_output_mutation=True)
4
+ def get_pipe():
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ model_name = "heegyu/koalpaca-355m"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ tokenizer.truncation_side = "right"
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
10
+ return model, tokenizer
11
+
12
+ def get_response(tokenizer, model, context):
13
+ context = f"<usr>{context}\n<sys>"
14
+ inputs = tokenizer(
15
+ context,
16
+ truncation=True,
17
+ max_length=512,
18
+ return_tensors="pt")
19
+
20
+ generation_args = dict(
21
+ max_length=256,
22
+ min_length=64,
23
+ eos_token_id=2,
24
+ do_sample=True,
25
+ top_p=1.0,
26
+ early_stopping=True
27
+ )
28
+
29
+ outputs = model.generate(**inputs, **generation_args)
30
+ response = tokenizer.decode(outputs[0])
31
+ print(context)
32
+ print(response)
33
+ response = response[len(context):].replace("</s>", "")
34
+
35
+ return response
36
+
37
+ st.title("KoAlpaca-355M")
38
+
39
+ with st.spinner("loading model..."):
40
+ model, tokenizer = get_pipe()
41
+
42
+ input_ = st.text_area("질문해보세요", value="미국과 중국의 갈등의 원인이 뭐야?")
43
+ ok = st.button("물어보기")
44
+ if input_ is not None and ok and len(input_) > 0:
45
+ with st.spinner("잠시만요"):
46
+ response = get_response(tokenizer, model, input_)
47
+ st.text("대답")
48
+ st.success(response)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit
2
+ transformers
3
+ torch