nikitharao commited on
Commit
2df9e19
·
1 Parent(s): 7d30557

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +59 -0
README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+
5
+ # CAT-LM: Aligned <u>C</u>ode <u>A</u>nd <u>T</u>ests Language Model
6
+
7
+ ### Model Description
8
+ **CAT-LM** is a GPT-style language model with 2.7 Billion parameters, trained on a corpus of Python and Java projects (~260GB). It supports a maximum sequence length of 8,192 tokens. We utilize a novel pretraining signal that explicitly considers the mapping between code and test files when available.
9
+
10
+ ### Publication
11
+
12
+ [CAT-LM: Training Language Models on Aligned Code And Tests](https://conf.researchr.org/details/ase-2023/ase-2023-papers/59/CAT-LM-Training-Language-Models-on-Aligned-Code-And-Tests)
13
+ [Nikitha Rao](https://raonikitha.github.io)\*, [Kush Jain](https://www.kushjain.com/)\*, [Uri Alon](https://urialon.ml), [Claire Le Goues](https://clairelegoues.com), and [Vincent J. Hellendoorn](http://vhellendoorn.github.io)\
14
+ 38th IEEE/ACM International Conference on Automated Software Engineering (ASE 2023)
15
+
16
+ ### Usage
17
+
18
+ ```python
19
+ from transformers import AutoModelForCausalLM, AutoTokenizer
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained('nikitharao/catlm', use_fast = False)
22
+ model = AutoModelForCausalLM.from_pretrained('nikitharao/catlm')
23
+
24
+ prompt = """
25
+ def add(x,y):
26
+ \"\"\"Add two numbers x and y\"\"\"
27
+ return x+y
28
+ <|codetestpair|>
29
+ """
30
+
31
+ print('Input prompt:')
32
+ print(prompt)
33
+
34
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
35
+
36
+ # The model was trained without the `</s>` token and should be removed.
37
+ if tokenizer.decode(input_ids[0,-1]) == '</s>':
38
+ input_ids = input_ids[:,:-1]
39
+
40
+ print(input_ids)
41
+ len_input = input_ids.shape[1]
42
+
43
+ sample_output = model.generate(
44
+ input_ids,
45
+ do_sample=True,
46
+ max_new_tokens = 512,
47
+ top_k=50,
48
+ top_p=0.95,
49
+ temperature=0.2
50
+ )
51
+ generated_output = sample_output[0][len_input:]
52
+ output = tokenizer.decode(generated_output, skip_special_tokens=True)
53
+ print('Output:')
54
+ print(output)
55
+ ```
56
+
57
+ <b>Note:</b> The model was trained without the `</s>` token and should be removed.
58
+
59
+ Please see https://github.com/RaoNikitha/CAT-LM for more details.