BAAI
/

BoyaWu10 commited on
Commit
59d5698
·
verified ·
1 Parent(s): 34350a2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -4
README.md CHANGED
@@ -30,6 +30,9 @@ Before running the snippet, you need to install the following dependencies:
30
  pip install torch transformers accelerate pillow
31
  ```
32
 
 
 
 
33
  ```python
34
  import torch
35
  import transformers
@@ -43,12 +46,13 @@ transformers.logging.disable_progress_bar()
43
  warnings.filterwarnings('ignore')
44
 
45
  # set device
46
- torch.set_default_device('cpu') # or 'cuda'
 
47
 
48
  # create model
49
  model = AutoModelForCausalLM.from_pretrained(
50
  'BAAI/Bunny-v1_0-3B',
51
- torch_dtype=torch.float16,
52
  device_map='auto',
53
  trust_remote_code=True)
54
  tokenizer = AutoTokenizer.from_pretrained(
@@ -59,11 +63,11 @@ tokenizer = AutoTokenizer.from_pretrained(
59
  prompt = 'Why is the image funny?'
60
  text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{prompt} ASSISTANT:"
61
  text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
62
- input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
63
 
64
  # image, sample images can be found in images folder
65
  image = Image.open('example_2.png')
66
- image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
67
 
68
  # generate
69
  output_ids = model.generate(
 
30
  pip install torch transformers accelerate pillow
31
  ```
32
 
33
+ If the CUDA memory is enough, it would be faster to execute this snippet by setting `CUDA_VISIBLE_DEVICES=0`.
34
+
35
+
36
  ```python
37
  import torch
38
  import transformers
 
46
  warnings.filterwarnings('ignore')
47
 
48
  # set device
49
+ device = 'cuda' # or cpu
50
+ torch.set_default_device(device)
51
 
52
  # create model
53
  model = AutoModelForCausalLM.from_pretrained(
54
  'BAAI/Bunny-v1_0-3B',
55
+ torch_dtype=torch.float16, # float32 for cpu
56
  device_map='auto',
57
  trust_remote_code=True)
58
  tokenizer = AutoTokenizer.from_pretrained(
 
63
  prompt = 'Why is the image funny?'
64
  text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{prompt} ASSISTANT:"
65
  text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
66
+ input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0).to(device)
67
 
68
  # image, sample images can be found in images folder
69
  image = Image.open('example_2.png')
70
+ image_tensor = model.process_images([image], model.config).to(dtype=model.dtype, device=device)
71
 
72
  # generate
73
  output_ids = model.generate(