{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import random\n",
    "import time\n",
    "from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "model_name=\"heegyu/bluechat-v0\"\n",
    "device=\"cuda:0\" if torch.cuda.is_available() else 'cpu'\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def query(prompt, min_new_tokens=16, max_turn=4):\n",
    "    ids = tokenizer(prompt.strip(), return_tensors=\"pt\").to(device)\n",
    "    min_length = ids['input_ids'].shape[1] + min_new_tokens\n",
    "    output = model.generate(\n",
    "        **ids,\n",
    "        no_repeat_ngram_size=3,\n",
    "        eos_token_id=2, # 375=\\n 2=</s>, 0:open-end\n",
    "        max_new_tokens=128,\n",
    "        min_length=min_length,\n",
    "        do_sample=True,\n",
    "        top_p=0.7,\n",
    "        early_stopping=True\n",
    "    ) # [0]['generated_text']\n",
    "    output = tokenizer.decode(output.cpu()[0])\n",
    "    print(output)\n",
    "\n",
    "    # response = output[len(prompt):]\n",
    "    # return response.strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "0 : 안녕하세요</s>\n",
      "1 : 반가워요</s>\n",
      "0 : 요즘 좋아하는 음악 있으신가요?</s>\n",
      "1 : 최근에 들어서인지 너무 많이 들어요</s>\n",
      "0 : 음 주로 어떤거요?</s>\n",
      "1 : \n",
      " music : music songs 수록곡을 즐겨들어요</s><bot> 앗 어떤 장르를 주로 들으시나요?</s>\n",
      "1 : music songs 좋죠</s>\n",
      "bot> 저도 요즘 들어 좋아하게 된 곡들 위주로 들어요 ㅎㅎ</s>\n",
      "2 : music songs 어떤 노래들 자주 들어요?</s>\n",
      "bot> 저 music songs someone이 제일 좋더라구요 ㅎㅎ</s>\n",
      "1 : music songs는 어떤 곡들 주로 들어요?</s>\n",
      "bot> 저 music songs는 주로 music songs를 많이 들어요 ㅎㅎ</s>\n"
     ]
    }
   ],
   "source": [
    "query(\"\"\"\n",
    "0 : 안녕하세요</s>\n",
    "1 : 반가워요</s>\n",
    "0 : 요즘 좋아하는 음악 있으신가요?</s>\n",
    "1 : 최근에 들어서인지 너무 많이 들어요</s>\n",
    "0 : 음 주로 어떤거요?</s>\n",
    "1 : \n",
    "\"\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<usr> 안녕하세요\n",
      "<bot> 안녕하세요~ 저녁 드셨나요? ㅎㅎ? ㅎㅎ</s>\n"
     ]
    }
   ],
   "source": [
    "query(\"\"\"\n",
    "<usr> 안녕하세요\n",
    "<bot>\n",
    "\"\"\", 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<usr> 안녕하세요 식사 하셨나요?\n",
      "<bot> 안녕하세요 네~ 점심 먹었어요 식사하셨나요?\n",
      "네~ 뭐드셨나요?</s>\n"
     ]
    }
   ],
   "source": [
    "query(\"\"\"\n",
    "<usr> 안녕하세요 식사 하셨나요?\n",
    "<bot>\n",
    "\"\"\", 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<usr> 창업에 관심이 있나요?\n",
      "<bot> 네! 근데 요즘 창업에 대한 관심이 많이 떨어지더라구요</s>\n"
     ]
    }
   ],
   "source": [
    "query(\"\"\"\n",
    "<usr> 창업에 관심이 있나요?\n",
    "<bot>\n",
    "\"\"\", 8)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}