Source code for xradio.measurement_set.load_processing_set

import os
from typing import Dict, Union
import xarray as xr
import time


[docs] def load_processing_set( ps_store: str, sel_parms: dict = None, data_group_name: str = None, include_variables: Union[list, None] = None, drop_variables: Union[list, None] = None, load_sub_datasets: bool = True, ) -> xr.DataTree: """Loads a processing set into memory. Parameters ---------- ps_store : str String of the path and name of the processing set. For example '/users/user_1/uid___A002_Xf07bba_Xbe5c_target.lsrk.vis.zarr' for a file stored on a local file system, or 's3://viper-test-data/Antennae_North.cal.lsrk.split.vis.zarr/' for a file in AWS object storage. sel_parms : dict, optional A dictionary where the keys are the names of the ms_xdt's (measurement set xarray data trees) and the values are slice_dicts. slice_dicts: A dictionary where the keys are the dimension names and the values are slices. For example:: { 'ms_v4_name_1': {'frequency': slice(0, 160, None),'time':slice(0,100)}, ... 'ms_v4_name_n': {'frequency': slice(0, 160, None),'time':slice(0,100)}, } By default None, which loads all ms_xdts. data_group_name : str, optional The name of the data group to select. By default None, which loads all data groups. include_variables : Union[list, None], optional The list of data variables to load into memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will load all data variables into memory. drop_variables : Union[list, None], optional The list of data variables to drop from memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will not drop any data variables from memory. load_sub_datasets : bool, optional If true sub-datasets (for example weather_xds, antenna_xds, pointing_xds, system_calibration_xds ...) will be loaded into memory, by default True. Returns ------- xarray.DataTree In memory representation of processing set using xr.DataTree. """ from xradio._utils.zarr.common import _get_file_system_and_items import s3fs import posixpath file_system, ms_store_list = _get_file_system_and_items(ps_store) ps_xdt = xr.DataTree() if sel_parms: for ms_name, ms_xds_isel in sel_parms.items(): ms_store = posixpath.join(ps_store, ms_name) if isinstance(file_system, s3fs.core.S3FileSystem): ms_store = s3fs.S3Map(root=ms_store, s3=file_system, check=False) if ms_xds_isel: ms_xdt = ( xr.open_datatree( ms_store, engine="zarr", drop_variables=drop_variables, cache=False, chunks=None, consolidated=False, ) .isel(ms_xds_isel) .xr_ms.sel(data_group_name=data_group_name) ) else: ms_xdt = xr.open_datatree( ms_store, engine="zarr", drop_variables=drop_variables, cache=False, chunks=None, consolidated=False, ).xr_ms.sel(data_group_name=data_group_name) if include_variables is not None: vars_to_drop = [ v for v in ms_xdt.ds.data_vars if v not in include_variables ] ms_xdt.ds = ms_xdt.ds.drop_vars(vars_to_drop) ps_xdt[ms_name] = ms_xdt ps_xdt.attrs["type"] = "processing_set" else: ps_xdt = xr.open_datatree( ps_store, engine="zarr", drop_variables=drop_variables, cache=False, chunks=None, ) if (include_variables is not None) or data_group_name: for ms_name, ms_xdt in ps_xdt.items(): ms_xdt = ms_xdt.xr_ms.sel(data_group_name=data_group_name) if include_variables is not None: for data_vars in ms_xdt.ds.data_vars: if data_vars not in include_variables: ms_xdt.ds = ms_xdt.ds.drop_vars(data_vars) ps_xdt[ms_name] = ms_xdt if not load_sub_datasets: for ms_xdt in ps_xdt.children.values(): ms_xdt_names = list(ms_xdt.keys()) for sub_xds_name in ms_xdt_names: if "xds" in sub_xds_name: del ms_xdt[sub_xds_name] ps_xdt = ps_xdt.load() return ps_xdt
class ProcessingSetIterator: def __init__( self, sel_parms: dict, input_data_store: str, input_data: Union[Dict, xr.DataTree, None] = None, data_group_name: str = None, include_variables: Union[list, None] = None, drop_variables: Union[list, None] = None, load_sub_datasets: bool = True, in_memory: bool = False, ): """An iterator that will go through a processing set one MS v4 at a time. Parameters ---------- sel_parms : dict A dictionary where the keys are the names of the ms_xds's and the values are slice_dicts. slice_dicts: A dictionary where the keys are the dimension names and the values are slices. For example:: { 'ms_v4_name_1': {'frequency': slice(0, 160, None),'time':slice(0,100)}, ... 'ms_v4_name_n': {'frequency': slice(0, 160, None),'time':slice(0,100)}, } input_data_store : str String of the path and name of the processing set. For example '/users/user_1/uid___A002_Xf07bba_Xbe5c_target.lsrk.vis.zarr'. input_data : Union[Dict, xr.DataTree, None], optional If the processing set is in memory already it can be supplied here. By default None which will make the iterator load data using the supplied input_data_store. data_group_name : str, optional The name of the data group to select. By default None, which loads all data groups. include_variables : Union[list, None], optional The list of data variables to load into memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will load all data variables into memory. drop_variables : Union[list, None], optional The list of data variables to drop from memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will not drop any data variables from memory. load_sub_datasets : bool, optional If true sub-datasets (for example weather_xds, antenna_xds, pointing_xds, system_calibration_xds ...) will be loaded into memory, by default True. in_memory : bool, optional If True, ms_xdt's are cached as they are loaded so that resetting the iterator does not require reloading from disk. If False, only a single ms_xdt is held in memory at a time. By default False. """ # logger.debug("Memory usage at start of ProcessingSetIterator initialization: " + str(get_rss_gb()) + " GB") self.input_data = input_data self.input_data_store = input_data_store self.sel_parms = sel_parms self.data_group_name = data_group_name self.include_variables = include_variables self.drop_variables = drop_variables self.load_sub_datasets = load_sub_datasets self.in_memory = in_memory self._ms_name_list = list(sel_parms.keys()) self._index = 0 self._current_ms_name: Union[str, None] = None self._current_ms_xdt: Union[xr.DataTree, None] = None self._cache: Dict[str, xr.DataTree] = {} self._load_time: float = 0.0 self._longest_load_time: float = 0.0 # logger.debug("ProcessingSetIterator initialized with " + str(len(self._ms_name_list)) + " ms_xdts to iterate over.") # logger.debug("Memory usage after ProcessingSetIterator initialization: " + str(get_rss_gb()) + " GB") def __iter__(self): return self def reset(self): """Reset the iterator to the beginning. If ``in_memory=True``, previously loaded ms_xdt's are served from the cache on the next pass without reloading from disk. Returns ------- tuple of float A 2-tuple ``(load_time, longest_load_time)``, where ``load_time`` is the total time in seconds spent loading data since the previous call to :meth:`reset`, and ``longest_load_time`` is the maximum such value observed across all iterations of this iterator. """ self._index = 0 self._current_ms_name = None self._current_ms_xdt = None load_time = self._load_time self._load_time = 0.0 if self._longest_load_time < load_time: self._longest_load_time = load_time return load_time, self._longest_load_time def __next__(self): import toolviper.utils.logger as logger # logger.debug("ProcessingSetIterator __next__ called. Current index: " + str(self._index)) # logger.debug("Memory usage at start of __next__: " + str(get_rss_gb()) + " GB") if self._index >= len(self._ms_name_list): raise StopIteration sub_xds_name = self._ms_name_list[self._index] self._index += 1 self._current_ms_name = sub_xds_name T_load_start = time.time() if self.input_data is not None: sub_xdt = self.input_data[sub_xds_name] elif self.in_memory and sub_xds_name in self._cache: sub_xdt = self._cache[sub_xds_name] else: slice_description = self.sel_parms[sub_xds_name] ps_xdt = load_processing_set( ps_store=self.input_data_store, sel_parms={sub_xds_name: slice_description}, data_group_name=self.data_group_name, include_variables=self.include_variables, drop_variables=self.drop_variables, load_sub_datasets=self.load_sub_datasets, ) sub_xdt = ps_xdt[sub_xds_name] if self.in_memory: self._cache[sub_xds_name] = sub_xdt self._current_ms_xdt = sub_xdt self._load_time = self._load_time + time.time() - T_load_start # logger.debug("Memory usage at end of __next__: " + str(get_rss_gb()) + " GB") return sub_xdt def _get_rss_gb(): import psutil return psutil.Process(os.getpid()).memory_info().rss / 1e9