Spaces:
Sleeping
Sleeping
import deepspeed | |
import torch | |
from transformers import pipeline | |
import os | |
import gradio as gr | |
model_id = 'dicta-il/dictalm-7b-instruct' | |
# ืืขืื ืช ืืืืื ืืืื ืช ืืื ืืข | |
should_use_fast = True | |
print(f'should_use_fast = {should_use_fast}') | |
local_rank = int(os.getenv('LOCAL_RANK', '0')) | |
world_size = int(os.getenv('WORLD_SIZE', '1')) | |
generator = pipeline('text-generation', model=model_id, | |
tokenizer=model_id, | |
torch_dtype=torch.float16, | |
use_fast=should_use_fast, | |
trust_remote_code=True, | |
device_map="auto") | |
# ืืืืงืช ืืชืงื - GPU ืื CPU | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print('Using device:', device) | |
print() | |
total_mem = 0 | |
if device.type == 'cuda': | |
print(torch.cuda.get_device_name(0)) | |
total_mem = round(torch.cuda.get_device_properties(0).total_memory / 1024**3, 1) | |
print('Total Memory: ', total_mem, 'GB') | |
should_replace_with_kernel_inject = total_mem >= 12 | |
print(f'should_replace_with_kernel_inject = {should_replace_with_kernel_inject}') | |
ds_engine = deepspeed.init_inference(generator.model, | |
mp_size=world_size, | |
dtype=torch.half, | |
replace_with_kernel_inject=should_replace_with_kernel_inject) | |
generator.model = ds_engine.module | |
# ืคืื ืงืฆืืืช ืืฆืืจืช ืืืงืกื | |
def chat_with_model(history): | |
prompt = history[-1]["content"] | |
result = generator(prompt, do_sample=True, min_length=20, max_length=64, top_k=40, top_p=0.92, temperature=0.9)[0]["generated_text"] | |
return history + [{"role": "bot", "content": result}] | |
# ืืฆืืจืช ืืืฉืง ืืชืงืื ืขื Gradio ืืฆืืจืช ืฆ'ื-ืืื ืืกืื ืื ืืงืืื | |
with gr.Blocks(theme="default") as demo: | |
gr.HTML(""" | |
<div style="background-color: #f5f5f5; padding: 20px; text-align: center;"> | |
<h1 style="color: #003366; font-family: Arial, sans-serif;">ืฆ'ืื ืขื ืืืื DictaLM</h1> | |
<p style="font-family: Arial, sans-serif; color: #333;">ืืจืืืื ืืืืื ืืฆ'ืื ืืืื ืืจืืงืืืื ืฉืื ื, ืืืืคืฉืจ ืืื ืืืชื ืกืืช ืืฉืืื ืขื ืืืื AI ืืชืงืื.</p> | |
</div> | |
""") | |
chatbot = gr.Chatbot(label="ืฆ'ืื ืขื ืืืื DictaLM", type="messages") | |
with gr.Row(): | |
user_input = gr.Textbox(placeholder="ืืื ืก ืืช ืืืืืขื ืฉืื ืืื...", label="", lines=1) | |
send_button = gr.Button("ืฉืื") | |
def user_chat(history, message): | |
return history + [{"role": "user", "content": message}], "" | |
# ืฉืืืืช ืืืืืขื ืื ืืืืืฆื ืขื Enter ืืื ืขื ืืื ืืืืฆื ืขื ืืคืชืืจ "ืฉืื" | |
user_input.submit(fn=user_chat, inputs=[chatbot, user_input], outputs=[chatbot, user_input], queue=False).then( | |
fn=chat_with_model, inputs=chatbot, outputs=chatbot | |
) | |
send_button.click(fn=user_chat, inputs=[chatbot, user_input], outputs=[chatbot, user_input], queue=False).then( | |
fn=chat_with_model, inputs=chatbot, outputs=chatbot | |
) | |
demo.launch() |