Source code for at.tracking.track

from __future__ import annotations

__all__ = [
    "lattice_track",
    "element_track",
    "internal_lpass",
    "internal_epass",
    "internal_plpass",
]

import multiprocessing
from collections.abc import Iterable
from functools import partial
from warnings import warn

import numpy

from .atpass import atpass as _atpass, elempass as _elempass, reset_rng
from .utils import fortran_align, has_collective, format_results
from .utils import initialize_lpass, disable_varelem, variable_refs
from ..lattice import AtWarning, DConstant, random
from ..lattice import Lattice, Element, Refpts, End
from ..lattice import get_uint32_index

_imax = numpy.iinfo(int).max
_globring: list[Element] | None = None


def _atpass_fork(seed, rank, rin, **kwargs):
    """Single forked job"""
    reset_rng(rank=rank, seed=seed)
    result = _atpass(_globring, rin, **kwargs)
    return rin, result


def _atpass_spawn(ring, seed, rank, rin, **kwargs):
    """Single spawned job"""
    reset_rng(rank=rank, seed=seed)
    result = _atpass(ring, rin, **kwargs)
    return rin, result


def _pass(ring, r_in, pool_size, start_method, seed, **kwargs):
    ctx = multiprocessing.get_context(start_method)
    # Split input in as many slices as processes
    args = enumerate(numpy.array_split(r_in, pool_size, axis=1))
    # Generate a new starting point for C RNGs
    global _globring
    _globring = ring
    if ctx.get_start_method() == "fork":
        passfunc = partial(_atpass_fork, seed, **kwargs)
    else:
        passfunc = partial(_atpass_spawn, ring, seed, **kwargs)
    # Start the parallel jobs
    with ctx.Pool(pool_size) as pool:
        results = pool.starmap(passfunc, args)
    _globring = None
    # Gather the results
    losses = kwargs.pop("losses", False)
    return format_results(results, r_in, losses)


@fortran_align
def _element_pass(element: Element, r_in, **kwargs):
    return _elempass(element, r_in, **kwargs)


@fortran_align
def _lattice_pass(
    lattice: list[Element],
    r_in,
    nturns: int = 1,
    refpts: Refpts = End,
    no_varelem=True,
    seed: int | None = None,
    **kwargs,
):
    kwargs["reuse"] = kwargs.pop("keep_lattice", False)
    if no_varelem:
        lattice = disable_varelem(lattice)
    else:
        if sum(variable_refs(lattice)) > 0:
            kwargs["reuse"] = False
    refs = get_uint32_index(lattice, refpts)
    if seed is not None:
        reset_rng(seed=seed)
    return _atpass(lattice, r_in, nturns, refpts=refs, **kwargs)


@fortran_align
def _plattice_pass(
    lattice: list[Element],
    r_in,
    nturns: int = 1,
    refpts: Refpts = End,
    seed: int | None = None,
    pool_size: int = None,
    start_method: str = None,
    **kwargs,
):
    refpts = get_uint32_index(lattice, refpts)
    any_collective = has_collective(lattice)
    kwargs["reuse"] = kwargs.pop("keep_lattice", False)
    rshape = r_in.shape
    if len(rshape) >= 2 and rshape[1] > 1 and not any_collective:
        if seed is None:
            seed = random.common.integers(0, high=_imax, dtype=int)
        if pool_size is None:
            pool_size = min(
                len(r_in[0]), multiprocessing.cpu_count(), DConstant.patpass_poolsize
            )
        if start_method is None:
            start_method = DConstant.patpass_startmethod
        return _pass(
            lattice,
            r_in,
            pool_size,
            start_method,
            seed=seed,
            nturns=nturns,
            refpts=refpts,
            **kwargs,
        )
    else:
        if seed is not None:
            reset_rng(seed=seed)
        if any_collective:
            warn(
                AtWarning("Collective PassMethod found: use single process"),
                stacklevel=2,
            )
        else:
            warn(
                AtWarning("no parallel computation for a single particle"), stacklevel=2
            )
        return _atpass(lattice, r_in, nturns=nturns, refpts=refpts, **kwargs)


