teachyourselfcoding's picture
Upload 245 files
fa6856c
raw
history blame
2.51 kB
import json
import logging
import pathlib
import yaml
from lang import Interpreter
import trlx
from trlx.data.configs import TRLConfig
logger = logging.getLogger(__name__)
class DSLDataset:
def __init__(self):
with open("dataset/train.json", "r") as f:
self.train_data = json.load(f)
with open("dataset/test.json", "r") as f:
self.test_data = json.load(f)
logger.info("Sucessfully loaded the dataset")
def load_datapoints(self, split="train"):
if split == "train":
for datapoint in self.train_data:
if "ERROR" not in datapoint["input"]:
yield datapoint["input"]
elif split == "test":
for datapoint in self.test_data:
yield datapoint["input"]
interpreter = Interpreter()
def reward_fn(samples, **kwargs):
reward_list = []
for sample in samples:
code = sample.split("Function:")[1].strip()
output = eval(sample.split("Output:")[1].strip().split("Function:")[0].strip())
interpreted_output = interpreter(code)
if interpreted_output == "ERROR":
# If the code is unparsable, we give it a negative reward.
reward_list.append(-1)
else:
# if the code is parseable
if output == interpreted_output:
# if the output is correct, we give it a positive reward.
reward_list.append(1)
else:
# if the output is incorrect, we give it a negative reward.
reward_list.append(-0.5)
return reward_list
config_path = pathlib.Path(__file__).parent.joinpath("configs/trlx_ppo_config.yml")
with config_path.open() as f:
default_config = yaml.safe_load(f)
def main(hparams={}):
config = TRLConfig.update(default_config, hparams)
# Dataset
dataset = DSLDataset()
train_prompts = list(dataset.load_datapoints(split="train"))[:1000]
trainer = trlx.train(
reward_fn=reward_fn,
prompts=train_prompts,
config=config,
)
trainer.save_pretrained("dataset/trained_model")
if __name__ == "__main__":
# TEST REWARD FUNTION
assert (reward_fn(["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -4]),1)"])) == [1]
assert (reward_fn(["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -a]),1)"])) == [-1]
assert (reward_fn(["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -3]),1)"])) == [-0.5]
main()