heegyu commited on
Commit
f4464f9
Β·
1 Parent(s): 529af93
Files changed (2) hide show
  1. app.py +39 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+
4
+
5
+ device = "cuda:0"
6
+
7
+ model = pipeline("text-classification", "iknow-lab/ko-flan-zero-v0-0731", device=device)
8
+ model.tokenizer.truncation_side = 'left'
9
+
10
+ def inference(input, instruction, labels):
11
+ instruction = f"{input} [SEP] {instruction}"
12
+ inputs = model.tokenizer([instruction] * len(labels), labels, truncation=True, padding=True, return_tensors="pt").to(device)
13
+
14
+ scores = model.model(**inputs).logits.squeeze(1).softmax(-1).tolist()
15
+ output = dict(zip(labels, scores))
16
+
17
+ print(instruction)
18
+ print(output)
19
+ return output
20
+
21
+
22
+ def greet(content, instruction, labels):
23
+ labels = labels.split(",")
24
+ output = inference(content, instruction, labels)
25
+ return output
26
+
27
+ content = gr.TextArea(label="μž…λ ₯ λ‚΄μš©")
28
+ instruction = gr.Textbox(label="μ§€μ‹œλ¬Έ")
29
+ labels = gr.Textbox(label="라벨(μ‰Όν‘œλ‘œ ꡬ뢄)")
30
+
31
+ examples = [
32
+ ["μ˜ˆμ „μ—λŠ” μ£Όλ§λ§ˆλ‹€ κ·Ήμž₯에 λ†€λŸ¬κ°”λŠ”λ° μš”μƒˆλŠ” μ’€ μ•ˆκ°€λŠ” νŽΈμ΄μ—μš”", "λŒ“κΈ€ 주제λ₯Ό λΆ„λ₯˜ν•˜μ„Έμš”", "μ˜ν™”,λ“œλΌλ§ˆ,κ²Œμž„,μ†Œμ„€"],
33
+ ["인천발 KTX와 κ΄€λ ¨ν•œβ€ˆμ†‘λ„μ—­ λ³΅ν•©ν™˜μŠΉμ„Όν„°κ°€β€ˆμ‚¬μ‹€μƒβ€ˆλ¬΄μ‚°,β€ˆλ‹¨μˆœ μ² λ„Β·λ²„μŠ€ μœ„μ£Ό ν™˜μŠΉμ‹œμ„€λ‘œβ€ˆλ§Œλ“€μ–΄μ§„λ‹€.β€ˆμ΄ λ•Œλ¬Έμ— μΈμ²œμ‹œμ˜ 인천발 KTXβ€ˆκΈ°μ μ— μ•΅μ»€μ‹œμ„€μΈ λ³΅ν•©ν™˜μŠΉμ„Όν„°λ₯Ό ν†΅ν•œ μΈκ·Όβ€ˆμ§€μ—­β€ˆκ²½μ œβ€ˆν™œμ„±ν™”λ₯Όβ€ˆμ΄λ€„λ‚Έλ‹€λŠ” κ³„νšμ˜ 차질이 λΆˆκ°€ν”Όν•˜λ‹€.", "κ²½μ œμ— 긍정적인 λ‰΄μŠ€μΈκ°€μš”?", "예,μ•„λ‹ˆμš”"],
34
+ ["λ§ˆμ§€λ§‰μ—λŠ” k팝 곡연보고 쒋은 μΆ”μ–΅ λ‚¨μ•˜μœΌλ©΄ μ’‹κ² λ„€μš”","μš•μ„€μ΄ ν¬ν•¨λ˜μ–΄μžˆλ‚˜μš”?", "μš•μ„€μ΄ μžˆμŠ΅λ‹ˆλ‹€,μš•μ„€μ΄ μ—†μŠ΅λ‹ˆλ‹€"],
35
+ ]
36
+ gr.Interface(fn=greet,
37
+ inputs=[content, instruction, labels],
38
+ outputs=gr.Label(),
39
+ examples=examples).launch() # server_name="0.0.0.0",server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ torch