[docs] def lattice_track( lattice: Iterable[Element], r_in, nturns: int = 1, refpts: Refpts = End, in_place: bool = False, **kwargs, ): """ :py:func:`track_function` tracks particles through each element of a lattice or throught a single Element calling the element-specific tracking function specified in the Element's *PassMethod* field. Usage: >>> lattice_track(lattice, r_in) >>> lattice.track(r_in) Parameters: lattice: list of elements r_in: (6, N) array: input coordinates of N particles. *r_in* is modified in-place only if *in_place* is :py:obj:`True` and reports the coordinates at the end of the element. For the best efficiency, *r_in* should be given as F_CONTIGUOUS numpy array. Keyword arguments: nturns: number of turns to be tracked refpts: Selects the location of coordinates output. See ":ref:`Selecting elements in a lattice <refpts>`" in_place (bool): If True *r_in* is modified in-place and reports the coordinates at the end of the element. (default: False) seed (int | None): Seed for the random generators. If None (default) continue the sequence keep_lattice (bool): Use elements persisted from a previous call. If :py:obj:`True`, assume that the lattice has not changed since the previous call. keep_counter (bool): Keep the turn number from the previous call. turn (int): Starting turn number. Ignored if *keep_counter* is :py:obj:`True`. The turn number is necessary to compute the absolute path length used in RFCavityPass. losses (bool): Boolean to activate loss maps output omp_num_threads (int): Number of OpenMP threads (default: automatic) use_mp (bool): Flag to activate multiprocessing (default: False) pool_size: number of processes used when *use_mp* is :py:obj:`True`. If None, ``min(npart,nproc)`` is used. It can be globally set using the variable *at.lattice.DConstant.patpass_poolsize* start_method: python multiprocessing start method. :py:obj:`None` uses the python default that is considered safe. Available values: ``'fork'``, ``'spawn'``, ``'forkserver'``. Default for linux is ``'fork'``, default for macOS and Windows is ``'spawn'``. ``'fork'`` may be used on macOS to speed up the calculation or to solve Runtime Errors, however it is considered unsafe. Used only when *use_mp* is :py:obj:`True`. It can be globally set using the variable *at.lattice.DConstant.patpass_startmethod* The following keyword arguments overload the lattice values Keyword arguments: particle (Optional[Particle]): circulating particle. Default: :code:`lattice.particle` if existing, otherwise :code:`Particle('relativistic')` energy (Optiona[float]): lattice energy. Default 0. unfold_beam (bool): Internal beam folding activate, this assumes the input particles are in bucket 0, works only if all bucket see the same RF Voltage. Default: :py:obj:`True` If *energy* is not available, relativistic tracking if forced, *rest_energy* is ignored. Returns: r_out: (6, N, R, T) array containing output coordinates of N particles at R reference points for T turns trackparam: A dictionary containing tracking input parameters with the following keys: ============== =================================================== **npart** number of particles **rout** final particle coordinates **turn** starting turn **refpts** array of index where particle coordinate are saved (only for lattice tracking) **nturns** number of turn ============== =================================================== trackdata: A dictionary containing tracking data with the following keys: ============== =================================================== **loss_map**: recarray containing the loss_map (only for lattice tracking) ============== =================================================== The **loss_map** is filled only if *losses* is :py:obj:`True`, it contains the following keys: ============== =================================================== **islost** (npart,) bool array indicating lost particles **turn** (npart,) int array indicating the turn at which the particle is lost **element** (npart,) int array indicating the element at which the particle is lost **coord** (npart, 6) float array giving the coordinates at which the particle is lost (zero for surviving particles) ============== =================================================== .. note:: * :pycode:`track_function(lattice, r_in, refpts=len(line))` is the same as :pycode:`track_function(lattice, r_in)` since the reference point len(line) is the exit of the last element. * :pycode:`track_function(lattice, r_in, refpts=0)` is a copy of *r_in* since the reference point 0 is the entrance of the first element. * To resume an interrupted tracking (for instance to get intermediate results), one must use one of the *turn* or *keep_counter* keywords to ensure the continuity of the turn number. * For multiparticle tracking with large number of turn the size of *r_out* may increase excessively. To avoid memory issues :pycode:`track_function(lattice, r_in, refpts=None, in_place=True)` can be used. An empty list is returned and the tracking results of the last turn are stored in *r_in*. * To model buckets with different RF voltage :pycode:`unfold_beam=False` has to be used. The beam can be unfolded using the function :py:func:`.unfold_beam`. This function takes into account the true voltage in each bucket and distributes the particles in the bunches defined by :code:`ring.fillpattern` using a 6D orbit search. """ trackdata = {} trackparam = {} part_kw = ["energy", "particle"] try: npart = numpy.shape(r_in)[1] except IndexError: npart = 1 [trackparam.update((kw, kwargs.get(kw))) for kw in kwargs if kw in part_kw] trackparam.update({"npart": npart}) if not in_place: r_in = r_in.copy() lattice = initialize_lpass(lattice, nturns, kwargs) ldtype = [ ("islost", numpy.bool_), ("turn", numpy.uint32), ("elem", numpy.uint32), ("coord", numpy.float64, (6,)), ] loss_map = numpy.recarray((npart,), ldtype) lat_kw = ["turn"] [trackparam.update((kw, kwargs.get(kw))) for kw in kwargs if kw in lat_kw] trackparam.update({"refpts": get_uint32_index(lattice, refpts), "nturns": nturns}) use_mp = kwargs.pop("use_mp", False) start_method = kwargs.pop("start_method", None) pool_size = kwargs.pop("pool_size", None) if use_mp: kwargs.update({"pool_size": pool_size, "start_method": start_method}) rout = _plattice_pass(lattice, r_in, nturns=nturns, refpts=refpts, **kwargs) else: rout = _lattice_pass( lattice, r_in, nturns=nturns, refpts=refpts, no_varelem=False, **kwargs ) if kwargs.get("losses", False): rout, lm = rout lm["coord"] = lm["coord"].T for k, v in lm.items(): loss_map[k] = v trackdata.update({"loss_map": loss_map}) trackparam.update({"rout": r_in}) return rout, trackparam, trackdata
[docs] def element_track(element: Element, r_in, in_place: bool = False, **kwargs): """ :py:func:`element_track` tracks particles through one element of a calling the element-specific tracking function specified in the Element's *PassMethod* field Usage: >>> element_track(element, r_in) >>> element.track(r_in) Parameters: element: element to track through r_in: (6, N) array: input coordinates of N particles. For the best efficiency, *r_in* should be given as F_CONTIGUOUS numpy array. Keyword arguments: in_place (bool): If True *r_in* is modified in-place and reports the coordinates at the end of the element. (default: False) omp_num_threads (int): Number of OpenMP threads (default: automatic) particle (Optional[Particle]): circulating particle. Default: :code:`lattice.particle` if existing, otherwise :code:`Particle('relativistic')` energy (Optiona[float]): lattice energy. Default 0. Returns: r_out: (6, N, R, T) array containing output coordinates of N particles at R reference points for T turns """ if not in_place: r_in = r_in.copy() rout = _element_pass(element, r_in, **kwargs) return rout
internal_lpass = _lattice_pass internal_epass = _element_pass internal_plpass = _plattice_pass Lattice.track = lattice_track Element.track = element_track