File size: 525 Bytes
fd582de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
# Check if GPU is available
if torch.cuda.is_available():
device = torch.device("cuda")
print("GPU is available. Using GPU.")
else:
device = torch.device("cpu")
print("GPU not available. Using CPU.")
# Example model loading
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "cerebras/btlm-3b-8k-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Move model to the appropriate device
model.to(device)
|