Source code for orchestrator.workflow.aiida

import math
from os import PathLike
from abc import ABC
from copy import deepcopy
from time import sleep
from typing import Union
from .workflow_base import HPCWorkflow, JobStatus
from ..utils.data_standard import METADATA_KEY

from aiida import load_profile
from aiida.orm import QueryBuilder, ProcessNode, load_node, CalcJobNode
from aiida.engine import submit
from aiida.engine.processes.builder import ProcessBuilder
from aiida.manage.configuration import get_config
from aiida.engine.daemon.client import get_daemon_client, DaemonClient


[docs] class AiidaWF(HPCWorkflow, ABC): """ Workflow class for using AiiDA Note that ``kwargs`` set the default parameters for the workflow, but can be overridden by values passed into ``job_details`` provided to :meth:`submit_job`. """
[docs] def __init__(self, **kwargs: dict): """ Initialize the values needed for the AiiDA WF. :param kwargs: arguments to control workflow behavior, keys may include root_directory, checkpoint_file, checkpoint_name, and job_record. These will be defaulted to './orchestrator_workflow' and './orchestrator_checkpoint.json', 'workflow', and './job_record.pkl', respectively. AiiDA additionally uses 'queue' -> 'pbatch', 'account' -> 'iap', 'walltime' -> '1:00', 'nodes' -> 1, 'tasks' -> 1, and 'tasks_per_node' -> 1 :type kwargs: dict """ super().__init__(**kwargs) self.check_daemon_status()
[docs] def check_daemon_status(self): """ Check the status of the daemon used for AiiDA. If the daemon is not currently running, will attempt to start the daemon. """ client = self.get_aiida_daemon_client() try: # Check that the daemon has not become stale. if client._is_pid_file_stale: # If it is stale, stop the daemon and start again. client._clean_potentially_stale_pid_file() client.start_daemon() elif not client.is_daemon_running: self.logger.info('No AiiDA daemon was found to be running. ' 'Starting the daemon now.') client.start_daemon() pid = client.get_daemon_pid() self.logger.info(f'AiiDA daemon started ({pid=}).') except RuntimeError: msg = ('The daemon is not running and was unable to be started. ' 'Fix the daemon and attempt the job after.') self.logger.warning(msg) raise RuntimeError(msg)
[docs] def check_daemon_workload(self, workload_target: float = 0.9): """ Will check the workload of the daemon and set to the designated workload target. :param workload_target: Target value to set the workload. Is used to decided the number of daemon workers needed. Default value is set to 90%. """ client = self.get_aiida_daemon_client() active_workers = client.get_numprocesses()['numprocesses'] config = get_config() slots_per_worker = config.get_option('daemon.worker_process_slots', scope=self.default_profile_name) ap = (QueryBuilder().append(ProcessNode, filters={ 'attributes.process_state': { 'in': ('created', 'waiting', 'running') } }).count()) needed_workers = math.ceil(ap / (workload_target * slots_per_worker)) # Always keep one worker running. if needed_workers == 0 and active_workers == 1: needed_workers = 1 diff = needed_workers - active_workers if diff > 0: self.logger.info( f'Increasing the number of AiiDA daemons by {diff}.') client.increase_workers(diff) elif diff < 0: client.decrease_workers(abs(diff)) self.logger.info( f'Decreasing the number of AiiDA daemons by {abs(diff)}.')
[docs] def get_aiida_daemon_client(self) -> DaemonClient: """ Retrieve the daemon client used within in AiiDA. """ config = get_config() self.default_profile_name = config.default_profile_name client = get_daemon_client(self.default_profile_name) return client
[docs] def checkpoint_workflow(self): """ Checkpoint the workflow module into the checkpoint file. Save necessary internal variables into a dict with key checkpoint_name and write to the (json) checkpoint file for restart capabilities. """ self.save_job_dict()
[docs] def restart_workflow(self): """ Restart the workflow module from the checkpoint file. Check if the checkpoint_file has an entry matching the checkpoint_name and set internal variables accordingly if so. """ # set the jobs dict from the pickle file self.read_job_dict()
[docs] def update_job_status(self, pks: list[int]) -> list[str]: """ Query the scheduler and extract the job_status. This helper function uses the AiiDA QueryBuilder to check updates about a job's progress, modifying the corresponding job_status object. Status options are: 'CREATED', 'EXCEPTED', 'FINISHED', 'KILLED', 'RUNNING', and 'WAITING'. The current status is returned for convenience. :param pks: list of AiiDA PKs of the jobs to check for completion :returns: list of job states """ # Load the AiiDA profile load_profile() # Create an instance of the QueryBuilder qb = QueryBuilder() # Query for the pks from the list qb.append(ProcessNode, filters={'id': { 'in': pks }}, project=['attributes.process_state', 'attributes.exit_code']) statuses = qb.all() # Returns a list of tuples [(state, exit_code)] status_changed = False updated_states = [] for pk, status in zip(pks, statuses): # Get the previous job state known_status = self.get_job_status(pk) if status[0] != known_status.state: self.logger.info((f'Updating job {pk} state from ' f'{known_status.state} to {status[0]}')) known_status.state = status[0] status_changed = True updated_states.append(status) if status_changed: self.checkpoint_workflow() return updated_states
[docs] def block_until_completed(self, pks: Union[int, list[int]]): """ Function will periodically check on the job status in AiiDA. The time between checks is based on the wait_freq variable in the workflow. The default value is 60 seconds. :param pks: list of AiiDA PKs of the jobs to check for completion. Can also pass a single ID. """ if type(pks) is list: remaining_jobs = deepcopy(pks) elif type(pks) is int: remaining_jobs = [pks] else: raise TypeError('AiiDA PKs must be a single int or a list!') wait_cycle_counter = 0 while len(remaining_jobs) > 0: job_states = self.update_job_status(remaining_jobs) jobs_to_remove = [] for pk, job_state in zip(remaining_jobs, job_states): if job_state[0] == 'finished': self.logger.info((f'AiiDA Job <{pk}> completed, removing ' f'from waiting list')) jobs_to_remove.append(pk) # remove after iterating the whole list so no IDs are skipped for pk in jobs_to_remove: remaining_jobs.remove(pk) # Check the daemon workload self.check_daemon_workload() if len(remaining_jobs) == 0: break if wait_cycle_counter % 5 == 0: self.logger.info((f'Wait iteration {wait_cycle_counter}, [' f'cycle time = {self.synch_check_frequency}' f' s] with {len(remaining_jobs)} jobs left ' f'to complete')) wait_cycle_counter += 1 sleep(self.synch_check_frequency) self.logger.info(f'Jobs {pks} have completed, continuing...')
[docs] def submit_job(self, builder: ProcessBuilder, job_details: dict) -> int: """ Submits a job to AiiDA. submit_job handles job submission to AiiDA for the Oracle calculations. As most information should have already been defined in the `computer` and `code` items of AiiDA, this function will primarily set things such as details about the job's resources (``job_details``). Note that while default job resources (nodes, account, walltime, etc.) are present, they can be overridden by providing these keywords in the ``job_details`` dict for any specific calculation. Creates the :class:`~.workflow_base.JobStatus` for this job, where the job state is initially 'CREATED' and can be updated to 'EXCEPTED', 'FINISHED', 'KILLED', 'RUNNING', or 'WAITING'. The 'FINISHED' state means the calculation has completed, but can be decorated with suffixes that add more information if the job didn't successfully complete (i.e. 'FINISHED_TIMEOUT'). Status checks are performed by :meth:`~update_job_status`. Returns the AiiDA pk, which can be used to retrieve the present job's :class:`~.workflow_base.JobStatus`. :param builder: AiiDA builder object containing the required information to submit the calculation. :param job_details: specifics for running the job, such as number of nodes, queue, etc., as well as optional dependency list, if the job should be synchronous or asychronous, and any other optional arguments, such as pre- or postambles |default| ``None`` :returns: return job ID to query this job status and location """ self.logger.info('Spawning job, ID to be defined') extra_args = job_details.get('extra_args', {}) # Load AiiDA profile load_profile() if job_details is None: job_details = {} self.logger.info((f'No job details specified, will use defaults:\n' f' N = {self.default_nodes},\n' f' A = {self.default_account},\n' f' t = {self.default_walltime},\n' f' p = {self.default_queue}\n')) calc = submit(builder) pk = calc.pk self.logger.info(f'Spawning AiiDA job with PK=<{pk}>') job_status = JobStatus(f'{self.root_directory}/{pk}', 'created', 0) job_status.metadata = extra_args.get(METADATA_KEY, {}) self.jobs[pk] = job_status self.checkpoint_workflow() return pk
[docs] @staticmethod def get_job_path(pk: int) -> Union[str, PathLike]: """ Given the parent pk value that an Oracle returns, will get the absolute path on the remote server. :param pk: AiiDA PK :returns: Path on the remote server where the calculation occurred. """ parent = load_node(pk) path = None for node in parent.called_descendants: if isinstance(node, CalcJobNode): path = node.get_remote_workdir() if not path: raise ValueError( f'The provided pk <{pk}> does not appear to have any ' 'CalcJobNode types associated with the descendants. ' 'Make sure the job finished completely.') return path