Kossistant-1 / chatbot.py
Yuchan5386's picture
Update chatbot.py
cd17f7a verified
import re
import json
import numpy as np
import requests
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.text import tokenizer_from_json
from tensorflow.keras.preprocessing.sequence import pad_sequences
def load_tokenizer(filename):
with open(filename, 'r', encoding='utf-8') as f:
return tokenizer_from_json(json.load(f))
tokenizer_q = load_tokenizer('kossistant_q.json')
tokenizer_a = load_tokenizer('kossistant_a.json')
# ๋ชจ๋ธ ๋ฐ ํŒŒ๋ผ๋ฏธํ„ฐ ๋กœ๋“œ
model = load_model('kossistant.h5', compile=False)
max_len_q = model.input_shape[0][1]
max_len_a = model.input_shape[1][1]
index_to_word = {v: k for k, v in tokenizer_a.word_index.items()}
index_to_word[0] = ''
start_token = 'start'
end_token = 'end'
# ํ† ํฐ ์ƒ˜ํ”Œ๋ง ํ•จ์ˆ˜
def sample_from_top_p_top_k(prob_dist, top_p=0.85, top_k=40, temperature=0.8, repetition_penalty=1.4, generated_ids=[]):
logits = np.log(prob_dist + 1e-9) / temperature
for idx in generated_ids:
logits[idx] /= repetition_penalty
probs = np.exp(logits)
probs = probs / np.sum(probs)
top_k_indices = np.argsort(probs)[-top_k:]
top_k_probs = probs[top_k_indices]
sorted_indices = top_k_indices[np.argsort(top_k_probs)[::-1]]
sorted_probs = probs[sorted_indices]
cumulative_probs = np.cumsum(sorted_probs)
cutoff_index = np.searchsorted(cumulative_probs, top_p)
final_indices = sorted_indices[:cutoff_index + 1]
final_probs = probs[final_indices]
final_probs = final_probs / np.sum(final_probs)
return np.random.choice(final_indices, p=final_probs)
# ๋””์ฝ”๋”ฉ
def decode_sequence_custom(input_text, max_attempts=2):
input_seq = tokenizer_q.texts_to_sequences([input_text])
input_seq = pad_sequences(input_seq, maxlen=max_len_q, padding='post')
for _ in range(max_attempts + 1):
target_seq = tokenizer_a.texts_to_sequences([start_token])[0]
target_seq = pad_sequences([target_seq], maxlen=max_len_a, padding='post')
decoded_sentence = ''
generated_ids = []
for i in range(max_len_a):
predictions = model.predict([input_seq, target_seq], verbose=0)
prob_dist = predictions[0, i, :]
pred_id = sample_from_top_p_top_k(prob_dist, generated_ids=generated_ids)
generated_ids.append(pred_id)
pred_word = index_to_word.get(pred_id, '')
if pred_word == end_token:
break
decoded_sentence += pred_word + ' '
if i + 1 < max_len_a:
target_seq[0, i + 1] = pred_id
cleaned = re.sub(r'\b<end>\b', '', decoded_sentence)
cleaned = re.sub(r'\s+', ' ', cleaned)
if is_valid_response(cleaned):
return cleaned.strip()
return "์ฃ„์†กํ•ด์š”, ๋‹ต๋ณ€ ์ƒ์„ฑ์— ์‹คํŒจํ–ˆ์–ด์š”."
def is_valid_response(response):
if len(response.strip()) < 2:
return False
if re.search(r'[ใ„ฑ-ใ…Žใ…-ใ…ฃ]{3,}', response):
return False
if len(response.split()) < 2:
return False
if response.count(' ') < 2:
return False
if any(tok in response.lower() for tok in ['hello', 'this', 'ใ…‹ใ…‹']):
return False
return True
def extract_main_query(text):
sentences = re.split(r'[.?!]\s*', text)
sentences = [s.strip() for s in sentences if s.strip()]
if not sentences:
return text
last = sentences[-1]
last = re.sub(r'[^๊ฐ€-ํžฃa-zA-Z0-9 ]', '', last)
particles = ['์ด', '๊ฐ€', '์€', '๋Š”', '์„', '๋ฅผ', '์˜', '์—์„œ', '์—๊ฒŒ', 'ํ•œํ…Œ', '๋ณด๋‹ค']
for p in particles:
last = re.sub(rf'\b(\w+){p}\b', r'\1', last)
return last.strip()
def get_wikipedia_summary(query):
cleaned_query = extract_main_query(query)
url = f"https://ko.wikipedia.org/api/rest_v1/page/summary/{cleaned_query}"
res = requests.get(url)
if res.status_code == 200:
return res.json().get("extract", "์š”์•ฝ ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
else:
return "์œ„ํ‚ค๋ฐฑ๊ณผ์—์„œ ์ •๋ณด๋ฅผ ๊ฐ€์ ธ์˜ฌ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
def simple_intent_classifier(text):
text = text.lower()
greet_keywords = ["์•ˆ๋…•", "๋ฐ˜๊ฐ€์›Œ", "์ด๋ฆ„", "๋ˆ„๊ตฌ", "์†Œ๊ฐœ", "์–ด๋””์„œ ์™”", "์ •์ฒด", "๋ช‡ ์‚ด", "๋„ˆ ๋ญ์•ผ"]
info_keywords = ["์„ค๋ช…", "์ •๋ณด", "๋ฌด์—‡", "๋ญ์•ผ", "์–ด๋””", "๋ˆ„๊ตฌ", "์™œ", "์–ด๋–ป๊ฒŒ", "์ข…๋ฅ˜", "๊ฐœ๋…"]
math_keywords = ["๋”ํ•˜๊ธฐ", "๋นผ๊ธฐ", "๊ณฑํ•˜๊ธฐ", "๋‚˜๋ˆ„๊ธฐ", "๋ฃจํŠธ", "์ œ๊ณฑ", "+", "-", "*", "/", "=", "^", "โˆš", "๊ณ„์‚ฐ", "๋ช‡์ด์•ผ", "์–ผ๋งˆ์•ผ"]
if any(kw in text for kw in greet_keywords):
return "์ธ์‚ฌ"
elif any(kw in text for kw in info_keywords):
return "์ •๋ณด์งˆ๋ฌธ"
elif any(kw in text for kw in math_keywords):
return "์ˆ˜ํ•™์งˆ๋ฌธ"
else:
return "์ผ์ƒ๋Œ€ํ™”"
def parse_math_question(text):
text = text.replace("๊ณฑํ•˜๊ธฐ", "*").replace("๋”ํ•˜๊ธฐ", "+").replace("๋นผ๊ธฐ", "-").replace("๋‚˜๋ˆ„๊ธฐ", "/").replace("์ œ๊ณฑ", "*2")
text = re.sub(r'๋ฃจํŠธ\s(\d+)', r'math.sqrt(\1)', text)
try:
result = eval(text)
return f"์ •๋‹ต์€ {result}์ž…๋‹ˆ๋‹ค."
except:
return "๊ณ„์‚ฐํ•  ์ˆ˜ ์—†๋Š” ์ˆ˜์‹์ด์—์š”. ๋‹ค์‹œ ํ•œ๋ฒˆ ํ™•์ธํ•ด ์ฃผ์„ธ์š”!"
# ์ „์ฒด ์‘๋‹ต ํ•จ์ˆ˜
def respond(input_text):
intent = simple_intent_classifier(input_text)
if "/์‚ฌ์šฉ๋ฒ•" in input_text:
return "์ž์œ ๋กญ๊ฒŒ ์‚ฌ์šฉํ•ด์ฃผ์„ธ์š”. ๋”ฑํžˆ ์ œ์•ฝ์€ ์—†์Šต๋‹ˆ๋‹ค."
if "์ด๋ฆ„" in input_text:
return "์ œ ์ด๋ฆ„์€ kossistant์ž…๋‹ˆ๋‹ค."
if "๋ˆ„๊ตฌ" in input_text:
return "์ €๋Š” kossistant์ด๋ผ๊ณ  ํ•ด์š”."
if intent == "์ˆ˜ํ•™์งˆ๋ฌธ":
return parse_math_question(input_text)
if intent == "์ •๋ณด์งˆ๋ฌธ":
keyword = re.sub(r"(์— ๋Œ€ํ•ด|์— ๋Œ€ํ•œ|์— ๋Œ€ํ•ด์„œ)?\s*(์„ค๋ช…ํ•ด์ค˜|์•Œ๋ ค์ค˜|๋ญ์•ผ|๊ฐœ๋…|์ •์˜|์ •๋ณด)?", "", input_text).strip()
if not keyword:
return "์–ด๋–ค ์ฃผ์ œ์— ๋Œ€ํ•ด ๊ถ๊ธˆํ•œ๊ฐ€์š”?"
summary = get_wikipedia_summary(keyword)
return f"{summary}\n๋‹ค๋ฅธ ๊ถ๊ธˆํ•œ ์  ์žˆ์œผ์‹ ๊ฐ€์š”?"
return decode_sequence_custom(input_text)