import os import torch import gradio as gr from transformers import GemmaTokenizer, AutoModelForCausalLM from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, TextIteratorStreamer from threading import Thread # Set an environment variable token = os.getenv('HUGGINGFACE_TOKEN') model_path= "CubeAI/Zhuji-Internet-Literature-Intelligent-Writing-Model-V1.0" tokenizer = AutoTokenizer.from_pretrained(model_path, encode_special_tokens=True, token=token) model= AutoModelForCausalLM.from_pretrained( model_path, torch_dtype= torch.bfloat16, low_cpu_mem_usage= True, token=token, #attn_implementation="flash_attention_2", device_map= "auto" ) model = torch.compile(model) model = model.eval() DESCRIPTION = '''

网文智能辅助写作 - 珠玑系列模型

我们自主研发的珠玑系列智能写作模型,专为网文创作与理解而生。基于丰富的网文场景数据,包括续写、扩写、取名等创作任务和章纲抽取等理解任务,我们训练了一系列模型参数,覆盖1B至14B不等的模型族,包括生成模型和embedding模型。

📚 基础版模型:适合初次尝试智能写作的用户,提供长篇小说创作的基础功能,助您轻松迈入智能写作的新纪元。

🚀 高级版模型:为追求更高层次创作体验的用户设计,配备更先进的文本生成技术和更精细的理解能力,让您的创作更具深度和创新。

珠玑系列模型(Zhuji-Internet-Literature-Intelligent-Writing-Model-V1.0)现已发布,包括1B、7B、14B规模的模型,基于Qwen1.5架构,旨在为您提供卓越的网文智能写作体验。

''' LICENSE = """

--- Built with NovelGen """ PLACEHOLDER = """

ai助力写作

ai辅助写作

""" css = """ h1 { text-align: center; display: block; } #duplicate-button { margin: auto; color: white; background: #1565c0; border-radius: 100vh; } """ tokenizer.chat_template = """{% for message in messages %} {% if message['role'] == 'user' %} {{'<|user|>'+ message['content'].strip() + '<|observation|>'+ '<|assistant|>'}} {% elif message['role'] == 'system' %} {{ '<|system|>' + message['content'].strip() + '<|observation|>'}} {% elif message['role'] == 'assistant' %} {{ message['content'] + '<|observation|>'}} {% endif %} {% endfor %}""".replace("\n", "").replace(" ", "") def chat_zhuji( message: str, history: list, temperature: float, max_new_tokens: int ) -> str: """ Generate a streaming response using the llama3-8b model. Args: message (str): The input message. history (list): The conversation history used by ChatInterface. temperature (float): The temperature for generating the response. max_new_tokens (int): The maximum number of new tokens to generate. Returns: str: The generated response. """ conversation = [] for user, assistant in history: conversation.extend([{"role": "system","content": "",},{"role": "user", "content": user}, {"role": "<|assistant|>", "content": assistant}]) conversation.append({"role": "user", "content": message}) input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids= input_ids, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, penalty_alpha= 0.65, top_p= 0.90, top_k= 35, use_cache= True, eos_token_id= tokenizer.encode("<|observation|>",add_special_tokens= False), temperature=temperature, ) # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash. if temperature == 0: generate_kwargs['do_sample'] = False t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() #outputs = [] #for text in streamer: # outputs.append(text) # yield "".join(outputs) partial_message = "" for new_token in streamer: if new_token != '<|observation|>': partial_message += new_token yield partial_message # Gradio block chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface') text_box= gr.Textbox(show_copy_button= True) with gr.Blocks(fill_height=True, css=css) as demo: gr.Markdown(DESCRIPTION) gr.ChatInterface( fn=chat_zhuji, chatbot=chatbot, textbox= text_box, fill_height=True, additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), additional_inputs=[ gr.Slider(minimum=0, maximum=1, step=0.1, value=0.95, label="Temperature", render=False), gr.Slider(minimum=128, maximum=8192*2, step=1, value=8192, label="Max new tokens", render=False ), ], examples=[ ['请给一个古代美女的外貌来一段描写'], ['请生成4个东方神功的招式名称'], ['生成一段官军和倭寇打斗的场面描写'], ['生成一个都市大女主的角色档案'], ], cache_examples=False, ) gr.Markdown(LICENSE) if __name__ == "__main__": demo.queue().launch( #server_name='0.0.0.0', #server_port=config.webui_config.port, #inbrowser=True, #share=True )