{ "cells": [ { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import random\n", "import time\n", "from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer\n", "\n", "model_name=\"heegyu/bluechat-v0\"\n", "device=\"cuda:0\" if torch.cuda.is_available() else 'cpu'\n", "model = AutoModelForCausalLM.from_pretrained(model_name)\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "\n", "def query(prompt, min_new_tokens=16, max_turn=4):\n", " ids = tokenizer(prompt.strip(), return_tensors=\"pt\").to(device)\n", " min_length = ids['input_ids'].shape[1] + min_new_tokens\n", " output = model.generate(\n", " **ids,\n", " no_repeat_ngram_size=3,\n", " eos_token_id=2, # 375=\\n 2=, 0:open-end\n", " max_new_tokens=128,\n", " min_length=min_length,\n", " do_sample=True,\n", " top_p=0.7,\n", " early_stopping=True\n", " ) # [0]['generated_text']\n", " output = tokenizer.decode(output.cpu()[0])\n", " print(output)\n", "\n", " # response = output[len(prompt):]\n", " # return response.strip()" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "0 : 안녕하세요\n", "1 : 반가워요\n", "0 : 요즘 좋아하는 음악 있으신가요?\n", "1 : 최근에 들어서인지 너무 많이 들어요\n", "0 : 음 주로 어떤거요?\n", "1 : \n", " music : music songs 수록곡을 즐겨들어요 앗 어떤 장르를 주로 들으시나요?\n", "1 : music songs 좋죠\n", "bot> 저도 요즘 들어 좋아하게 된 곡들 위주로 들어요 ㅎㅎ\n", "2 : music songs 어떤 노래들 자주 들어요?\n", "bot> 저 music songs someone이 제일 좋더라구요 ㅎㅎ\n", "1 : music songs는 어떤 곡들 주로 들어요?\n", "bot> 저 music songs는 주로 music songs를 많이 들어요 ㅎㅎ\n" ] } ], "source": [ "query(\"\"\"\n", "0 : 안녕하세요\n", "1 : 반가워요\n", "0 : 요즘 좋아하는 음악 있으신가요?\n", "1 : 최근에 들어서인지 너무 많이 들어요\n", "0 : 음 주로 어떤거요?\n", "1 : \n", "\"\"\")" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 안녕하세요\n", " 안녕하세요~ 저녁 드셨나요? ㅎㅎ? ㅎㅎ\n" ] } ], "source": [ "query(\"\"\"\n", " 안녕하세요\n", "\n", "\"\"\", 8)" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 안녕하세요 식사 하셨나요?\n", " 안녕하세요 네~ 점심 먹었어요 식사하셨나요?\n", "네~ 뭐드셨나요?\n" ] } ], "source": [ "query(\"\"\"\n", " 안녕하세요 식사 하셨나요?\n", "\n", "\"\"\", 8)" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 창업에 관심이 있나요?\n", " 네! 근데 요즘 창업에 대한 관심이 많이 떨어지더라구요\n" ] } ], "source": [ "query(\"\"\"\n", " 창업에 관심이 있나요?\n", "\n", "\"\"\", 8)" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }