Spaces:
Sleeping
Sleeping
File size: 3,144 Bytes
b99a3bf ef9cde4 b99a3bf ef9cde4 b99a3bf 8239d17 7b56d83 b99a3bf 51eae2b ef9cde4 d1de333 8239d17 d1de333 d8cb72d d1de333 7b56d83 53a67f8 8239d17 7b56d83 8239d17 d8cb72d 53a67f8 b99a3bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import deepspeed
import torch
from transformers import pipeline
import os
import gradio as gr
model_id = 'dicta-il/dictalm-7b-instruct'
# ืืขืื ืช ืืืืื ืืืื ืช ืืื ืืข
should_use_fast = True
print(f'should_use_fast = {should_use_fast}')
local_rank = int(os.getenv('LOCAL_RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))
generator = pipeline('text-generation', model=model_id,
tokenizer=model_id,
torch_dtype=torch.float16,
use_fast=should_use_fast,
trust_remote_code=True,
device_map="auto")
# ืืืืงืช ืืชืงื - GPU ืื CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()
total_mem = 0
if device.type == 'cuda':
print(torch.cuda.get_device_name(0))
total_mem = round(torch.cuda.get_device_properties(0).total_memory / 1024**3, 1)
print('Total Memory: ', total_mem, 'GB')
should_replace_with_kernel_inject = total_mem >= 12
print(f'should_replace_with_kernel_inject = {should_replace_with_kernel_inject}')
ds_engine = deepspeed.init_inference(generator.model,
mp_size=world_size,
dtype=torch.half,
replace_with_kernel_inject=should_replace_with_kernel_inject)
generator.model = ds_engine.module
# ืคืื ืงืฆืืืช ืืฆืืจืช ืืืงืกื
def chat_with_model(history):
prompt = history[-1]["content"]
result = generator(prompt, do_sample=True, min_length=20, max_length=64, top_k=40, top_p=0.92, temperature=0.9)[0]["generated_text"]
return history + [{"role": "bot", "content": result}]
# ืืฆืืจืช ืืืฉืง ืืชืงืื ืขื Gradio ืืฆืืจืช ืฆ'ื-ืืื ืืกืื ืื ืืงืืื
with gr.Blocks(theme="default") as demo:
gr.HTML("""
<div style="background-color: #f5f5f5; padding: 20px; text-align: center;">
<h1 style="color: #003366; font-family: Arial, sans-serif;">ืฆ'ืื ืขื ืืืื DictaLM</h1>
<p style="font-family: Arial, sans-serif; color: #333;">ืืจืืืื ืืืืื ืืฆ'ืื ืืืื ืืจืืงืืืื ืฉืื ื, ืืืืคืฉืจ ืืื ืืืชื ืกืืช ืืฉืืื ืขื ืืืื AI ืืชืงืื.</p>
</div>
""")
chatbot = gr.Chatbot(label="ืฆ'ืื ืขื ืืืื DictaLM", type="messages")
with gr.Row():
user_input = gr.Textbox(placeholder="ืืื ืก ืืช ืืืืืขื ืฉืื ืืื...", label="", lines=1)
send_button = gr.Button("ืฉืื")
def user_chat(history, message):
return history + [{"role": "user", "content": message}], ""
# ืฉืืืืช ืืืืืขื ืื ืืืืืฆื ืขื Enter ืืื ืขื ืืื ืืืืฆื ืขื ืืคืชืืจ "ืฉืื"
user_input.submit(fn=user_chat, inputs=[chatbot, user_input], outputs=[chatbot, user_input], queue=False).then(
fn=chat_with_model, inputs=chatbot, outputs=chatbot
)
send_button.click(fn=user_chat, inputs=[chatbot, user_input], outputs=[chatbot, user_input], queue=False).then(
fn=chat_with_model, inputs=chatbot, outputs=chatbot
)
demo.launch() |