--- license: apache-2.0 --- ## 一、基于baichuan 7b模型进行sft,对其人类意图 ## 二、sft数据是在开源MOSS数据中通过各个类别均衡采样15w数据进行sft ## 模型推理 Install package: ``` pip install transformers pip install sentencepiece pip install vllm ``` ### huggingface结合fastapi起服务,支持多轮对话 ```python from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import torch import uvicorn from fastapi import FastAPI import jsonlines device = 'cuda' model_name = 'mxmax/baichuan-7b-sft-001' max_new_tokens = 500 top_p = 0.9 temperature = 0.35 repetition_penalty = 1.0 tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.float16, device_map={'': 0}#'auto' ).cuda() # model = PeftModel.from_pretrained(model, adapter_name) model.eval() model = model.to(device) # 输入模型的最大长度 history_max_len = 1024 def model_infer(user_input): history_token_ids = tokenizer('', return_tensors="pt").input_ids user_input_ids = tokenizer(user_input, return_tensors="pt").input_ids history_token_ids = torch.concat((history_token_ids, user_input_ids[:, -history_max_len:]), dim=1) model_input_ids = history_token_ids.to(device) outputs = model.generate( input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id ) model_input_ids_len = model_input_ids.size(1) response_ids = outputs[:, model_input_ids_len:] response = tokenizer.batch_decode(response_ids) return response[0].strip().replace('', "") app = FastAPI() @app.get('/') async def root(): return {"msg": "Hello World"} @app.post('/baichuan_sft_001') async def baichuan_sft_001(message: dict): prompt = '' for l in message['context']: prompt += 'human:'+l['human']+'\nassistant:'+l['assistant']+'' result = model_infer(prompt) message['context'][-1]['assistant'] = result return {'model_ouput':result} if __name__ == '__main__': uvicorn.run('model_serving:app',host="0.0.0.0", port=6006) ``` ### vllm结合fastapi起服务,加速推理,支持多轮对话 ```python from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import torch import uvicorn from fastapi import FastAPI import jsonlines from vllm import LLM, SamplingParams device = 'cuda' model_name = 'mxmax/baichuan-7b-sft-001' max_new_tokens = 512 top_p = 0.9 temperature = 0.35 repetition_penalty = 0.1 history_max_len = 1024 sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_new_tokens, presence_penalty=repetition_penalty) # Create an LLM. llm = LLM(model=model_name,trust_remote_code=True,dtype='float16') file = jsonlines.open('chat_record.json','a') app = FastAPI() @app.get('/') async def root(): return {"msg": "Hello World"} @app.post('/baichuan_sft_001') async def baichuan_sft_001(message: dict): prompt = '' for l in message['context']: prompt += 'human:'+l['human']+'\nassistant:'+l['assistant']+'' prompt = ''+prompt[-history_max_len:] outputs = llm.generate([prompt], sampling_params) result = outputs[0].outputs[0].text message['context'][-1]['assistant'] = result return {'model_ouput':result} if __name__ == '__main__': uvicorn.run('vllm_serving:app',host="0.0.0.0", port=6006) ``` ## 模型效果展示 ![arch](./images/1.jpg) ![arch](./images/2.jpg) ![arch](./images/3.jpg) ![arch](./images/4.jpg) ![arch](./images/5.jpg) ![arch](./images/6.jpg) ## 联系方式 ![arch](./images/微信好友二维码.jpg) 加好友请备注:来自于huggingface网站交流技术+名字 qq群:621725172 ## 引用 ```bash @misc{mxmax, title={baichuan_sft: baichuan-7b-sft-001}, author={Ma Xin}, year={2023}, howpublished={\url{https://huggingface.co/mxmax/baichuan-7b-sft-001}}, } ```