Source code for acme.backend

#
# Computational scaffolding for user-interface
#
# Copyright © 2025 Ernst Strüngmann Institute (ESI) for Neuroscience
# in Cooperation with Max Planck Society
#
# SPDX-License-Identifier: BSD-3-Clause
#

# Builtin/3rd party package imports
import time
import socket
import getpass
import datetime
import inspect
import numbers
import collections
import os
import sys
import glob
import shutil
import functools
import pickle
import logging
import multiprocessing
import psutil
import tqdm
import h5py
import dask
import dask.distributed as dd
import numpy as np
from dask_jobqueue import SLURMCluster
from typing import TYPE_CHECKING, Optional, Any, Union, List
from numpy.typing import ArrayLike

# Local imports
from . import __path__
from .dask_helpers import (
    esi_cluster_setup,
    bic_cluster_setup,
    local_cluster_setup,
    slurm_cluster_setup,
    cluster_cleanup,
    count_online_workers,
)
from .shared import user_yesno, is_esi_node, is_slurm_node, is_bic_node
from .logger import prepare_log
from .validators import validate_boolean, validate_pmap
from .config import ACMEConfig
from .memory_profiler import MemoryProfiler
from .argument_processor import ArgumentProcessor

isSpyModule = False
if "syncopy" in sys.modules:  # pragma: no cover
    isSpyModule = True
if TYPE_CHECKING:  # pragma: no cover
    from frontend import ParallelMap

__all__: List["str"] = ["ACMEdaemon"]

# Fetch logger
log = logging.getLogger("ACME")


