|
--- |
|
license: mit |
|
--- |
|
|
|
# CAT-LM: Aligned <u>C</u>ode <u>A</u>nd <u>T</u>ests Language Model |
|
|
|
### Model Description |
|
**CAT-LM** is a GPT-style language model with 2.7 Billion parameters, trained on a corpus of Python and Java projects (~260GB). It supports a maximum sequence length of 8,192 tokens. We utilize a novel pretraining signal that explicitly considers the mapping between code and test files when available. |
|
|
|
### Publication |
|
|
|
[CAT-LM: Training Language Models on Aligned Code And Tests](https://arxiv.org/abs/2310.01602) |
|
[Nikitha Rao](https://raonikitha.github.io)\*, [Kush Jain](https://www.kushjain.com/)\*, [Uri Alon](https://urialon.ml), [Claire Le Goues](https://clairelegoues.com), and [Vincent J. Hellendoorn](http://vhellendoorn.github.io)\ |
|
38th IEEE/ACM International Conference on Automated Software Engineering (ASE 2023) |
|
|
|
### Usage |
|
|
|
```python |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('nikitharao/catlm', use_fast = False) |
|
model = AutoModelForCausalLM.from_pretrained('nikitharao/catlm') |
|
|
|
prompt = """ |
|
def add(x,y): |
|
\"\"\"Add two numbers x and y\"\"\" |
|
return x+y |
|
<|codetestpair|> |
|
""" |
|
|
|
print('Input prompt:') |
|
print(prompt) |
|
|
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids |
|
|
|
# The model was trained without the `</s>` token and should be removed. |
|
if tokenizer.decode(input_ids[0,-1]) == '</s>': |
|
input_ids = input_ids[:,:-1] |
|
|
|
print(input_ids) |
|
len_input = input_ids.shape[1] |
|
|
|
sample_output = model.generate( |
|
input_ids, |
|
do_sample=True, |
|
max_new_tokens = 512, |
|
top_k=50, |
|
top_p=0.95, |
|
temperature=0.2 |
|
) |
|
generated_output = sample_output[0][len_input:] |
|
output = tokenizer.decode(generated_output, skip_special_tokens=True) |
|
print('Output:') |
|
print(output) |
|
``` |
|
|
|
<b>Note:</b> The model was trained without the `</s>` token and should be removed. |
|
|
|
Please see https://github.com/RaoNikitha/CAT-LM for more details. |
|
|