Spaces:
Sleeping
Sleeping
# import streamlit as st | |
# x = st.slider('Select a value') | |
# st.write(x, 'squared is', x * x) | |
import sys | |
import os | |
import transformers | |
import json | |
assert ( | |
"LlamaTokenizer" in transformers._import_structure["models.llama"] | |
), "Please reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git" | |
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig | |
base_model = "https://huggingface.co/Shangding-Gu/Lunyu-LLM/" | |
tokenizer = LlamaTokenizer.from_pretrained(base_model) | |
load_8bit = False | |
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk | |
model.config.bos_token_id = 1 | |
model.config.eos_token_id = 2 | |
class Call_model(): | |
model.eval() | |
def evaluate(self, instruction): | |
final_output = self.inference(instruction+"\n\n### Response:") | |
return final_output | |
def inference(self, | |
batch_data, | |
input=None, | |
temperature=1, | |
top_p=0.95, | |
top_k=40, | |
num_beams=1, | |
max_new_tokens=4096, | |
**kwargs, | |
): | |
prompts = f"""A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {batch_data} ASSISTANT:""" | |
inputs = tokenizer(prompts, return_tensors="pt") | |
input_ids = inputs["input_ids"].to(device) | |
generation_config = GenerationConfig( | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
num_beams=num_beams, | |
**kwargs, | |
) | |
generation_output = model.generate( | |
input_ids=input_ids, | |
generation_config=generation_config, | |
return_dict_in_generate=True, | |
output_scores=True, | |
max_new_tokens=max_new_tokens, | |
) | |
s = generation_output.sequences | |
output = tokenizer.batch_decode(s, skip_special_tokens=True) | |
output = output[0].split("ASSISTANT:")[1].strip() | |
return output | |
if __name__ == "__main__": | |
prompt = input("Please input:") | |
prompt = str(prompt) | |
model_evaluate = Call_model() | |
prompt_state = model_evaluate.evaluate(prompt) | |
print(prompt_state) | |