Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- coding:utf-8 _*- | |
""" | |
@author:quincy qiang | |
@license: Apache Licence | |
@file: generate.py | |
@time: 2023/04/17 | |
@contact: [email protected] | |
@software: PyCharm | |
@description: coding.. | |
""" | |
from typing import List, Optional | |
from langchain.llms.base import LLM | |
from langchain.llms.utils import enforce_stop_tokens | |
from transformers import AutoModel, AutoTokenizer | |
class ChatGLMService(LLM): | |
max_token: int = 10000 | |
temperature: float = 0.1 | |
top_p = 0.9 | |
history = [] | |
tokenizer: object = None | |
model: object = None | |
def __init__(self): | |
super().__init__() | |
def _llm_type(self) -> str: | |
return "ChatGLM" | |
def _call(self, | |
prompt: str, | |
stop: Optional[List[str]] = None) -> str: | |
response, _ = self.model.chat( | |
self.tokenizer, | |
prompt, | |
history=self.history, | |
max_length=self.max_token, | |
temperature=self.temperature, | |
) | |
if stop is not None: | |
response = enforce_stop_tokens(response, stop) | |
self.history = self.history + [[None, response]] | |
return response | |
def load_model(self, | |
model_name_or_path: str = "THUDM/chatglm-6b"): | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_name_or_path, | |
trust_remote_code=True | |
) | |
self.model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True).half().cuda() | |
self.model=self.model.eval() | |
# if __name__ == '__main__': | |
# config=LangChainCFG() | |
# chatLLM = ChatGLMService() | |
# chatLLM.load_model(model_name_or_path=config.llm_model_name) | |