Spaces:
Runtime error
Runtime error
from colbert.infra.run import Run | |
from colbert.infra.launcher import Launcher | |
from colbert.infra.config import ColBERTConfig, RunConfig | |
from colbert.training.training import train | |
class Trainer: | |
def __init__(self, triples, queries, collection, config=None): | |
self.config = ColBERTConfig.from_existing(config, Run().config) | |
self.triples = triples | |
self.queries = queries | |
self.collection = collection | |
def configure(self, **kw_args): | |
self.config.configure(**kw_args) | |
def train(self, checkpoint='bert-base-uncased'): | |
""" | |
Note that config.checkpoint is ignored. Only the supplied checkpoint here is used. | |
""" | |
# Resources don't come from the config object. They come from the input parameters. | |
# TODO: After the API stabilizes, make this "self.config.assign()" to emphasize this distinction. | |
self.configure(triples=self.triples, queries=self.queries, collection=self.collection) | |
self.configure(checkpoint=checkpoint) | |
launcher = Launcher(train) | |
self._best_checkpoint_path = launcher.launch(self.config, self.triples, self.queries, self.collection) | |
def best_checkpoint_path(self): | |
return self._best_checkpoint_path | |