import pandas as pd
from xradio._utils.list_and_array import to_list
import numpy as np
import xarray as xr
from xradio.measurement_set.measurement_set_xdt import get_data_group_name
PS_DATASET_TYPES = {"processing_set"}
class InvalidAccessorLocation(ValueError):
"""
Raised by Processing Set accessor functions called on a wrong DataTree node (not processing set).
"""
pass
[docs]
@xr.register_datatree_accessor("xr_ps")
class ProcessingSetXdt:
"""
Accessor to Processing Set DataTree nodes. Provides Processing Set specific functionality such
as producing a summary of the processing set (with information from all its MSv4s), or retrieving
combined antenna or field_and_source datasets.
"""
_xdt: xr.DataTree
def __init__(self, datatree: xr.DataTree):
"""
Initialize the ProcessingSetXdt instance.
Parameters
----------
datatree: xarray.DataTree
The Processing Set DataTree node to construct a ProcessingSetXdt accessor.
"""
self._xdt = datatree
self.meta = {"summary": {}}
[docs]
def summary(
self, data_group_name: str | None = None, first_columns: list[str] = None
) -> pd.DataFrame:
"""
Generate and retrieve a summary of the Processing Set as a data frame.
The summary includes information such as the names of the Measurement Sets,
their intents, polarizations, spectral window names, field names, source names,
field coordinates, start frequencies, and end frequencies.
To prioritize certain columns depending on the context, the order in which the
columns of the data frame are sorted can be modified from the default
(first_columns parameter).
Parameters
----------
data_group_name : str, optional
The data group to summarize. By default the "base" group
is used (if found), or otherwise the first group found.
first_columns : list[str], optional
List of columns to be sorted first. Currently, the columns included in the
summary frame are, in this order: "name", "scan_intents", "shape",
"execution_block_UID", "polarization", "scan_name", "spw_name",
"spw_intents", "field_name", "source_name", "line_name", "field_coords",
"session_reference_UID", "scheduling_block_UID", "project_UID",
"start_frequency", "end_frequency".
For example, with first_columns=["spw_name", "scan_name"] one can print
these two columns first, followed by all the other columns in their usual
order.
Returns
-------
pandas.DataFrame
A DataFrame containing the summary information of the specified data group.
"""
def find_data_group_base_or_first(
data_group_name: str, xdt: xr.DataTree
) -> str:
first_msv4 = next(iter(xdt.values()))
first_data_groups = first_msv4.attrs["data_groups"]
if data_group_name is None:
data_group_name = (
"base"
if "base" in first_data_groups
else next(iter(first_data_groups))
)
return data_group_name
if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
raise InvalidAccessorLocation(
f"{self._xdt.path} is not a processing set node."
)
data_group_name = find_data_group_base_or_first(data_group_name, self._xdt)
if data_group_name in self.meta["summary"]:
summary = self.meta["summary"][data_group_name]
else:
self.meta["summary"][data_group_name] = self._summary(
data_group_name
).sort_values(by=["name"], ascending=True)
summary = self.meta["summary"][data_group_name]
if first_columns:
found_columns = [col for col in first_columns if col in summary.columns]
if found_columns:
all_columns = (
found_columns + summary.columns.drop(found_columns).tolist()
)
summary = summary[all_columns]
return summary
[docs]
def get_max_dims(self) -> dict[str, int]:
"""
Determine the maximum dimensions across all Measurement Sets in the Processing Set.
This method examines each Measurement Set's dimensions and computes the maximum
size for each dimension across the entire Processing Set.
For example, if the Processing Set contains two MSs with dimensions (50, 20, 30) and (10, 30, 40),
the maximum dimensions will be (50, 30, 40).
Returns
-------
dict
A dictionary containing the maximum dimensions of the Processing Set, with dimension names as keys
and their maximum sizes as values.
"""
if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
raise InvalidAccessorLocation(
f"{self._xdt.path} is not a processing set node."
)
if "max_dims" in self.meta:
return self.meta["max_dims"]
else:
max_dims = None
for ms_xdt in self._xdt.values():
if max_dims is None:
max_dims = dict(ms_xdt.sizes)
else:
for dim_name, size in ms_xdt.sizes.items():
if dim_name in max_dims:
if max_dims[dim_name] < size:
max_dims[dim_name] = size
else:
max_dims[dim_name] = size
self.meta["max_dims"] = max_dims
return self.meta["max_dims"]
[docs]
def get_freq_axis(self) -> xr.DataArray:
"""
Combine the frequency axes of all Measurement Sets in the Processing Set.
This method aggregates the frequency information from each Measurement Set to create
a unified frequency axis for the entire Processing Set.
Returns
-------
xarray.DataArray
The combined frequency axis of the Processing Set.
"""
if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
raise InvalidAccessorLocation(
f"{self._xdt.path} is not a processing set node."
)
if "freq_axis" in self.meta:
return self.meta["freq_axis"]
else:
spw_names = []
freq_axis_list = []
frame = self._xdt[next(iter(self._xdt.children))].frequency.attrs[
"observer"
]
for ms_xdt in self._xdt.values():
assert (
frame == ms_xdt.frequency.attrs["observer"]
), "Frequency reference frame not consistent in Processing Set."
if ms_xdt.frequency.attrs["spectral_window_name"] not in spw_names:
spw_names.append(ms_xdt.frequency.attrs["spectral_window_name"])
freq_axis_list.append(ms_xdt.frequency)
freq_axis = xr.concat(freq_axis_list, dim="frequency", join="outer").sortby(
"frequency"
)
self.meta["freq_axis"] = freq_axis
return self.meta["freq_axis"]
def _summary(self, data_group_name: str = None):
summary_data = {
"name": [],
"scan_intents": [],
"shape": [],
"execution_block_UID": [],
"polarization": [],
"scan_name": [],
"spw_name": [],
"spw_intents": [],
"field_name": [],
"source_name": [],
"line_name": [],
"field_coords": [],
"session_reference_UID": [],
"scheduling_block_UID": [],
"project_UID": [],
"start_frequency": [],
"end_frequency": [],
}
from astropy.coordinates import SkyCoord
import astropy.units as u
for key, value in sorted(self._xdt.items()):
partition_info = value.xr_ms.get_partition_info()
observation_info = value.observation_info
summary_data["name"].append(key)
summary_data["scan_intents"].append(partition_info["scan_intents"])
summary_data["execution_block_UID"].append(
observation_info.get("execution_block_UID", "---")
)
summary_data["spw_name"].append(partition_info["spectral_window_name"])
summary_data["spw_intents"].append(
partition_info["spectral_window_intents"]
)
summary_data["polarization"].append(value.polarization.values)
summary_data["scan_name"].append(partition_info["scan_name"])
data_name = value.attrs["data_groups"][data_group_name]["correlated_data"]
if "VISIBILITY" in data_name:
center_name = "FIELD_PHASE_CENTER_DIRECTION"
if "SPECTRUM" in data_name:
center_name = "FIELD_REFERENCE_CENTER_DIRECTION"
summary_data["shape"].append(value[data_name].shape)
summary_data["field_name"].append(partition_info["field_name"])
summary_data["source_name"].append(partition_info["source_name"])
summary_data["line_name"].append(partition_info["line_name"])
summary_data["session_reference_UID"].append(
observation_info.get("session_reference_UID", "---")
)
summary_data["scheduling_block_UID"].append(
observation_info.get("scheduling_block_UID", "---")
)
summary_data["project_UID"].append(
observation_info.get("project_UID", "---")
)
summary_data["start_frequency"].append(
to_list(value["frequency"].values)[0]
)
summary_data["end_frequency"].append(to_list(value["frequency"].values)[-1])
field_and_source_xds = value.xr_ms.get_field_and_source_xds(data_group_name)
if field_and_source_xds.attrs["type"] == "field_and_source_ephemeris":
summary_data["field_coords"].append("Ephemeris")
elif field_and_source_xds[center_name]["field_name"].size > 1:
summary_data["field_coords"].append("Multi-Phase-Center")
else:
ra_dec_rad = field_and_source_xds[center_name].values[0, :]
frame = field_and_source_xds[center_name].attrs["frame"].lower()
coord = SkyCoord(
ra=ra_dec_rad[0] * u.rad, dec=ra_dec_rad[1] * u.rad, frame=frame
)
summary_data["field_coords"].append(
[
frame,
coord.ra.to_string(unit=u.hour, precision=2),
coord.dec.to_string(unit=u.deg, precision=2),
]
)
summary_df = pd.DataFrame(summary_data)
return summary_df
[docs]
def query(
self, string_exact_match: bool = True, query: str = None, **kwargs
) -> xr.DataTree:
"""
Select a subset of the Processing Set based on specified criteria.
This method allows filtering the Processing Set by matching column names and values
or by applying a Pandas query string. The selection criteria can target various
attributes of the Measurement Sets such as scan_intents, polarization, spectral window names, etc.
A data group can be selected by name by using the `data_group_name` parameter. This is applied to each Measurement Set in the Processing Set.
Note
----
This selection does not modify the actual data within the Measurement Sets. For example, if
a Measurement Set has `field_name=['field_0','field_10','field_08']` and `ps.query(field_name='field_0')`
is invoked, the resulting subset will still contain the original list `['field_0','field_10','field_08']`.
The exception is data group selection, using `data_group_name`, that will select data variables only associated with the specified data group in the Measurement Set.
Parameters
----------
string_exact_match : bool, optional
If `True`, string matching will require exact matches for string and string list columns.
If `False`, partial matches are allowed. Default is `True`.
query : str, optional
A Pandas query string to apply additional filtering. Default is `None`.
**kwargs : dict
Keyword arguments representing column names and their corresponding values to filter the Processing Set.
Returns
-------
xr.DataTree
A new Processing Set DataTree instance containing only the Measurement Sets that match the selection criteria.
Examples
--------
>>> # Select all MSs with scan_intents 'OBSERVE_TARGET#ON_SOURCE' and polarization 'RR' or 'LL'
>>> selected_ps_xdt = ps_xdt.xr_ps.query(scan_intents='OBSERVE_TARGET#ON_SOURCE', polarization=['RR', 'LL'])
>>> # Select all MSs with start_frequency greater than 100 GHz and less than 200 GHz
>>> selected_ps_xdt = ps_xdt.xr_ps.query(query='start_frequency > 100e9 AND end_frequency < 200e9')
"""
if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
raise InvalidAccessorLocation(
f"{self._xdt.path} is not a processing set node."
)
def select_rows(df, col, sel_vals, string_exact_match):
def check_selection(row_val):
row_val = to_list(
row_val
) # make sure that it is a list so that we can iterate over it.
for rw in row_val:
for s in sel_vals:
if string_exact_match:
if rw == s:
return True
else:
if s in rw:
return True
return False
return df[df[col].apply(check_selection)]
summary_table = self.summary()
data_group_name = None
for key, value in kwargs.items():
if "data_group_name" == key:
data_group_name = value
else:
value = to_list(value) # make sure value is a list.
if len(value) == 1 and isinstance(value[0], slice):
summary_table = summary_table[
summary_table[key].between(value[0].start, value[0].stop)
]
else:
summary_table = select_rows(
summary_table, key, value, string_exact_match
)
if query is not None:
summary_table = summary_table.query(query)
sub_ps_xdt = xr.DataTree()
for key, val in self._xdt.items():
if key in summary_table["name"].values:
if data_group_name is not None:
sub_ps_xdt[key] = val.xr_ms.sel(data_group_name=data_group_name)
else:
sub_ps_xdt[key] = val
sub_ps_xdt.attrs = self._xdt.attrs
return sub_ps_xdt
[docs]
def get_combined_field_and_source_xds(
self, data_group_name: str = "base"
) -> xr.Dataset:
"""
Combine all non-ephemeris `field_and_source_xds` datasets from a Processing Set for a data group into a
single dataset.
Parameters
----------
data_group_name : str, optional
The data group to process. Default is "base".
Returns
-------
xarray.Dataset
combined_field_and_source_xds: Combined dataset for standard (non-ephemeris) fields.
Raises
------
ValueError
If the `field_and_source_xds` attribute is missing or improperly formatted in any Measurement Set.
"""
if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
raise InvalidAccessorLocation(
f"{self._xdt.path} is not a processing set node."
)
combined_field_and_source_xds = xr.Dataset()
for ms_name, ms_xdt in self._xdt.items():
field_and_source_xds = ms_xdt.xr_ms.get_field_and_source_xds(
data_group_name
)
if not field_and_source_xds.attrs["type"] == "field_and_source_ephemeris":
if (
"line_name" in field_and_source_xds.coords
): # Not including line info since it is a function of spw.
field_and_source_xds = field_and_source_xds.drop_vars(
["LINE_REST_FREQUENCY", "LINE_SYSTEMIC_VELOCITY"],
errors="ignore",
)
del field_and_source_xds["line_name"]
del field_and_source_xds["line_label"]
if len(combined_field_and_source_xds.data_vars) == 0:
combined_field_and_source_xds = field_and_source_xds
else:
combined_field_and_source_xds = xr.concat(
[combined_field_and_source_xds, field_and_source_xds],
dim="field_name",
join="outer",
)
if (len(combined_field_and_source_xds.data_vars) > 0) and (
"FIELD_PHASE_CENTER_DIRECTION" in combined_field_and_source_xds
):
combined_field_and_source_xds = (
combined_field_and_source_xds.drop_duplicates("field_name")
)
combined_field_and_source_xds["MEAN_PHASE_CENTER_DIRECTION"] = (
combined_field_and_source_xds["FIELD_PHASE_CENTER_DIRECTION"].mean(
dim=["field_name"]
)
)
ra1 = (
combined_field_and_source_xds["FIELD_PHASE_CENTER_DIRECTION"]
.sel(sky_dir_label="ra")
.values
)
dec1 = (
combined_field_and_source_xds["FIELD_PHASE_CENTER_DIRECTION"]
.sel(sky_dir_label="dec")
.values
)
ra2 = (
combined_field_and_source_xds["MEAN_PHASE_CENTER_DIRECTION"]
.sel(sky_dir_label="ra")
.values
)
dec2 = (
combined_field_and_source_xds["MEAN_PHASE_CENTER_DIRECTION"]
.sel(sky_dir_label="dec")
.values
)
from xradio._utils.coord_math import haversine
distance = haversine(ra1, dec1, ra2, dec2)
min_index = distance.argmin()
combined_field_and_source_xds.attrs["center_field_name"] = (
combined_field_and_source_xds.field_name[min_index].values
)
return combined_field_and_source_xds
[docs]
def get_combined_field_and_source_xds_ephemeris(
self, data_group_name: str = "base"
) -> xr.Dataset:
"""
Combine all ephemeris `field_and_source_xds` datasets from a Processing Set for a datagroup into a single dataset.
Parameters
----------
data_group_name : str, optional
The data group to process. Default is "base".
Returns
-------
xarray.Dataset
combined_ephemeris_field_and_source_xds: Combined dataset for ephemeris fields.
Raises
------
ValueError
If the `field_and_source_xds` attribute is missing or improperly formatted in any Measurement Set.
"""
if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
raise InvalidAccessorLocation(
f"{self._xdt.path} is not a processing set node."
)
combined_ephemeris_field_and_source_xds = xr.Dataset()
for ms_name, ms_xdt in self._xdt.items():
field_and_source_xds = field_and_source_xds = (
ms_xdt.xr_ms.get_field_and_source_xds(data_group_name)
)
if field_and_source_xds.attrs["type"] == "field_and_source_ephemeris":
if (
"line_name" in field_and_source_xds.coords
): # Not including line info since it is a function of spw.
field_and_source_xds = field_and_source_xds.drop_vars(
["LINE_REST_FREQUENCY", "LINE_SYSTEMIC_VELOCITY"],
errors="ignore",
)
del field_and_source_xds["line_name"]
del field_and_source_xds["line_label"]
from xradio.measurement_set._utils._utils.interpolate import (
interpolate_to_time,
)
if "time_ephemeris" in field_and_source_xds:
field_and_source_xds = interpolate_to_time(
field_and_source_xds,
field_and_source_xds.time,
"field_and_source_xds",
"time_ephemeris",
)
del field_and_source_xds["time_ephemeris"]
field_and_source_xds = field_and_source_xds.rename(
{"time_ephemeris": "time"}
)
if "OBSERVER_POSITION" in field_and_source_xds:
field_and_source_xds = field_and_source_xds.drop_vars(
["OBSERVER_POSITION"], errors="ignore"
)
if len(combined_ephemeris_field_and_source_xds.data_vars) == 0:
combined_ephemeris_field_and_source_xds = field_and_source_xds
else:
combined_ephemeris_field_and_source_xds = xr.concat(
[combined_ephemeris_field_and_source_xds, field_and_source_xds],
dim="time",
join="outer",
)
if (len(combined_ephemeris_field_and_source_xds.data_vars) > 0) and (
"FIELD_PHASE_CENTER_DIRECTION" in combined_ephemeris_field_and_source_xds
):
from xradio._utils.coord_math import wrap_to_pi
offset = (
combined_ephemeris_field_and_source_xds["FIELD_PHASE_CENTER_DIRECTION"]
- combined_ephemeris_field_and_source_xds["SOURCE_DIRECTION"]
)
combined_ephemeris_field_and_source_xds["FIELD_OFFSET"] = xr.DataArray(
wrap_to_pi(offset.sel(sky_dir_label=["ra", "dec"])).values,
dims=["time", "sky_dir_label"],
)
combined_ephemeris_field_and_source_xds["FIELD_OFFSET"].attrs = (
combined_ephemeris_field_and_source_xds[
"FIELD_PHASE_CENTER_DIRECTION"
].attrs
)
combined_ephemeris_field_and_source_xds["FIELD_OFFSET"].attrs["units"] = (
combined_ephemeris_field_and_source_xds["FIELD_OFFSET"].attrs["units"][
:2
]
)
ra1 = (
combined_ephemeris_field_and_source_xds["FIELD_OFFSET"]
.sel(sky_dir_label="ra")
.values
)
dec1 = (
combined_ephemeris_field_and_source_xds["FIELD_OFFSET"]
.sel(sky_dir_label="dec")
.values
)
ra2 = 0.0
dec2 = 0.0
from xradio._utils.coord_math import haversine
distance = haversine(ra1, dec1, ra2, dec2)
min_index = distance.argmin()
combined_ephemeris_field_and_source_xds.attrs["center_field_name"] = (
combined_ephemeris_field_and_source_xds.field_name[min_index].values
)
return combined_ephemeris_field_and_source_xds
[docs]
def plot_phase_centers(
self, label_all_fields: bool = False, data_group_name: str = "base"
):
"""
Plot the phase center locations of all fields in the Processing Set.
This method is primarily used for visualizing mosaics. It generates scatter plots of
the phase center coordinates for both standard and ephemeris fields. The central field
is highlighted in red based on the closest phase center calculation.
Parameters
----------
label_all_fields : bool, optional
If `True`, all fields will be labeled on the plot. Default is `False`.
data_group_name : str, optional
The data group to use for processing. Default is "base".
Returns
-------
None
Raises
------
ValueError
If the combined datasets are empty or improperly formatted.
"""
def setup_annotations_all(axis, scatter, field_names):
"""
Creates annotations for when label_all_fields=True
"""
coord_x, coord_y = np.array(scatter.get_offsets()).transpose()
offset_x = np.abs(np.max(coord_x) - np.min(coord_x)) * 0.01
offset_y = np.abs(np.max(coord_y) - np.min(coord_y)) * 0.01
for idx, (x, y) in enumerate(zip(coord_x + offset_x, coord_y + offset_y)):
axis.annotate(field_names[idx], (x, y), alpha=1)
if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
raise InvalidAccessorLocation(
f"{self._xdt.path} is not a processing set node."
)
combined_field_and_source_xds = self.get_combined_field_and_source_xds(
data_group_name
)
combined_ephemeris_field_and_source_xds = (
self.get_combined_field_and_source_xds_ephemeris(data_group_name)
)
from matplotlib import pyplot as plt
if (len(combined_field_and_source_xds.data_vars) > 0) and (
"FIELD_PHASE_CENTER_DIRECTION" in combined_field_and_source_xds
):
fig = plt.figure()
plt.title("Field Phase Center Locations")
scatter = plt.scatter(
combined_field_and_source_xds["FIELD_PHASE_CENTER_DIRECTION"].sel(
sky_dir_label="ra"
),
combined_field_and_source_xds["FIELD_PHASE_CENTER_DIRECTION"].sel(
sky_dir_label="dec"
),
)
center_field_name = combined_field_and_source_xds.attrs["center_field_name"]
center_field = combined_field_and_source_xds.sel(
field_name=center_field_name
)
if label_all_fields:
field_name = combined_field_and_source_xds.field_name.values
setup_annotations_all(fig.axes[0], scatter, field_name)
fig.axes[0].margins(0.2, 0.2)
center_label = None
else:
center_label = center_field_name
plt.scatter(
center_field["FIELD_PHASE_CENTER_DIRECTION"].sel(sky_dir_label="ra"),
center_field["FIELD_PHASE_CENTER_DIRECTION"].sel(sky_dir_label="dec"),
color="red",
label=center_label,
)
plt.xlabel("RA (rad)")
plt.ylabel("DEC (rad)")
if not label_all_fields:
plt.legend()
plt.show()
if (len(combined_ephemeris_field_and_source_xds.data_vars) > 0) and (
"FIELD_PHASE_CENTER_DIRECTION" in combined_ephemeris_field_and_source_xds
):
fig = plt.figure()
plt.title(
"Offset of Field Phase Center from Source Location (Ephemeris Data)"
)
scatter = plt.scatter(
combined_ephemeris_field_and_source_xds["FIELD_OFFSET"].sel(
sky_dir_label="ra"
),
combined_ephemeris_field_and_source_xds["FIELD_OFFSET"].sel(
sky_dir_label="dec"
),
)
center_field_name = combined_ephemeris_field_and_source_xds.attrs[
"center_field_name"
]
combined_ephemeris_field_and_source_xds = (
combined_ephemeris_field_and_source_xds.set_xindex("field_name")
)
center_field = combined_ephemeris_field_and_source_xds.sel(
field_name=center_field_name
)
if label_all_fields:
field_name = combined_ephemeris_field_and_source_xds.field_name.values
setup_annotations_all(fig.axes[0], scatter, field_name)
fig.axes[0].margins(0.2, 0.2)
center_label = None
else:
center_label = center_field_name
plt.scatter(
center_field["FIELD_OFFSET"].sel(sky_dir_label="ra"),
center_field["FIELD_OFFSET"].sel(sky_dir_label="dec"),
color="red",
label=center_label,
)
plt.xlabel("RA Offset (rad)")
plt.ylabel("DEC Offset (rad)")
if not label_all_fields:
plt.legend()
plt.show()
[docs]
def get_combined_antenna_xds(self) -> xr.Dataset:
"""
Combine the `antenna_xds` datasets from all Measurement Sets into a single dataset.
This method concatenates the antenna datasets from each Measurement Set along the 'antenna_name' dimension.
Returns
-------
xarray.Dataset
A combined `xarray.Dataset` containing antenna information from all Measurement Sets.
Raises
------
ValueError
If antenna datasets are missing required variables or improperly formatted.
"""
if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
raise InvalidAccessorLocation(
f"{self._xdt.path} is not a processing set node."
)
combined_antenna_xds = xr.Dataset()
for cor_name, ms_xdt in self._xdt.items():
antenna_xds = ms_xdt.antenna_xds.ds
if len(combined_antenna_xds.data_vars) == 0:
combined_antenna_xds = antenna_xds
else:
combined_antenna_xds = xr.concat(
[combined_antenna_xds, antenna_xds],
dim="antenna_name",
data_vars="minimal",
coords="minimal",
join="outer",
)
# ALMA WVR antenna_xds data has a NaN value for the antenna receptor angle.
if "ANTENNA_RECEPTOR_ANGLE" in combined_antenna_xds.data_vars:
combined_antenna_xds = combined_antenna_xds.dropna("antenna_name")
combined_antenna_xds = combined_antenna_xds.drop_duplicates("antenna_name")
return combined_antenna_xds
[docs]
def plot_antenna_positions_2d(
self,
add_antenna_labels: bool = True,
add_antenna_stations: bool = False,
add_elevation_plot: bool = True,
add_continent_outlines: bool = True,
figure_size: tuple = (12, 8),
):
"""
Plot the antenna positions of all antennas in all measurement sets onto 2D grids.
For connected arrays with a known array center the antenna positions are plotted in an east-west (X axis)\
north-south (Y axis) grid in meters centered on the array center.
For disconnected arrays (usually VLBI) or arrays with no-known array center antenna positions are plotted in a \
longitude and latitude grid in a quasi-mercator projection.
If `cartopy` is available, non-connected arrays are plotted along with continental outlines.
A plot of antenna elevations above sea level is also produced together with the 2d array configuration.
Parameters
----------
add_antenna_labels: bool, optional
If 'True', annotations are shown with a descriptive label for each antenna, default is 'True'.
add_antenna_stations: bool, optional
If 'True', add antenna station information to the antenna labels, default is 'False'.
add_elevation_plot: bool, optional
If 'True', add a plot of the elevations above sea level for each antenna, default is 'True'.
add_continent_outlines: bool, optional
If 'True' and `cartopy` is available, add continental outlines for the longitude and latitude plots for \
disconnected arrays or arrays for which no array center is known.
figure_size: tuple, optional
Controls the size of the plot in inches.
Returns
-------
None
Raises
------
ValueError
If antenna positions are not in the Geocentric ITRS frame.
"""
from xradio._utils.logging import xradio_logger
from matplotlib import pyplot as plt
import astropy.units as ap_units
from astropy.coordinates import EarthLocation
from dask.array import rad2deg
def setup_annotations_for_hover(plot_axes):
"""
Creates annotations for antennae at position and elevation axes objects.
Parameters
plot_axes : matplotlib axes object list
Returns
-------
dict
dict from antenna axes -> annotation objects
"""
annotations = []
for plot_ax in plot_axes:
annotation = plot_ax.annotate(
"",
xy=(0, 0),
xytext=(5, 5),
textcoords="offset points",
arrowprops=dict(arrowstyle="-|>"),
bbox=dict(boxstyle="round", fc="w"),
)
annotation.set_visible(False)
annotations.append(annotation)
return dict(zip(plot_axes, annotations))
def update_antenna_annotation(indices, plot_obj, annotation, ant_info):
position = plot_obj.get_offsets()[indices["ind"][0]]
annotation.xy = position
anno_text = ""
for num in indices["ind"]:
anno_text = ""
for key, value in ant_info[num].items():
anno_text += f"{key}: {value}\n"
anno_text = anno_text.rstrip("\n")
annotation.set_text(anno_text)
annotation.get_bbox_patch().set_facecolor("#e8d192")
annotation.get_bbox_patch().set_alpha(1)
def hover_annotation(event):
if event.inaxes in plot_axes:
for axis in plot_axes:
contained, indices = plot_map[axis].contains(event)
annotation = annotation_map[axis]
if contained:
scatter = plot_map[axis]
update_antenna_annotation(
indices, scatter, annotation, info_dicts
)
annotation.set_visible(True)
fig.canvas.draw_idle()
else:
visible = annotation.get_visible()
if visible:
annotation.set_visible(False)
fig.canvas.draw_idle()
if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
raise InvalidAccessorLocation(
f"{self._xdt.path} is not a processing set node."
)
try:
import cartopy
if add_continent_outlines:
cartopy_available = True
else:
cartopy_available = False
except ImportError:
xradio_logger().info("To include Continent outlines: `pip install cartopy`")
cartopy_available = False
# Longitude and latitudes in degrees, radius in meters.
# VLBI arrays are marked with None for their centers.
observatory_array_centers = {
# Connected Arrays:
"ALMA": {
"longitude": -67.754929,
"latitude": -23.029,
"radius": 6379946.0,
},
"VLA": {
"longitude": rad2deg(-1.8782884344112576),
"latitude": rad2deg(0.5916753430723376),
"radius": 6373580.0,
},
"MeerKAT": {
"longitude": 21.443889,
"latitude": -30.711056,
"radius": 6373681.0,
},
# GMRT antenna positions in the XDSes seem really weird, maybe an unit error?
"GMRT": {
"longitude": 74.05210298316263,
"latitude": 19.090998273409596,
"radius": 6377126.8,
},
# Connected arrays with currently unknown array centers
"OSKAR": None,
"ASKAP": None,
# Disconnected arrays (VLBI)
"VLBA": None,
"EVN": None,
"EHT": None,
"LOFAR": None,
}
combined_antenna_xds = self.get_combined_antenna_xds()
ant_pos = combined_antenna_xds.ANTENNA_POSITION
ant_names = ant_pos.antenna_name.values
station_names = ant_pos.station_name.values
telescope_names = ant_pos.telescope_name.values
overall_telescope = combined_antenna_xds.attrs["overall_telescope_name"]
if "OSKAR" in overall_telescope:
overall_telescope = "OSKAR"
elif "ALMA" in overall_telescope:
overall_telescope = "ALMA"
elif "VLA" in overall_telescope:
overall_telescope = "VLA"
elif "ASKAP" in overall_telescope:
overall_telescope = "ASKAP"
pos_frame = ant_pos.attrs["frame"]
pos_system = ant_pos.attrs["coordinate_system"]
# Convert antenna positions in bulk
if pos_frame == "ITRS" and pos_system == "geocentric":
ant_x = ant_pos.sel(cartesian_pos_label="x").values
ant_y = ant_pos.sel(cartesian_pos_label="y").values
ant_z = ant_pos.sel(cartesian_pos_label="z").values
ant_locs = EarthLocation.from_geocentric(
ant_x * ap_units.m, ant_y * ap_units.m, ant_z * ap_units.m
)
ant_rad = np.sqrt(ant_x**2 + ant_y**2 + ant_z**2)
ant_lon, ant_lat, ant_height = ant_locs.geodetic
ant_lat = ant_lat.deg
ant_lon = ant_lon.deg
ant_height = ant_height
else:
raise ValueError(
f"Don't know how to plot antenna positions in {pos_system} {pos_frame}"
)
try:
array_center = observatory_array_centers[overall_telescope]
except KeyError:
xradio_logger().warning(
f"Observatory {overall_telescope} not yet supported, plotting as a disconnected array"
)
array_center = None
plot_as_disconnected_array = array_center is None
ant_labels = [f"{ant_name}" for ant_name in ant_names]
if add_antenna_stations:
ant_labels = [
f"{ant_name} @ {station_names[i_ant]}"
for i_ant, ant_name in enumerate(ant_names)
]
fig = plt.figure(figsize=figure_size)
if add_elevation_plot:
if cartopy_available:
main_plot_area = (0.1, 0.40, 0.8, 0.55)
else:
main_plot_area = (0.1, 0.40, 0.8, 0.55)
else:
main_plot_area = (0.1, 0.1, 0.8, 0.8)
if plot_as_disconnected_array:
if cartopy_available:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
ant_pos_ax = plt.axes(main_plot_area, projection=ccrs.PlateCarree())
ant_pos_ax.add_feature(cfeature.COASTLINE)
gl = ant_pos_ax.gridlines(draw_labels=True)
gl.top_labels = False
gl.right_labels = False
else:
ant_pos_ax = plt.axes(main_plot_area)
ant_pos_ax.set_xlabel("Longitude [deg]")
ant_pos_ax.set_ylabel("Latitude [deg]")
ant_plot_pos_x = ant_lon
ant_plot_pos_y = ant_lat
else:
ant_pos_ax = plt.axes(main_plot_area)
# Compute antenna positions relative to array center in a gnomonic projection
tel_lon = array_center["longitude"]
tel_lat = array_center["latitude"]
tel_rad = array_center["radius"]
ant_plot_pos_x = tel_rad * (ant_lon - tel_lon) * np.cos(tel_lat)
ant_plot_pos_y = tel_rad * (ant_lat - tel_lat)
ant_pos_ax.set_xlabel("East [m]")
ant_pos_ax.set_ylabel("North [m]")
# Build antenna info dictionaries
info_dicts = []
for i_ant, ant_name in enumerate(ant_names):
ant_dict = {
"Name": ant_name,
"Station": station_names[i_ant],
"Telescope": telescope_names[i_ant],
"Elevation": f"{ant_height[i_ant]:.0f}",
}
if plot_as_disconnected_array:
ant_dict["Longitude"] = f"{ant_plot_pos_x[i_ant]:.4f} deg"
ant_dict["Latitude"] = f"{ant_plot_pos_y[i_ant]:.4f} deg"
else:
ant_dict["East offset"] = f"{ant_plot_pos_x[i_ant]:.2f} m"
ant_dict["North offset"] = f"{ant_plot_pos_y[i_ant]:.2f} m"
info_dicts.append(ant_dict)
ant_pos_plot = ant_pos_ax.scatter(
ant_plot_pos_x,
ant_plot_pos_y,
color="blue",
marker="+",
ls="",
)
if add_antenna_labels:
for i_ant, ant_label in enumerate(ant_labels):
ant_pos_ax.annotate(
ant_label,
(ant_plot_pos_x[i_ant], ant_plot_pos_y[i_ant]),
alpha=1,
xytext=(2, 2),
textcoords="offset points",
)
ant_pos_ax.set_title(f"{overall_telescope} Antenna Positions")
if add_elevation_plot:
ant_ids = np.arange(ant_height.shape[0])
ant_el_ax = plt.axes((0.1, 0.05, 0.8, 0.25))
ant_el_plot = ant_el_ax.scatter(
ant_ids,
ant_height,
color="blue",
marker="_",
ls="",
)
ant_el_ax.set_xlabel("Antenna")
ant_el_ax.set_ylabel("Antenna Elevation [m]")
ant_el_ax.set_xticks(ant_ids, labels=ant_labels, rotation=90)
plot_axes = [ant_pos_ax, ant_el_ax]
plot_map = dict(zip(plot_axes, [ant_pos_plot, ant_el_plot]))
else:
plot_axes = [ant_pos_ax]
plot_map = dict(zip(plot_axes, [ant_pos_plot]))
annotation_map = setup_annotations_for_hover(plot_axes)
fig.canvas.mpl_connect("motion_notify_event", hover_annotation)
plt.show()
return
[docs]
def plot_antenna_positions(self, label_all_antennas: bool = False):
"""
Plot the antenna positions of all antennas in the Processing Set.
This method generates and displays a figure with three scatter plots, displaying the antenna
positions in different planes:
- X vs Y
- X vs Z
- Y vs Z
The antenna names are shown on hovering their positions, unless label_all_antennas is enabled.
Parameters
----------
label_all_antennas : bool, optional
If 'True', annotations are shown with the names of every antenna next to their positions.
Returns
-------
None
Raises
------
ValueError
If the combined antenna dataset is empty or missing required coordinates.
"""
def antenna_hover(event):
if event.inaxes in antenna_axes:
for axis in antenna_axes:
contained, indices = scatter_map[axis].contains(event)
annotation = annotations_map[axis]
if contained:
scatter = scatter_map[axis]
update_antenna_annotation(indices, scatter, annotation)
annotation.set_visible(True)
fig.canvas.draw_idle()
else:
visible = annotation.get_visible()
if visible:
annotation.set_visible(False)
fig.canvas.draw_idle()
def update_antenna_annotation(indices, scatter, annotation):
position = scatter.get_offsets()[indices["ind"][0]]
annotation.xy = position
text = "{}".format(" ".join([antenna_names[num] for num in indices["ind"]]))
annotation.set_text(text)
annotation.get_bbox_patch().set_facecolor("#e8d192")
annotation.get_bbox_patch().set_alpha(1)
def setup_annotations_for_hover(antenna_axes, scatter_plots):
"""
Creates annotations on all the axes requested.
Returns
-------
dict
dict from antenna axes -> annotation objects
"""
antenna_annotations = []
for axis in antenna_axes:
annotation = axis.annotate(
"",
xy=(0, 0),
xytext=(10, 15),
textcoords="offset points",
arrowprops=dict(arrowstyle="-|>"),
bbox=dict(boxstyle="round", fc="w"),
)
antenna_annotations.append(annotation)
annotation.set_visible(False)
annotations_map = dict(zip(antenna_axes, antenna_annotations))
return annotations_map
def setup_annotations_for_all(antenna_axes, scatter_map):
"""
Creates annotations for when label_all_antennas=True
"""
for axis in antenna_axes:
scatter = scatter_map[axis]
coord_x, coord_y = np.array(scatter.get_offsets()).transpose()
offset_x = np.abs(np.max(coord_x) - np.min(coord_x)) * 0.01
offset_y = np.abs(np.max(coord_y) - np.min(coord_y)) * 0.01
for idx, (x, y) in enumerate(
zip(coord_x + offset_x, coord_y + offset_y)
):
axis.annotate(
antenna_names[idx],
(x, y),
alpha=1,
)
if self._xdt.attrs.get("type") not in PS_DATASET_TYPES:
raise InvalidAccessorLocation(
f"{self._xdt.path} is not a processing set node."
)
combined_antenna_xds = self.get_combined_antenna_xds()
from matplotlib import pyplot as plt
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
fig.suptitle("Antenna Positions")
fig.subplots_adjust(
wspace=0.25, hspace=0.25, left=0.1, right=0.95, top=0.9, bottom=0.1
)
scatter1 = ax1.scatter(
combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="x"),
combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="y"),
)
ax1.set_xlabel("x (m)")
ax1.set_ylabel("y (m)")
antenna_names = combined_antenna_xds.antenna_name.values
scatter2 = ax2.scatter(
combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="y"),
combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="z"),
)
ax2.set_xlabel("y (m)")
ax2.set_ylabel("z (m)")
scatter3 = ax3.scatter(
combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="x"),
combined_antenna_xds["ANTENNA_POSITION"].sel(cartesian_pos_label="z"),
)
ax3.set_xlabel("x (m)")
ax3.set_ylabel("z (m)")
ax4.axis("off")
antenna_axes = [ax1, ax2, ax3]
scatter_map = dict(zip(antenna_axes, [scatter1, scatter2, scatter3]))
if label_all_antennas:
annotations_map = setup_annotations_for_all(antenna_axes, scatter_map)
else:
annotations_map = setup_annotations_for_hover(antenna_axes, scatter_map)
fig.canvas.mpl_connect("motion_notify_event", antenna_hover)
plt.show()
[docs]
def get_ms_xdt(self):
"""Returns the Measurement Set associated with this Processing Set if there is only a single Measurement Set.
Returns
-------
xr.DataTree
The Measurement Set Data Tree object.
"""
assert (
len(self._xdt.children) == 1
), "Processing Set contains multiple Measurement Sets and cannot determine which to return."
return list(self._xdt.children.values())[0]