pszemraj's picture
add gist
138c85b
|
raw
history blame
1.02 kB
metadata
inference: false

ethzanalytics/gpt-j-6B-8bit-sharded

This is a version of hivemind/gpt-j-6B-8bit for low-RAM loading, i.e., free Colab runtimes :)

Usage

NOTE: PRIOR to loading the model, you need to "patch" it to be compatible with loading 8bit weights etc. See the original model card above for details on how to do this.

import transformers 
from transformers import AutoTokenizer

"""
CODE TO PATCH GPTJForCausalLM GOES HERE
"""

tokenizer = AutoTokenizer.from_pretrained("ethzanalytics/gpt-j-6B-8bit-sharded")

model = GPTJForCausalLM.from_pretrained(
    "ethzanalytics/gpt-j-6B-8bit-sharded",
    low_cpu_mem_usage=True,
    max_shard_size=f"1000MB",
)