File size: 3,363 Bytes
5c4969a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import torch
from typing import Union, Dict, Any
from modelscope.pipelines.builder import PIPELINES
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
from modelscope.pipelines.base import Pipeline
from modelscope.outputs import OutputKeys
from modelscope.pipelines.nlp.text_generation_pipeline import TextGenerationPipeline
from modelscope.models.base import Model, TorchModel
from modelscope.utils.logger import get_logger
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.generation.utils import GenerationConfig

@PIPELINES.register_module(Tasks.text_generation, module_name='Baichuan2-13B-chatbot-pipe')
class Baichuan13BChatTextGenerationPipeline(TextGenerationPipeline):
    def __init__(
            self,
            model: Union[Model, str],
            *args,
            **kwargs):
        self.model = Baichuan13BChatTextGeneration(model) if isinstance(model, str) else model
        super().__init__(model=model, **kwargs)
    
    def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
        return inputs
    
    def _sanitize_parameters(self, **pipeline_parameters):
        return {},pipeline_parameters,{}
    
    # define the forward pass
    def forward(self, inputs: Dict, **forward_params) -> Dict[str, Any]:
        output = {}
        device = self.model.model.device
        input_ids = self.model.tokenizer(inputs, return_tensors="pt").input_ids.to(device)
        pred = self.model.model.generate(input_ids,**forward_params)
        out = self.model.tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)
        output['text'] = out
        return output
    
    # format the outputs from pipeline
    def postprocess(self, input, **kwargs) -> Dict[str, Any]:
        return input
    

@MODELS.register_module(Tasks.text_generation, module_name='Baichuan2-13B-Chat')
class Baichuan13BChatTextGeneration(TorchModel):
    def __init__(self, model_dir=None, *args, **kwargs):
        super().__init__(model_dir, *args, **kwargs)
        self.logger = get_logger()
        # loading tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
        # self.model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto",trust_remote_code=True)
        self.model.generation_config = GenerationConfig.from_pretrained(model_dir)
        self.model = self.model.eval()
    
    def forward(self,input: Dict, *args, **kwargs) -> Dict[str, Any]:
        output = {}
        response = self.model.chat(self.tokenizer, input, *args, **kwargs)
        history = input.copy()
        history.append({'role': 'assistant', 'content': response})
        return {OutputKeys.RESPONSE:response, OutputKeys.HISTORY: history}
    
    def quantize(self, bits: int):
        self.model = self.model.quantize(bits)
        return self
    
    def infer(self, input, **kwargs):
        device = self.model.device
        input_ids = self.tokenizer(input, return_tensors="pt").input_ids.to(device)
        pred = self.model.generate(input_ids,**kwargs)
        out = self.tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)
        return out