wenge-research commited on
Commit
f1a9e8d
·
1 Parent(s): 18a4ed3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +49 -1
README.md CHANGED
@@ -30,5 +30,53 @@ tags:
30
 
31
  ## 运行方式
32
 
33
- Comming Soon~
 
 
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  ## 运行方式
32
 
33
+ ```python
34
+ import torch
35
+ from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig
36
+ from transformers import StoppingCriteria, StoppingCriteriaList
37
 
38
+ pretrained_model_name_or_path = "wenge-research/yayi-7b-llama2"
39
+ tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_name_or_path)
40
+ model = LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=False)
41
+
42
+ # Define the stopping criteria
43
+ class KeywordsStoppingCriteria(StoppingCriteria):
44
+ def __init__(self, keywords_ids:list):
45
+ self.keywords = keywords_ids
46
+
47
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
48
+ if input_ids[0][-1] in self.keywords:
49
+ return True
50
+ return False
51
+
52
+ stop_words = ["<|End|>", "<|YaYi|>", "<|Human|>", "</s>"]
53
+ stop_ids = [tokenizer.encode(w)[-1] for w in stop_words]
54
+ stop_criteria = KeywordsStoppingCriteria(stop_ids)
55
+
56
+ # inference
57
+ prompt = "你是谁?"
58
+ formatted_prompt = f"""<|System|>:
59
+ You are a helpful, respectful and honest assistant named YaYi developed by Beijing Wenge Technology Co.,Ltd. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
60
+
61
+ <|Human|>:
62
+ {prompt}
63
+
64
+ <|YaYi|>:
65
+ """
66
+
67
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
68
+ eos_token_id = tokenizer("<|End|>").input_ids[0]
69
+ generation_config = GenerationConfig(
70
+ eos_token_id=eos_token_id,
71
+ pad_token_id=eos_token_id,
72
+ do_sample=True,
73
+ max_new_tokens=256,
74
+ temperature=0.3,
75
+ repetition_penalty=1.1,
76
+ no_repeat_ngram_size=0
77
+ )
78
+ response = model.generate(**inputs, generation_config=generation_config, stopping_criteria=StoppingCriteriaList([stop_criteria]))
79
+ response = [response[0][len(inputs.input_ids[0]):]]
80
+ response_str = tokenizer.batch_decode(response, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
81
+ print(response_str)
82
+ ```