IlyaGusev commited on
Commit
3dd3754
1 Parent(s): 5f885e5

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +98 -0
README.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - IlyaGusev/gpt_roleplay_realm
4
+ language:
5
+ - en
6
+ pipeline_tag: text-generation
7
+ ---
8
+
9
+ LLaMA 7B fine-tuned on the `gpt_roleplay_realm` dataset.
10
+
11
+ Code example:
12
+ ```
13
+ from peft import PeftModel, PeftConfig
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
15
+
16
+ MODEL_NAME = "IlyaGusev/rpr_7b"
17
+ DEFAULT_MESSAGE_TEMPLATE = "<s>{role}\n{content}</s>\n"
18
+
19
+ class Conversation:
20
+ def __init__(
21
+ self,
22
+ system_prompt,
23
+ message_template=DEFAULT_MESSAGE_TEMPLATE,
24
+ start_token_id=1,
25
+ bot_token_id=9225
26
+ ):
27
+ self.message_template = message_template
28
+ self.start_token_id = start_token_id
29
+ self.bot_token_id = bot_token_id
30
+ self.messages = [{
31
+ "role": "system",
32
+ "content": system_prompt
33
+ }]
34
+
35
+ def get_start_token_id(self):
36
+ return self.start_token_id
37
+
38
+ def get_bot_token_id(self):
39
+ return self.bot_token_id
40
+
41
+ def add_user_message(self, message):
42
+ self.messages.append({
43
+ "role": "user",
44
+ "content": message
45
+ })
46
+
47
+ def add_bot_message(self, message):
48
+ self.messages.append({
49
+ "role": "bot",
50
+ "content": message
51
+ })
52
+
53
+ def get_prompt(self, tokenizer):
54
+ final_text = ""
55
+ for message in self.messages:
56
+ message_text = self.message_template.format(**message)
57
+ final_text += message_text
58
+ final_text += tokenizer.decode([self.start_token_id, self.bot_token_id])
59
+ return final_text.strip()
60
+
61
+
62
+ def generate(model, tokenizer, prompt, generation_config):
63
+ data = tokenizer(prompt, return_tensors="pt")
64
+ data = {k: v.to(model.device) for k, v in data.items()}
65
+ output_ids = model.generate(**data,generation_config=generation_config)[0]
66
+ output_ids = output_ids[len(data["input_ids"][0]):]
67
+ output = tokenizer.decode(output_ids, skip_special_tokens=True)
68
+ return output.strip()
69
+
70
+
71
+ config = PeftConfig.from_pretrained(MODEL_NAME)
72
+ model = AutoModelForCausalLM.from_pretrained(
73
+ config.base_model_name_or_path,
74
+ load_in_8bit=True,
75
+ torch_dtype=torch.float16,
76
+ device_map="auto"
77
+ )
78
+ model = PeftModel.from_pretrained(
79
+ model,
80
+ MODEL_NAME,
81
+ torch_dtype=torch.float16
82
+ )
83
+ model.eval()
84
+
85
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
86
+ generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
87
+ print(generation_config)
88
+
89
+ system_prompt = "You are Chiharu Yamada. Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology."
90
+ conversation = Conversation(system_prompt=system_prompt)
91
+ for inp in inputs:
92
+ inp = input()
93
+ conversation.add_user_message(inp)
94
+ prompt = conversation.get_prompt(tokenizer)
95
+ output = generate(model, tokenizer, prompt, generation_config)
96
+ conversation.add_bot_message(output)
97
+ print(output)
98
+ ```