import os
from typing import Dict, Union
import xarray as xr
import time
[docs]
def load_processing_set(
ps_store: str,
sel_parms: dict = None,
data_group_name: str = None,
include_variables: Union[list, None] = None,
drop_variables: Union[list, None] = None,
load_sub_datasets: bool = True,
) -> xr.DataTree:
"""Loads a processing set into memory.
Parameters
----------
ps_store : str
String of the path and name of the processing set. For example '/users/user_1/uid___A002_Xf07bba_Xbe5c_target.lsrk.vis.zarr' for a file stored on a local file system, or 's3://viper-test-data/Antennae_North.cal.lsrk.split.vis.zarr/' for a file in AWS object storage.
sel_parms : dict, optional
A dictionary where the keys are the names of the ms_xdt's (measurement set xarray data trees) and the values are slice_dicts.
slice_dicts: A dictionary where the keys are the dimension names and the values are slices.
For example::
{
'ms_v4_name_1': {'frequency': slice(0, 160, None),'time':slice(0,100)},
...
'ms_v4_name_n': {'frequency': slice(0, 160, None),'time':slice(0,100)},
}
By default None, which loads all ms_xdts.
data_group_name : str, optional
The name of the data group to select. By default None, which loads all data groups.
include_variables : Union[list, None], optional
The list of data variables to load into memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will load all data variables into memory.
drop_variables : Union[list, None], optional
The list of data variables to drop from memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will not drop any data variables from memory.
load_sub_datasets : bool, optional
If true sub-datasets (for example weather_xds, antenna_xds, pointing_xds, system_calibration_xds ...) will be loaded into memory, by default True.
Returns
-------
xarray.DataTree
In memory representation of processing set using xr.DataTree.
"""
from xradio._utils.zarr.common import _get_file_system_and_items
import s3fs
import posixpath
file_system, ms_store_list = _get_file_system_and_items(ps_store)
ps_xdt = xr.DataTree()
if sel_parms:
for ms_name, ms_xds_isel in sel_parms.items():
ms_store = posixpath.join(ps_store, ms_name)
if isinstance(file_system, s3fs.core.S3FileSystem):
ms_store = s3fs.S3Map(root=ms_store, s3=file_system, check=False)
if ms_xds_isel:
ms_xdt = (
xr.open_datatree(
ms_store,
engine="zarr",
drop_variables=drop_variables,
cache=False,
chunks=None,
consolidated=False,
)
.isel(ms_xds_isel)
.xr_ms.sel(data_group_name=data_group_name)
)
else:
ms_xdt = xr.open_datatree(
ms_store,
engine="zarr",
drop_variables=drop_variables,
cache=False,
chunks=None,
consolidated=False,
).xr_ms.sel(data_group_name=data_group_name)
if include_variables is not None:
vars_to_drop = [
v for v in ms_xdt.ds.data_vars if v not in include_variables
]
ms_xdt.ds = ms_xdt.ds.drop_vars(vars_to_drop)
ps_xdt[ms_name] = ms_xdt
ps_xdt.attrs["type"] = "processing_set"
else:
ps_xdt = xr.open_datatree(
ps_store,
engine="zarr",
drop_variables=drop_variables,
cache=False,
chunks=None,
)
if (include_variables is not None) or data_group_name:
for ms_name, ms_xdt in ps_xdt.items():
ms_xdt = ms_xdt.xr_ms.sel(data_group_name=data_group_name)
if include_variables is not None:
for data_vars in ms_xdt.ds.data_vars:
if data_vars not in include_variables:
ms_xdt.ds = ms_xdt.ds.drop_vars(data_vars)
ps_xdt[ms_name] = ms_xdt
if not load_sub_datasets:
for ms_xdt in ps_xdt.children.values():
ms_xdt_names = list(ms_xdt.keys())
for sub_xds_name in ms_xdt_names:
if "xds" in sub_xds_name:
del ms_xdt[sub_xds_name]
ps_xdt = ps_xdt.load()
return ps_xdt
class ProcessingSetIterator:
def __init__(
self,
sel_parms: dict,
input_data_store: str,
input_data: Union[Dict, xr.DataTree, None] = None,
data_group_name: str = None,
include_variables: Union[list, None] = None,
drop_variables: Union[list, None] = None,
load_sub_datasets: bool = True,
in_memory: bool = False,
):
"""An iterator that will go through a processing set one MS v4 at a time.
Parameters
----------
sel_parms : dict
A dictionary where the keys are the names of the ms_xds's and the values are slice_dicts.
slice_dicts: A dictionary where the keys are the dimension names and the values are slices.
For example::
{
'ms_v4_name_1': {'frequency': slice(0, 160, None),'time':slice(0,100)},
...
'ms_v4_name_n': {'frequency': slice(0, 160, None),'time':slice(0,100)},
}
input_data_store : str
String of the path and name of the processing set. For example '/users/user_1/uid___A002_Xf07bba_Xbe5c_target.lsrk.vis.zarr'.
input_data : Union[Dict, xr.DataTree, None], optional
If the processing set is in memory already it can be supplied here. By default None which will make the iterator load data using the supplied input_data_store.
data_group_name : str, optional
The name of the data group to select. By default None, which loads all data groups.
include_variables : Union[list, None], optional
The list of data variables to load into memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will load all data variables into memory.
drop_variables : Union[list, None], optional
The list of data variables to drop from memory for example ['VISIBILITY', 'WEIGHT, 'FLAGS']. By default None which will not drop any data variables from memory.
load_sub_datasets : bool, optional
If true sub-datasets (for example weather_xds, antenna_xds, pointing_xds, system_calibration_xds ...) will be loaded into memory, by default True.
in_memory : bool, optional
If True, ms_xdt's are cached as they are loaded so that resetting the iterator
does not require reloading from disk. If False, only a single ms_xdt is held in
memory at a time. By default False.
"""
# logger.debug("Memory usage at start of ProcessingSetIterator initialization: " + str(get_rss_gb()) + " GB")
self.input_data = input_data
self.input_data_store = input_data_store
self.sel_parms = sel_parms
self.data_group_name = data_group_name
self.include_variables = include_variables
self.drop_variables = drop_variables
self.load_sub_datasets = load_sub_datasets
self.in_memory = in_memory
self._ms_name_list = list(sel_parms.keys())
self._index = 0
self._current_ms_name: Union[str, None] = None
self._current_ms_xdt: Union[xr.DataTree, None] = None
self._cache: Dict[str, xr.DataTree] = {}
self._load_time: float = 0.0
self._longest_load_time: float = 0.0
# logger.debug("ProcessingSetIterator initialized with " + str(len(self._ms_name_list)) + " ms_xdts to iterate over.")
# logger.debug("Memory usage after ProcessingSetIterator initialization: " + str(get_rss_gb()) + " GB")
def __iter__(self):
return self
def reset(self):
"""Reset the iterator to the beginning.
If ``in_memory=True``, previously loaded ms_xdt's are served from the
cache on the next pass without reloading from disk.
Returns
-------
tuple of float
A 2-tuple ``(load_time, longest_load_time)``, where ``load_time`` is
the total time in seconds spent loading data since the previous call
to :meth:`reset`, and ``longest_load_time`` is the maximum such value
observed across all iterations of this iterator.
"""
self._index = 0
self._current_ms_name = None
self._current_ms_xdt = None
load_time = self._load_time
self._load_time = 0.0
if self._longest_load_time < load_time:
self._longest_load_time = load_time
return load_time, self._longest_load_time
def __next__(self):
import toolviper.utils.logger as logger
# logger.debug("ProcessingSetIterator __next__ called. Current index: " + str(self._index))
# logger.debug("Memory usage at start of __next__: " + str(get_rss_gb()) + " GB")
if self._index >= len(self._ms_name_list):
raise StopIteration
sub_xds_name = self._ms_name_list[self._index]
self._index += 1
self._current_ms_name = sub_xds_name
T_load_start = time.time()
if self.input_data is not None:
sub_xdt = self.input_data[sub_xds_name]
elif self.in_memory and sub_xds_name in self._cache:
sub_xdt = self._cache[sub_xds_name]
else:
slice_description = self.sel_parms[sub_xds_name]
ps_xdt = load_processing_set(
ps_store=self.input_data_store,
sel_parms={sub_xds_name: slice_description},
data_group_name=self.data_group_name,
include_variables=self.include_variables,
drop_variables=self.drop_variables,
load_sub_datasets=self.load_sub_datasets,
)
sub_xdt = ps_xdt[sub_xds_name]
if self.in_memory:
self._cache[sub_xds_name] = sub_xdt
self._current_ms_xdt = sub_xdt
self._load_time = self._load_time + time.time() - T_load_start
# logger.debug("Memory usage at end of __next__: " + str(get_rss_gb()) + " GB")
return sub_xdt
def _get_rss_gb():
import psutil
return psutil.Process(os.getpid()).memory_info().rss / 1e9