|
import streamlit as st |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,HfArgumentParser,TrainingArguments,pipeline, logging |
|
import torch |
|
|
|
base_model = "minhtt/vistral-7b-chat" |
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit= True, |
|
bnb_4bit_quant_type= "nf4", |
|
bnb_4bit_compute_dtype= torch.bfloat16, |
|
bnb_4bit_use_double_quant= False, |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
base_model, |
|
load_in_4bit=True, |
|
quantization_config=bnb_config, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
) |
|
|
|
|
|
model.config.use_cache = False |
|
model.config.pretraining_tp = 1 |
|
model.gradient_checkpointing_enable() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) |
|
tokenizer.padding_side = 'right' |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.add_eos_token = True |
|
tokenizer.bos_token, tokenizer.eos_token |
|
|
|
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=200) |
|
text = st.text_erea("Đặt câu hỏi") |
|
|
|
if text: |
|
out = pipe(text) |
|
st.text_erea(out) |