File size: 3,144 Bytes
b99a3bf
 
 
 
ef9cde4
 
b99a3bf
ef9cde4
b99a3bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8239d17
7b56d83
b99a3bf
51eae2b
ef9cde4
d1de333
8239d17
d1de333
 
d8cb72d
 
d1de333
 
7b56d83
53a67f8
8239d17
 
 
 
7b56d83
8239d17
d8cb72d
 
 
 
 
 
 
53a67f8
b99a3bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import deepspeed
import torch
from transformers import pipeline
import os
import gradio as gr

model_id = 'dicta-il/dictalm-7b-instruct'

# ื˜ืขื™ื ืช ื”ืžื•ื“ืœ ื•ื”ื›ื ืช ื”ืžื ื•ืข
should_use_fast = True
print(f'should_use_fast = {should_use_fast}')

local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
generator = pipeline('text-generation', model=model_id,
                     tokenizer=model_id,
                     torch_dtype=torch.float16,
                     use_fast=should_use_fast,
                     trust_remote_code=True,
                     device_map="auto")

# ื‘ื“ื™ืงืช ื”ืชืงืŸ - GPU ืื• CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

total_mem = 0
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    total_mem = round(torch.cuda.get_device_properties(0).total_memory / 1024**3, 1)
    print('Total Memory: ', total_mem, 'GB')

should_replace_with_kernel_inject = total_mem >= 12
print(f'should_replace_with_kernel_inject = {should_replace_with_kernel_inject}')

ds_engine = deepspeed.init_inference(generator.model,
                                     mp_size=world_size,
                                     dtype=torch.half,
                                     replace_with_kernel_inject=should_replace_with_kernel_inject)
generator.model = ds_engine.module

# ืคื•ื ืงืฆื™ื™ืช ื™ืฆื™ืจืช ื”ื˜ืงืกื˜
def chat_with_model(history):
    prompt = history[-1]["content"]
    result = generator(prompt, do_sample=True, min_length=20, max_length=64, top_k=40, top_p=0.92, temperature=0.9)[0]["generated_text"]
    return history + [{"role": "bot", "content": result}]

# ื™ืฆื™ืจืช ืžืžืฉืง ืžืชืงื“ื ืขื Gradio ื‘ืฆื•ืจืช ืฆ'ื˜-ื‘ื•ื˜ ื‘ืกื’ื ื•ืŸ ืืงื“ืžื™
with gr.Blocks(theme="default") as demo:
    gr.HTML("""
    <div style="background-color: #f5f5f5; padding: 20px; text-align: center;">
        <h1 style="color: #003366; font-family: Arial, sans-serif;">ืฆ'ืื˜ ืขื ืžื•ื“ืœ DictaLM</h1>
        <p style="font-family: Arial, sans-serif; color: #333;">ื‘ืจื•ื›ื™ื ื”ื‘ืื™ื ืœืฆ'ืื˜ ื”ืื™ื ื˜ืจืืงื˜ื™ื‘ื™ ืฉืœื ื•, ื”ืžืืคืฉืจ ืœื›ื ืœื”ืชื ืกื•ืช ื‘ืฉื™ื—ื” ืขื ืžื•ื“ืœ AI ืžืชืงื“ื.</p>
    </div>
    """)
    chatbot = gr.Chatbot(label="ืฆ'ืื˜ ืขื ืžื•ื“ืœ DictaLM", type="messages")
    with gr.Row():
        user_input = gr.Textbox(placeholder="ื”ื›ื ืก ืืช ื”ื”ื•ื“ืขื” ืฉืœืš ื›ืืŸ...", label="", lines=1)
        send_button = gr.Button("ืฉืœื—")
    
    def user_chat(history, message):
        return history + [{"role": "user", "content": message}], ""

    # ืฉืœื™ื—ืช ื”ื”ื•ื“ืขื” ื’ื ื‘ืœื—ื™ืฆื” ืขืœ Enter ื•ื’ื ืขืœ ื™ื“ื™ ืœื—ื™ืฆื” ืขืœ ื›ืคืชื•ืจ "ืฉืœื—"
    user_input.submit(fn=user_chat, inputs=[chatbot, user_input], outputs=[chatbot, user_input], queue=False).then(
        fn=chat_with_model, inputs=chatbot, outputs=chatbot
    )
    send_button.click(fn=user_chat, inputs=[chatbot, user_input], outputs=[chatbot, user_input], queue=False).then(
        fn=chat_with_model, inputs=chatbot, outputs=chatbot
    )

demo.launch()