File size: 2,045 Bytes
73c9f52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6c11c7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
import torch
from transformers import AutoTokenizer

#def greet(name):
#    return "Hello " + name + "!!"

def greet(sent,mode):
  print("input_sent= " + sent)
  if mode=='Malicious_comment':
    pt_model ='best.pt'
  if mode=='Economic_article':
    pt_model ='best2.pt'
  
  
  print(pt_model)
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  print("device:",device)
  
  device = "cuda" if torch.cuda.is_available() else "cpu"
  model = torch.load(pt_model, map_location=device)
  print(model)

  MODEL_NAME = "beomi/KcELECTRA-base" # hugging face 에 등록된 모델
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
  
  model.eval() # 평가

  # 입력문장 토크나이징
  tokenized_sent = tokenizer(
      sent,
      return_tensors="pt",
      truncation=True, 
      add_special_tokens=True, 
      max_length=128
      )

  # 모델 위치 gpu이동
  tokenized_sent.to(device)

  # 예측
  with torch.no_grad():
    outputs = model(
        input_ids=tokenized_sent["input_ids"],
        attention_mask=tokenized_sent["attention_mask"],
        token_type_ids=tokenized_sent["token_type_ids"],
    )

  # 결과
  logits = outputs[0]   ## 마지막 노드에서 아무런 Activation Function을 거치지 않은 값을 Logit
  logits = logits.detach().cpu()
  result = logits.argmax(-1)
  if mode=='Malicious_comment':
    if result == 0:
      result = sent + ">> 악성글로 판단됩니다. 조심하세요."
      
    elif result ==1:
      result= sent + ">> 악의적인 내용이 보이지 않습니다."
  elif mode=='Economic_article':
    if result == 0:
      result = "중립"
    elif result == 1:
      result = "긍정"
    elif result == 2:
      result = "부정"
    
    
  return result
intput="text"
input2= gr.Dropdown(choices=['Malicious_comment','Economic_article'])
iface = gr.Interface(fn=greet,title='Korean classification',description="한국어 악플 && 경제기사 긍부정 판별기",inputs=[intput,input2], outputs="text")
iface.launch()