Source code for orchestrator.trainer.factory

from ..utils.module_factory import ModuleFactory, ModuleBuilder
from ..utils.exceptions import ModuleAlreadyInFactoryError
from .trainer_base import Trainer

#: default factory for trainers, includes DNN (kliff) and KLIFF (parametric
#: model)
trainer_factory = ModuleFactory(Trainer)


[docs] class TrainerBuilder(ModuleBuilder): """ Constructor for trainers added in the factory set the factory to be used for the builder. The default is to use the trainer_factory generated at the end of this module. A user defined ModuleFactory can optionally be supplied instead. :param factory: a trainer factory |default| :data:`trainer_factory` :type factory: ModuleFactory """
[docs] def __init__(self, factory=trainer_factory): """ constructor for the TrainerBuilder, sets the factory to build from :param factory: a trainer factory |default| :data:`trainer_factory` :type factory: ModuleFactory """ if factory.base_class.__name__ == Trainer.__name__: super().__init__(factory) else: raise Exception('Supplied factory is not for Trainers!')
[docs] def build(self, trainer_type, trainer_args=None) -> Trainer: """ Return an instance of the specified trainer The build method takes the specifier and input arguments to construct a concrete trainer instance. :param trainer_type: token of a trainer which has been added to the factory :type trainer_type: str :param trainer_args: arguments to control trainer behavior :type trainer_args: dict :returns: instantiated concrete Trainer :rtype: Trainer """ if trainer_args is None: trainer_args = {} match trainer_type: case 'KLIFF': from .kliff import ParametricModelTrainer try: trainer_factory.add_new_module( 'KLIFF', ParametricModelTrainer, ) except ModuleAlreadyInFactoryError: pass case 'DNN': from .kliff import DUNNTrainer try: trainer_factory.add_new_module('DNN', DUNNTrainer) except ModuleAlreadyInFactoryError: pass case 'FitSnap': from .fitsnap import FitSnapTrainer try: trainer_factory.add_new_module('FitSnap', FitSnapTrainer) except ModuleAlreadyInFactoryError: pass case 'ChIMES': from .chimes import ChIMESTrainer try: trainer_factory.add_new_module('ChIMES', ChIMESTrainer) except ModuleAlreadyInFactoryError: pass trainer_constructor = self.factory.select_module(trainer_type) built_class = trainer_constructor(**trainer_args) built_class.factory_token = trainer_type return built_class
#: trainer builder object which can be imported for use in other modules trainer_builder = TrainerBuilder()