# Main manager for parallel execution of user-defined functions
[docs] class ACMEdaemon(object): # Restrict valid class attributes __slots__ = ("results_container", "config", "processor", "profiler")
[docs] def __init__( self, pmap: "ParallelMap", n_workers: Union[int, str] = "auto", write_worker_results: bool = True, output_dir: Optional[str] = None, result_shape: Optional[tuple[Optional[int], ...]] = None, result_dtype: str = "float", single_file: bool = False, write_pickle: bool = False, dryrun: bool = False, partition: str = "auto", mem_per_worker: str = "auto", setup_timeout: int = 60, setup_interactive: bool = True, stop_client: Union[bool, str] = "auto", verbose: Optional[bool] = None, logfile: Optional[Union[bool, str]] = None, ) -> None: """ Manager class for performing concurrent user function calls Parameters ---------- pmap : :class:`~acme.ParallelMap` context manager By default, `:class:~`acme.ACMEDaemon` assumes that that the provided :class:`~acme.ParallelMap` instance has already been properly set up to process `func` (all input arguments parsed and properly formatted). All other input arguments of `:class:~`acme.ACMEDaemon` are extracted from the provided :class:`~acme.ParallelMap` instance. n_workers : int or "auto" Number of SLURM workers (=jobs) to spawn. See :class:`~acme.ParallelMap` for details. write_worker_results : bool If `True`, the return value(s) of `func` is/are saved on disk. See :class:`~acme.ParallelMap` for details. output_dir : str or None If provided, auto-generated results are stored in the given path. See :class:`~acme.ParallelMap` for details. result_shape : tuple or None If provided, results are slotted into a dataset/array with layout `result_shape`. See :class:`~acme.ParallelMap` for details. result_dtype : str Determines numerical datatype of dataset laid out by `result_shape`. See :class:`~acme.ParallelMap` for details. single_file : bool If `True`, parallel workers write to the same results container. See :class:`~acme.ParallelMap` for details. write_pickle : bool If `True`, the return value(s) of `func` is/are pickled to disk. See :class:`~acme.ParallelMap` for details. dryrun : bool If `True`, a dry-run of calling `func` is performed using a single `args`, `kwargs` tuple. See :class:`~acme.ParallelMap` for details. partition : str Name of SLURM partition to use. See :class:`~acme.ParallelMap` for details. mem_per_worker : str Memory booking for each SLURM worker. See :class:`~acme.ParallelMap` for details. setup_timeout : int Timeout period (in seconds) for SLURM workers to come online. See :class:`~acme.ParallelMap` for details. setup_interactive : bool If `True`, user input is queried in case not enough SLURM workers could be started within `setup_timeout` seconds. See :class:`~acme.ParallelMap` for details. stop_client : bool or "auto" If `"auto"`, automatically started distributed computing clients are shut down at the end of computation, while user-provided clients are left untouched. See :class:`~acme.ParallelMap` for details. verbose : None or bool If `None` (default), general run-time information as well as warnings and errors are shown. See :class:`~acme.ParallelMap` for details. logfile : None or bool or str If `None` (default) or `True`, and `write_worker_results` is `True`, all run-time information as well as errors and warnings are tracked in a log-file. See :class:`~acme.ParallelMap` for details. Returns ------- results : list If `write_worker_results` is `True`, `results` is a list of HDF5 file-names containing computed results. If `write_worker_results` is `False`, results is a list comprising the actual return values of `func`. If `:class:~`acme.ACMEDaemon` was instantiated by :class:`~acme.ParallelMap`, results are propagated back to :class:`~acme.ParallelMap`. See also -------- ParallelMap : Context manager and main user interface """ # First and foremost: ensure we got something useful to work with validate_pmap(pmap) # Create configuration self.config = ACMEConfig( func=pmap.func, argv=pmap.argv, kwargv=pmap.kwargv, n_calls=pmap.n_inputs, n_workers=n_workers, write_worker_results=write_worker_results, output_dir=output_dir, result_shape=result_shape, result_dtype=result_dtype, single_file=single_file, write_pickle=write_pickle, dryrun=dryrun, partition=partition, mem_per_worker=mem_per_worker, setup_timeout=setup_timeout, setup_interactive=setup_interactive, stop_client=stop_client, verbose=verbose, logfile=logfile, ) self.config.validate() # Set up output handler self.pre_process() # Set up argument processing helper class self.processor = ArgumentProcessor( self.config.argv, self.config.kwargv, self.config.n_calls ) # Set up memory profiling helper class self.profiler = MemoryProfiler( self.processor, self.config.acme_func, self.config.func.__name__, self.config.tqdmFormat, ) # If requested, perform single-worker dry-run (and quit if desired) if dryrun: goOn = self.profiler.perform_dryrun( output_dir=self.config.output_dir, setup_interactive=self.config.setup_interactive, ) if not goOn: log.debug("Quitting after dryrun") return log.debug("Continuing after dryrun") # Either use existing dask client or start a fresh instance self.prepare_client()
[docs] def pre_process(self) -> None: """ If `write_*` is `True` set up directories for saving output HDF5 containers (or pickle files). Warn if results are to be collected in memory """ # If automatic saving of results is requested, make necessary preparations if self.config.write_worker_results: self.setup_output() else: # If `taskID` is not an explicit kw-arg of `func` and `func` does not # accept "anonymous" `**kwargs`, don't save anything but return stuff log.debug("Automatic output processing disabled.") if self.config.kwargv.get("taskID") is None: if not isSpyModule: msg = ( "`write_worker_results` is `False` and `taskID` is not a keyword argument of %s. " + "Results will be collected in memory by caller - this might be slow and can lead " + "to excessive memory consumption. " ) log.warning(msg, self.config.func.__name__) self.config.collect_results = True # type: ignore else: self.config.kwargv["taskID"] = self.config.task_ids self.config.collect_results = False # type: ignore msg = ( "Not collecting results in memory, leaving output " + "processing to user-provided function" ) log.debug(msg) # The "raw" user-provided function is used in the computation self.config.acme_func = self.config.func log.debug("Not wrapping user-provided function but invoking it directly") # If progress tracking in a log-file was requested, set it up now prepare_log( logname="ACME", logfile=self.config.logfile, verbose=self.config.verbose ) log.debug("Set up logfile=%s", str(self.config.logfile)) return
[docs] def setup_output(self) -> None: """ Local helper for creating output directories and preparing containers """ # Unless specifically denied by the user, each worker stores results # separately with a common container file pointing to the individual # by-worker files residing in a "payload" directory if not self.config.single_file and not self.config.write_pickle: log.debug("Preparing payload directory for HDF5 containers") payloadName = f"{self.config.func.__name__}_payload" outputDir = os.path.join(self.config.output_dir, payloadName) # type: ignore else: msg = ( "Either single-file output or pickling was requested. " + "Not creating payload directory" ) log.debug(msg) outputDir = self.config.output_dir try: os.makedirs(outputDir) log.debug("Created %s", outputDir) except Exception as exc: err = "automatic creation of output folder %s failed: %s" log.error(err, outputDir, str(exc)) raise OSError(err % (outputDir, str(exc))) # Re-define or allocate key "taskID" to track concurrent processing results self.config.kwargv["taskID"] = self.config.task_ids self.config.collect_results = False # Set up correct file-extension for output files; in case of HDF5 # containers, prepare "main" file for collecting/symlinking worker results if self.config.write_pickle: fExt = "pickle" log.debug("Pickling was requested") else: fExt = "h5" self.config.results_container = os.path.join(self.config.output_dir, f"{self.config.func.__name__}.h5") # type: ignore log.debug("Using HDF5 storage %s", self.config.results_container) # By default, `results_container` is a collection of links that point to # worker-generated HDF5 containers; if `single_file` is `True`, then # `results_container` is a "real" container with actual dataset(s) if self.config.single_file: self.config.kwargv["singleFile"] = [True] self.config.kwargv["outFile"] = [self.config.results_container] log.debug("Saving results in single HDF5 container") # If no output shape provided, prepare groups for storing datasets; # otherwise allocate a single dataset w/specified dimension if self.config.result_shape is None: msg = "Created group comp_%d in single shared results container" with h5py.File(self.config.results_container, "w") as h5f: for i in self.config.task_ids: h5f.create_group(f"comp_{i}") log.debug(msg, i) else: if np.inf in self.config.result_shape: # type: ignore actShape = tuple(spec if spec is not np.inf else 1 for spec in self.config.result_shape) # type: ignore maxShape = tuple(spec if spec is not np.inf else None for spec in self.config.result_shape) # type: ignore else: actShape = self.config.result_shape # type: ignore maxShape = None msg = ( "Created unique dataset 'result_0' with shape %s " + "in single shared results container" ) with h5py.File(self.config.results_container, "w") as h5f: h5f.create_dataset( "result_0", shape=actShape, maxshape=maxShape, dtype=self.config.result_dtype, ) log.debug(msg, str(self.config.result_shape)) else: self.config.kwargv["outFile"] = [ os.path.join(outputDir, f"{self.config.func.__name__}_{taskID}.{fExt}") for taskID in self.config.task_ids ] if not self.config.write_pickle: # If no output shape provided, generate links to external datasets; # otherwise allocate a virtual dataset w/specified dimension if self.config.result_shape is None: msg = ( "Created external link comp_%d pointing to " + "%s in results container" ) with h5py.File(self.config.results_container, "w") as h5f: for i, fname in enumerate(self.config.kwargv["outFile"]): relPath = os.path.join(payloadName, os.path.basename(fname)) h5f[f"comp_{i}"] = h5py.ExternalLink(relPath, "/") log.debug(msg, i, relPath) else: VSourceShape = [spec if spec is not np.inf else None for spec in self.config.result_shape] # type: ignore VSourceShape.pop(self.config.stacking_dim) # type: ignore VSourceShape = tuple(VSourceShape) # Account for resizable datasets if None in VSourceShape: resActShape = tuple(spec if spec is not np.inf else 1 for spec in self.config.result_shape) # type: ignore resMaxShape = tuple(spec if spec is not np.inf else None for spec in self.config.result_shape) # type: ignore vsActShape = tuple( spec if spec is not None else 1 for spec in VSourceShape ) vsMaxShape = VSourceShape else: resActShape = self.config.result_shape # type: ignore resMaxShape = None vsActShape = VSourceShape vsMaxShape = None layout = h5py.VirtualLayout( shape=resActShape, dtype=self.config.result_dtype, maxshape=resMaxShape, ) # type: ignore idx = [slice(None) if spec is not np.inf else slice(h5py.h5s.UNLIMITED) for spec in self.config.result_shape] # type: ignore jdx = list(idx) jdx.pop(self.config.stacking_dim) # type: ignore msg = ( "Created virtual dataset result_0' with shape " + "%s in results container" ) for i, fname in enumerate(self.config.kwargv["outFile"]): idx[self.config.stacking_dim] = i # type: ignore relPath = os.path.join(payloadName, os.path.basename(fname)) vsource = h5py.VirtualSource( fname, "result_0", shape=vsActShape, maxshape=vsMaxShape ) layout[tuple(idx)] = vsource[tuple(jdx)] with h5py.File( self.config.results_container, "w", libver="latest" ) as h5f: h5f.create_virtual_dataset("result_0", layout) log.debug(msg, self.config.result_shape) # Include logger name in keywords so that workers can use it self.config.kwargv["logName"] = [log.name] # Wrap the user-provided func and distribute it across workers self.config.kwargv["userFunc"] = [self.config.func] self.config.acme_func = self.func_wrapper # type: ignore log.debug("Wrapping user-provided function inside func_wrapper") # Finally, attach verbosity flag to enable logging inside wrapper self.config.kwargv["logLevel"] = [log.level] return
[docs] def prepare_client(self) -> None: """ Setup or fetch dask distributed processing client. Depending on available hardware, either start a local multi-processing client or launch a worker cluster via SLURM. Also ensure that ad-hoc clients created here are stopped and worker jobs are properly released at the end of computation. However, ensure any client not created by `prepare_client` is **not** automatically cleaned up. """ # Check if a dask client is already running try: self.config.client = dd.get_client() # type: ignore log.debug("Detected running client %s", str(self.config.client)) if self.config.stop_client == "auto": self.config.stop_client = False msg = ( "Changing `stop_client` from `'auto'` to `False` " + "to not terminate external client" ) log.debug(msg) self.config.n_workers = count_online_workers(self.config.client.cluster) # type: ignore log.debug("Found %d alive workers in the client", self.config.n_workers) msg = "Attaching to parallel computing client %s" log.info(msg % (str(self.config.client))) return except ValueError: msg = "No running client detected, preparing to start a new one" log.debug(msg) if self.config.stop_client == "auto": self.config.stop_client = True # type: ignore msg = ( "Changing `stop_client` from `'auto'` to `True` " + "to clean up client started by `ParallelMap`" ) log.debug(msg) # If things are running locally, simply fire up a dask-distributed client, # otherwise go through the motions of preparing a full worker cluster if not self.config.has_slurm: # pragma: no cover log.debug("SLURM not found, Calling `local_cluster_setup`") self.config.client = local_cluster_setup(n_workers=self.config.n_workers, interactive=False) # type: ignore else: # If `partition` is "auto", attempt to heuristically determine average # memory consumption of jobs if partition == "auto": mem_per_worker = self.profiler.estimate_memory(self.config.output_dir) # All set, remaining input processing is done by respective `*_cluster_setup` routines if is_esi_node(): msg = "Running on ESI compute node, Calling `esi_cluster_setup`" log.debug(msg) self.config.client = esi_cluster_setup( partition=partition, n_workers=n_workers, # type: ignore mem_per_worker=mem_per_worker, timeout=setup_timeout, interactive=setup_interactive, start_client=True, ) # All set, remaining input processing is done by respective `*_cluster_setup` routines elif is_bic_node(): msg = "Running on CoBIC compute node, Calling `bic_cluster_setup`" log.debug(msg) self.config.client = bic_cluster_setup( partition=partition, n_workers=n_workers, # type: ignore mem_per_worker=mem_per_worker, timeout=setup_timeout, interactive=setup_interactive, start_client=True, ) # Unknown cluster node, use vanilla config else: # pragma: no cover wrng = ( "Cluster node %s not recognized. Falling back to vanilla " + "SLURM setup allocating one worker and one core per worker" ) log.warning(wrng % (socket.getfqdn())) processes_per_worker = 1 n_cores = 1 self.config.client = slurm_cluster_setup( partition=partition, # type: ignore n_cores=n_cores, n_workers=n_workers, # type: ignore processes_per_worker=processes_per_worker, mem_per_worker=mem_per_worker, n_workers_startup=1, timeout=setup_timeout, interactive=setup_interactive, interactive_wait=120, start_client=True, job_extra=[], invalid_partitions=[], ) # If startup is aborted by user, get outta here if self.config.client is None: # pragma: no cover err = "Could not start distributed computing client. " log.error(err) raise ConnectionAbortedError(err) # Set `n_workers` to no. of active workers in the initialized cluster self.config.n_workers = len(self.config.client.cluster.workers) # type: ignore log.debug( "Setting `n_workers = %d` based on active workers in %s", self.config.n_workers, str(self.config.client), ) # If single output file saving was chosen, initialize distributed # lock for shared writing to container if self.config.kwargv.get("singleFile") is not None: msg = "Initializing distributed lock for writing to single shared results container" log.debug(msg) dd.lock.Lock(name=os.path.basename(self.config.results_container)) # type: ignore return
[docs] def compute(self, debug: bool = False) -> Union[List, None]: """ Perform the actual parallel execution of `func` If `debug` is `True`, use a single-threaded dask scheduler that does not actually process anything concurrently but uses the dask framework in a sequential setup. """ validate_boolean(debug, name="debug") # If `prepare_client` has not been called yet, don't attempt to compute anything if self.config.client is None: log.debug("No parallel computing client allocated, exiting") return None # Check if the underlying parallel computing cluster hosts actually usable workers if count_online_workers(self.config.client.cluster) == 0: err = ( "no active workers found in distributed computing client %s " + "Consider running \n" + "\timport dask.distributed as dd; dd.get_client().restart()\n" + "If this fails to make workers come online, please use\n" + "\timport acme; acme.cluster_cleanup()\n" + "to shut down any defunct distributed computing clients" ) log.error(err, str(self.config.client)) raise RuntimeError(err % (str(self.config.client))) log.debug( "Found %d workers in client %s", count_online_workers(self.config.client.cluster), str(self.config.client), ) # Dask does not correctly forward the `sys.path` from the parent process # to its workers. Fix this. def init_acme(dask_worker, syspath): sys.path = list(syspath) self.config.client.register_worker_callbacks( setup=functools.partial(init_acme, syspath=sys.path) ) log.debug("Registered worker callback to forward `sys.path`") # Broadcast arguments and format keyword arguments self.config.argv, self.config.kwargv = self.processor.broadcast_arguments( self.config.client ) kwargList = self.processor.format_kwarg_list() # In case a debugging run is performed, use the single-threaded scheduler and return if debug: log.warning("Running in debug mode") with dask.config.set(scheduler="single-threaded"): log.debug("Using single-threaded scheduler to evaluate function") values = self.config.client.gather( [ self.config.client.submit( self.config.acme_func, *args, **kwargs ) for args, kwargs in zip(zip(*self.config.argv), kwargList) ] ) return values # Depending on the used dask cluster object, point to respective log info if isinstance(self.config.client.cluster, SLURMCluster): logFiles = self.config.client.cluster.job_header.split("--output=")[ 1 ].replace("%j", "{}") logDir = os.path.split(logFiles)[0] else: # pragma: no cover logFiles = [] logDir = ( os.path.dirname(self.config.client.cluster.dashboard_link) + "/info/main/workers.html" ) msg = "Preparing %d parallel calls of `%s` using %d workers" log.info( msg % (self.config.n_calls, self.config.func.__name__, self.config.n_workers) ) msg = "Log information available at %s" log.debug(msg % (logDir)) # Submit `self.config.n_calls` function calls to the cluster log.debug( "Submitting %d function calls to client %s", self.config.n_calls, str(self.config.client), ) futures = [ self.config.client.submit(self.config.acme_func, *args, **kwargs) for args, kwargs in zip(zip(*self.config.argv), kwargList) ] # Set up progress bar: the while loop ensures all futures are executed totalTasks = len(futures) pbar = tqdm.tqdm( total=totalTasks, bar_format=self.config.tqdmFormat, position=0, leave=True ) cnt = 0 while any(f.status == "pending" for f in futures): time.sleep(self.config.sleepTime) new = max(0, sum([f.status == "finished" for f in futures]) - cnt) cnt += new pbar.update(new) pbar.close() # Avoid race condition: give futures time to perform switch from 'pending' # to 'finished' so that `finishedTasks` is computed correctly log.debug("Waiting %f seconds for futures", self.config.sleepTime) time.sleep(self.config.sleepTime) # If number of 'finished' tasks is less than expected, go into # problem analysis mode: all futures that erred hav an `.exception` # method which can be used to track down the worker it was executed by # Once we know the worker, we can point to the right log file. If # futures were cancelled (by the user or the SLURM controller), # `.exception` is `None` and we can't reliably track down the # respective executing worker finishedTasks = sum([f.status == "finished" for f in futures]) if finishedTasks < totalTasks: schedulerLog = list( self.config.client.cluster.get_logs( cluster=False, scheduler=True, workers=False ).values() )[0] erredFutures = [f for f in futures if f.status == "error"] msg = "Parallel computation failed: %d/%d tasks failed or stalled. " msg = msg % (totalTasks - finishedTasks, totalTasks) msg += "Concurrent computing scheduler log info: " msg += schedulerLog + "\n" # If we're working w/`SLURMCluster`, perform the Herculean task of # tracking down which dask worker was executed by which SLURM job... if self.config.client.cluster.__class__.__name__ == "SLURMCluster": try: erredJobs = [ f.exception().last_worker.identity()["id"] for f in erredFutures ] except AttributeError: erredJobs = [] erredJobs = list(set(erredJobs)) validIDs = [ job for job in erredJobs if job in self.config.client.cluster.workers.keys() ] erredJobIDs = [ self.config.client.cluster.workers[job].job_id for job in validIDs ] errFiles = glob.glob(logDir + os.sep + "*.err") if len(erredFutures) > 0 or len(errFiles) > 0: msg += "Please consult the following SLURM log files for details:\n" if len(erredJobIDs) > 0: msg += "".join(logFiles.format(id) + "\n" for id in erredJobIDs) else: msg += "".join(logDir) msg += "".join(errfile + "\n" for errfile in errFiles) else: msg += "Please check SLURM logs in %s" % (logDir) # In case of a `LocalCluster`, syphon worker logs else: # pragma: no cover msg += "Parallel worker log details: \n" workerLogs = self.config.client.get_worker_logs().values() for wLog in workerLogs: if "Failed" in wLog: msg += wLog # Finally, raise an error and get outta here log.error(msg) raise RuntimeError(msg) # Postprocessing of results values = self.post_process(futures) # Either return collected by-worker results or the filepaths of results return values
[docs] def post_process(self, futures: dd.Future) -> Union[List, None]: """ Local helper to post-process results on disk/in-memory The return `values` is either `None` : if neither in-memory results collection or auto-writing was requested list of file-names: if `write_worker_results` is `True` list of objects: if in-memory results collection was requested """ # Deduce result output information write_worker_results = self.config.acme_func == self.func_wrapper single_file = False if write_worker_results: write_pickle = self.config.results_container is None if not write_pickle and self.config.kwargv.get("singleFile") is not None: single_file = True else: write_pickle = False msg = "Inferred that `write_worker_results = %s`, `single_file = %s`, `write_pickle = %s`" log.debug(msg, str(write_worker_results), str(single_file), str(write_pickle)) # If wanted (not recommended) collect computed results in local memory if self.config.collect_results: if not isSpyModule: log.info("Gathering results in local memory") collected = self.config.client.gather(futures) log.debug( "Gathered results from client in a %d-element list", len(collected) ) if self.config.result_shape is not None: log.debug( "Returning single NumPy array of shape %s and type %s", str(self.config.result_shape), str(self.config.result_dtype), ) values = [] arrVal = np.empty( shape=self.config.result_shape, dtype=self.config.result_dtype ) idx = [slice(None)] * len(self.config.result_shape) for i, res in enumerate(collected): if not isinstance(res, (list, tuple)): res = [res] idx[self.config.stacking_dim] = i arrVal[tuple(idx)] = res[0] for r in res[1:]: values.append(r) values.insert(0, arrVal) # If `values` is a single array, don't wrap it inside a list if len(values) == 1: values = values[0] else: log.debug("Returning a list of values") values = collected else: values = None # Prepare final output message successMsg = "SUCCESS!" # If automatic results writing was requested, perform some housekeeping if write_worker_results: finalMsg = "Results have been saved to %s" if write_pickle: log.debug("Saved results as pickle files") values = list(self.config.kwargv["outFile"]) finalMsg = finalMsg % (self.config.output_dir) log.debug("Returning a list of file-names") else: if single_file: log.debug("Saved results to single shared container") finalMsg = finalMsg % (self.config.results_container) if values is None: values = [self.config.results_container] log.debug("Returning container name as single-element list") else: log.debug("Scanning payload directory for emergency pickles") picklesFound = False values = [] for fname in self.config.kwargv["outFile"]: pklName = fname.rstrip(".h5") + ".pickle" if os.path.isfile(fname): values.append(fname) elif os.path.isfile(pklName): values.append(pklName) picklesFound = True log.debug("Found emergency pickle %s", pklName) else: missing = fname.rstrip(".h5") values.append("Missing %s" % (missing)) log.debug("Missing file %s", missing) payloadDir = os.path.dirname(values[0]) # If pickles are found, remove global `results_container` as it # would contain invalid file-links and move compute results out # of payload dir if picklesFound: os.unlink(self.config.results_container) # type: ignore wrng = ( "Some compute runs could not be saved as HDF5, " + "collection container %s has been removed as it would " + "comprise invalid file-links" ) log.warning(wrng, self.config.results_container) self.config.results_container = None # Move files out of payload dir and update return `values` target = os.path.abspath(os.path.join(payloadDir, os.pardir)) for i, fname in enumerate(values): shutil.move(fname, target) self.config.kwargv["outFile"][i] = os.path.join( target, os.path.basename(fname) ) log.debug("Moved %s to %s", fname, target) values = list(self.config.kwargv["outFile"]) log.debug("Returning a list of file-names") shutil.rmtree(payloadDir) log.debug("Deleted payload directory %s", payloadDir) successMsg = "" finalMsg = finalMsg % (target) # All good, no pickle gymnastics was needed else: # In case of multiple return values present in by-worker # containers but missing in collection container (happens # if `result_shape` is not `None` and data-sets have to # be pre-allocated), create "symlinks" to corresponding # missing returns log.debug("No emergency pickles found") if self.config.stacking_dim is not None: msg = ( "Check if additional return values " + "need to be added to container with pre-allocated dataset" ) log.debug(msg) with h5py.File(self.config.results_container, "r") as h5r: with h5py.File(values[0], "r") as h5Tmp: missingReturns = set(h5Tmp.keys()).difference( h5r.keys() ) if len(missingReturns) > 0: log.debug("Found return values to be added") with h5py.File( self.config.results_container, "a" ) as h5r: for retVal in missingReturns: for i, fname in enumerate(values): relPath = os.path.join( os.path.basename(payloadDir), os.path.basename(fname), ) h5r[f"comp_{i}/{retVal}"] = ( h5py.ExternalLink(relPath, retVal) ) log.debug( "Added return value via external link comp_%d/%s", i, retVal, ) finalMsg = finalMsg % (self.config.results_container) msg = "Container ready, links to data payload located in %s" log.debug(msg, payloadDir) log.debug("Returning a list of file-names") else: finalMsg = "Finished parallel computation" # Finally, estabslish shortcut to `results_container` (if present) for easier access self.results_container = self.config.results_container # Print final triumphant output message and force-flush all logging handlers if len(successMsg) > 0: log.announce(successMsg) # type: ignore log.info(finalMsg) for h in log.handlers: if hasattr(h, "flush"): h.flush() return values
[docs] def cleanup(self) -> None: """ Shut down any ad-hoc distributed computing clients created by `prepare_client` """ # If `prepare_client` has not been launched yet, just get outta here if self.config.client is None: log.debug("Helper `prepare_client` not yet launched, exiting") return if self.config.stop_client and self.config.client is not None: log.debug( "Found client %s, calling `cluster_cleanup`", str(self.config.client) ) cluster_cleanup(self.config.client) self.config.client = None return log.debug("Either `stop_client = False` or no client found, returning") return
[docs] @staticmethod def func_wrapper(*args: Any, **kwargs: Optional[Any]) -> None: # pragma: no cover """ If the output of `func` is saved to disk, wrap `func` with this static method to take care of filling up HDF5/pickle files If writing to HDF5 fails, use an "emergency-pickling" mechanism to try to save the output of `func` using pickle instead """ # Extract everything from `kwargs` appended by `ACMEdaemon` func = kwargs.pop("userFunc") taskID = kwargs.pop("taskID") fname = kwargs.pop("outFile") logName = kwargs.pop("logName") logLevel = kwargs.pop("logLevel") singleFile = kwargs.pop("singleFile", False) stackingDim = kwargs.pop("stackingDim", None) memEstRun = kwargs.pop("memEstRun", False) # Set up logger log = logging.getLogger(logName) log.setLevel(logLevel) # type: ignore for h in log.handlers: h.setLevel(logLevel) # type: ignore # Call user-provided function result = func(*args, **kwargs) # type: ignore # For memory estimation runs, don't start saving stuff if memEstRun: return # Save results: either (try to) use HDF5 or pickle stuff if fname.endswith(".h5"): # type: ignore grpName = "" if singleFile: lock = dd.lock.Lock(name=os.path.basename(fname)) # type: ignore lock.acquire() grpName = f"comp_{taskID}/" if not isinstance(result, (list, tuple)): result = [result] try: with h5py.File(fname, "a") as h5f: if stackingDim is None: if not all( isinstance(value, (numbers.Number, str)) for value in result ): for rk, res in enumerate(result): h5f.create_dataset(f"{grpName}result_{rk}", data=res) log.debug( "Created new dataset `result_%d` in %s", rk, fname ) else: h5f.create_dataset(grpName + "result_0", data=result) log.debug("Created new dataset `result_0` in %s", fname) else: if singleFile: dset = h5f["result_0"] idx = [slice(None)] * len(dset.shape) # type: ignore idx[stackingDim] = taskID # type: ignore if None in dset.maxshape: if len(result[0].shape) < len(idx): lenDim = list( set(result[0].shape).difference(dset.maxshape) ) if len(lenDim) == 0: lenDim = result[0].shape[0] else: lenDim = lenDim[0] actShape = tuple( spec if spec is not None else lenDim for spec in dset.maxshape ) else: actShape = list(result[0].shape) # type: ignore actShape[stackingDim] = dset.maxshape[stackingDim] # type: ignore actShape = tuple(actShape) dset.resize(actShape) dset[tuple(idx)] = result[0] log.debug( "Wrote to pre-allocated dataset `result_0` in %s", fname ) for rk, res in enumerate(result[1:]): h5f.create_dataset( f"{grpName}result_{rk + 1}", data=res ) log.debug( "Created new dataset `result_%d` in %s", rk + 1, fname, ) else: for rk, res in enumerate(result): h5f.create_dataset(f"{grpName}result_{rk}", data=res) log.debug( "Created new dataset `result_%d` in %s", rk, fname ) if singleFile: lock.release() except TypeError as exc: if ( "has no native HDF5 equivalent" in str(exc) or "One of data, shape or dtype must be specified" in str(exc) ) and not singleFile: try: os.unlink(fname) # type: ignore pname = fname.rstrip(".h5") + ".pickle" # type: ignore with open(os.path.join(pname), "wb") as pkf: pickle.dump(result, pkf) msg = ( "Could not write %s results have been pickled instead: %s. Return values are most likely " + "not suitable for storage in HDF5 containers. Original error message: %s" ) log.warning(msg, fname, pname, str(exc)) except pickle.PicklingError as pexc: err = "Unable to write %s, successive attempts to pickle results failed too: %s" log.error(err, fname, str(pexc)) else: if singleFile: err = "Could not write to %s. File potentially corrupted. Original error message: %s" lock.release() else: err = "Could not access %s. Original error message: %s" log.error(err, fname, str(exc)) raise exc except Exception as exc: if singleFile: lock.release() log.error(str(exc)) raise exc else: try: with open(os.path.join(fname), "wb") as pkf: # type: ignore pickle.dump(result, pkf) log.debug("Pickled to %s", fname) except pickle.PicklingError as pexc: err = "Could not pickle results to file %s. Original error message: %s" log.error(err, fname, str(pexc)) raise pexc return