"""
Roman pipeline step for image alignment.
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import pyarrow.parquet as pq
from astropy.table import Table
from roman_datamodels import datamodels as rdm
from roman_datamodels import dqflags
from stcal.tweakreg import tweakreg
from stcal.tweakreg.tweakreg import TweakregError
from tweakwcs import RomanWCSCorrector
from romancal.assign_wcs.assign_wcs import add_s_region
from romancal.datamodels.fileio import open_dataset
from romancal.lib.save_wcs import save_wfiwcs
# LOCAL
from ..datamodels import ModelLibrary
from ..stpipe import RomanStep
if TYPE_CHECKING:
from typing import ClassVar
DEFAULT_ABS_REFCAT = "GAIADR3_S3"
__all__ = ["TweakRegStep"]
log = logging.getLogger(__name__)
[docs]
class TweakRegStep(RomanStep):
"""
TweakRegStep: Image alignment based on catalogs of sources from in
input images.
"""
class_alias = "tweakreg"
spec = f"""
catalog_format = string(default='ascii.ecsv') # Catalog output file format
catalog_path = string(default='') # Catalog output file path
enforce_user_order = boolean(default=False) # Align images in user specified order?
expand_refcat = boolean(default=False) # Expand reference catalog with new sources?
minobj = integer(default=10) # Minimum number of objects acceptable for matching
searchrad = float(default=2.0) # The search radius in arcsec for a match
use2dhist = boolean(default=True) # Use 2d histogram to find initial offset?
separation = float(default=1.0) # Minimum object separation in arcsec
tolerance = float(default=0.7) # Matching tolerance for xyxymatch in arcsec
fitgeometry = option('shift', 'rshift', 'rscale', 'general', default='general') # Fitting geometry
nclip = integer(min=0, default=3) # Number of clipping iterations in fit
sigma = float(min=0.0, default=3.0) # Clipping limit in sigma units
abs_refcat = string(default='{DEFAULT_ABS_REFCAT}') # Absolute reference catalog
save_abs_catalog = boolean(default=False) # Write out used absolute astrometric reference catalog as a separate product
abs_minobj = integer(default=10) # Minimum number of objects acceptable for matching when performing absolute astrometry
abs_searchrad = float(default=6.0) # The search radius in arcsec for a match when performing absolute astrometry
# We encourage setting this parameter to True. Otherwise, xoffset and yoffset will be set to zero.
abs_use2dhist = boolean(default=True) # Use 2D histogram to find initial offset when performing absolute astrometry?
abs_separation = float(default=1.0) # Minimum object separation in arcsec when performing absolute astrometry
abs_tolerance = float(default=0.7) # Matching tolerance for xyxymatch in arcsec when performing absolute astrometry
# Fitting geometry when performing absolute astrometry
abs_fitgeometry = option('shift', 'rshift', 'rscale', 'general', default='general')
abs_nclip = integer(min=0, default=3) # Number of clipping iterations in fit when performing absolute astrometry
abs_sigma = float(min=0.0, default=3.0) # Clipping limit in sigma units when performing absolute astrometry
output_use_model = boolean(default=True) # When saving use `DataModel.meta.filename`
update_source_catalog_coordinates = boolean(default=False) # Update source catalog file with tweaked coordinates?
vo_timeout = float(min=0, default=1200.) # VO catalog service timeout.
"""
reference_file_types: ClassVar = []
[docs]
def process(self, dataset):
images = open_dataset(
dataset, update_version=self.update_version, as_library=True
)
if not images:
raise ValueError("Input must contain at least one image model.")
log.info(
f"Number of image groups to be aligned: {len(images.group_indices):d}."
)
log.info("Image groups:")
for name in images.group_names:
log.info(f" {name}")
# set the first image as reference
with images:
ref_image = images.borrow(0)
images.shelve(ref_image, 0, modify=False)
# set path where the source catalog will be saved to
if len(self.catalog_path) == 0:
self.catalog_path = os.getcwd()
self.catalog_path = Path(self.catalog_path).as_posix()
log.info(f"All source catalogs will be saved to: {self.catalog_path}")
# set reference catalog name
if not self.abs_refcat:
self.abs_refcat = DEFAULT_ABS_REFCAT.strip().upper()
if self.abs_refcat != DEFAULT_ABS_REFCAT:
self.expand_refcat = True
# build the catalogs for input images
imcats = []
with images:
for i, image_model in enumerate(images):
source_catalog = getattr(image_model.meta, "source_catalog", None)
if source_catalog is None:
log.warning(
f"Skipping TweakReg for {image_model.meta.filename}: "
"no source catalog available."
)
image_model.meta.cal_step.tweakreg = "SKIPPED"
else:
try:
catalog = self.get_tweakreg_catalog(source_catalog, image_model)
except AttributeError as e:
log.error(f"Failed to retrieve tweakreg_catalog: {e}")
images.shelve(image_model, i, modify=False)
raise e
if len(catalog) == 0:
_add_required_columns(catalog)
# for empty catalogs, SourceCatalog omits xpsf & ypsf; add them
# validate catalog columns
if not _validate_catalog_columns(catalog):
raise ValueError(
"'tweakreg' source catalogs must contain a header with columns named either 'x' and 'y' or 'x_psf' and 'y_psf'. Neither were found in the catalog provided."
)
catalog = tweakreg.filter_catalog_by_bounding_box(
catalog, image_model.meta.wcs.bounding_box
)
catalog = _filter_catalog(catalog)
if self.save_abs_catalog:
output_name = os.path.join(
self.catalog_path, f"fit_{self.abs_refcat.lower()}_ref.ecsv"
)
catalog.write(
output_name, format=self.catalog_format, overwrite=True
)
image_model.meta["tweakreg_catalog"] = catalog.as_array()
nsources = len(catalog)
log.info(
f"Using {nsources} sources from {image_model.meta.filename}."
if nsources
else f"No sources found in {image_model.meta.filename}."
)
# build image catalog
# catalog name
catalog_name = os.path.splitext(image_model.meta.filename)[0].strip(
"_- "
)
# catalog data
catalog_table = Table(image_model.meta.tweakreg_catalog)
catalog_table.meta["name"] = catalog_name
catalog = tweakreg.filter_catalog_by_bounding_box(
catalog_table, image_model.meta.wcs.bounding_box
)
corrector = RomanWCSCorrector(
wcs=image_model.meta.wcs,
wcsinfo={
"roll_ref": image_model.meta.wcsinfo.roll_ref,
"v2_ref": image_model.meta.wcsinfo.v2_ref,
"v3_ref": image_model.meta.wcsinfo.v3_ref,
},
# catalog and group_id are required meta
meta={
"catalog": catalog,
"name": catalog.meta.get("name"),
"group_id": images._model_to_group_id(image_model),
"model_index": i,
},
)
imcats.append(corrector)
images.shelve(image_model, i)
# run alignment only if it was possible to build image catalogs
if len(imcats):
absolute_alignment_failed = False
# extract WCS correctors to use for image alignment
if len(images.group_indices) > 1:
try:
self.do_relative_alignment(imcats)
except TweakregError as e:
log.warning(str(e))
try:
self.do_absolute_alignment(ref_image, imcats)
except TweakregError as e:
log.warning(str(e))
absolute_alignment_failed = True
# finalize step
with images:
for imcat in imcats:
image_model = images.borrow(imcat.meta["model_index"])
fit_info = imcat.meta.get("fit_info")
fit_status = (
"" if fit_info is None else str(fit_info.get("status", ""))
)
fit_succeeded = (
not absolute_alignment_failed and "SUCCESS" in fit_status
)
image_model.meta["wcs_fit_results"] = _serialize_wcs_fit_results(
fit_info=fit_info,
n_detector=len(imcats),
force_failed_status=absolute_alignment_failed,
)
# remove source catalog
if "tweakreg_catalog" in image_model.meta:
del image_model.meta["tweakreg_catalog"]
# retrieve fit status and update wcs if fit is successful:
if fit_succeeded:
image_model.meta.cal_step.tweakreg = "COMPLETE"
# Update/create the WCS .name attribute with information
# on this astrometric fit as the only record that it was
# successful:
# NOTE: This .name attrib agreed upon by the JWST Cal
# Working Group.
# Current value is merely a place-holder based
# on HST conventions. This value should also be
# translated to the FITS WCSNAME keyword
# IF that is what gets recorded in the archive
# for end-user searches.
imcat.wcs.name = f"FIT-LVL2-{self.abs_refcat}"
# update WCS
image_model.meta.wcs = imcat.wcs
# update S_REGION
add_s_region(image_model)
# update source catalog coordinates if requested
if self.update_source_catalog_coordinates:
try:
self.update_catalog_coordinates(
image_model.meta.source_catalog[
"tweakreg_catalog_name"
],
imcat.wcs,
)
except Exception as e:
log.error(
f"Failed to update source catalog coordinates: {e}"
)
raise e
else:
image_model.meta.cal_step.tweakreg = "FAILED"
images.shelve(image_model, imcat.meta["model_index"])
return images
[docs]
def save_model(self, result, *args, **kwargs):
if isinstance(result, ModelLibrary):
save_wfiwcs(self, result, force=True)
super().save_model(result, *args, **kwargs)
def _update_catalog_coordinates(self, catalog, tweaked_wcs):
# (x_col, y_col) -> (ra_col, dec_col)
updates = [
("x_centroid", "y_centroid", "ra_centroid", "dec_centroid"),
("x_centroid", "y_centroid", "ra", "dec"),
(
"x_centroid_win",
"y_centroid_win",
"ra_centroid_win",
"dec_centroid_win",
),
("x_psf", "y_psf", "ra_psf", "dec_psf"),
]
for x_col, y_col, ra_col, dec_col in updates:
if any(c not in catalog.colnames for c in (x_col, y_col, ra_col, dec_col)):
# Only update existing columns to preserve the file schema.
continue
catalog[ra_col], catalog[dec_col] = tweaked_wcs.pixel_to_world_values(
catalog[x_col], catalog[y_col]
)
[docs]
def update_catalog_coordinates(self, tweakreg_catalog_name, tweaked_wcs):
"""
Update the source catalog coordinates using the tweaked WCS while strictly preserving original file metadata.
Parameters
----------
tweakreg_catalog_name : str
Path to the source catalog file (in Parquet format) to be updated.
tweaked_wcs : callable
A WCS transformation function that takes x and y coordinates and returns updated (RA, Dec) values.
Returns
-------
None
The function updates the catalog file in place; it does not return a value.
Notes
-----
The method preserves all original file metadata by reading and re-attaching it after coordinate updates.
Only the coordinate columns are modified; all other data and metadata remain unchanged.
"""
# Read the existing catalog using PyArrow
pa_table = pq.read_table(tweakreg_catalog_name)
original_metadata = pa_table.schema.metadata
astropy_table = Table(pa_table.to_pydict())
self._update_catalog_coordinates(astropy_table, tweaked_wcs)
# Create new table with original schema metadata
final_table = pa_table.from_pydict(
{colname: astropy_table[colname] for colname in astropy_table.colnames}
)
final_table = final_table.replace_schema_metadata(original_metadata)
# Write back to file
pq.write_table(final_table, tweakreg_catalog_name)
[docs]
def read_catalog(self, catalog_name):
"""
Reads a source catalog from a specified file.
This function determines the format of the catalog based on the
file extension:
* "asdf": uses roman datamodels
* "parquet": uses pyarrow
* otherwise: uses astropy Table.
Parameters
----------
catalog_name : str
The name of the catalog file to read.
Returns
-------
Table
The read catalog as a Table object.
Raises
------
ValueError
If the catalog format is unsupported.
"""
filetype = (
"parquet" if catalog_name.endswith("parquet") else self.catalog_format
)
if catalog_name.endswith("asdf"):
# leave this for now
with rdm.open(catalog_name) as source_catalog_model:
catalog = source_catalog_model.source_catalog
else:
catalog = Table.read(catalog_name, format=filetype)
return catalog
[docs]
def get_tweakreg_catalog(self, source_catalog, image_model):
"""
Retrieve the tweakreg catalog from source detection.
This method checks the source detection metadata for the presence of a
tweakreg catalog data or a string with its name. It returns the catalog
as a Table object if either is found, or raises an error if neither is available.
Parameters
----------
source_catalog : object
The source catalog metadata containing catalog information.
image_model : DataModel
The image model associated with the source detection.
Returns
-------
Table
The retrieved tweakreg catalog as a Table object.
Raises
------
AttributeError
If the required catalog information is missing from the source detection.
"""
twk_cat = getattr(source_catalog, "tweakreg_catalog", None)
twk_cat_name = getattr(source_catalog, "tweakreg_catalog_name", None)
image_name = getattr(
getattr(image_model, "meta", None), "filename", "<unknown>"
)
if twk_cat is not None:
log.info(
f"Using in-memory tweakreg catalog from meta.source_catalog.tweakreg_catalog for {image_name}."
)
tweakreg_catalog = Table(np.asarray(source_catalog.tweakreg_catalog))
del image_model.meta.source_catalog["tweakreg_catalog"]
return tweakreg_catalog
elif twk_cat_name is not None:
log.info(f"Using tweakreg catalog file '{twk_cat_name}' for {image_name}.")
return self.read_catalog(source_catalog.tweakreg_catalog_name)
else:
raise AttributeError(
"Attribute 'meta.source_catalog.tweakreg_catalog' is missing. "
"Please either run SourceCatalogStep or provide a custom source catalog."
)
[docs]
def do_relative_alignment(self, imcats):
"""
Perform relative alignment of images.
This method performs relative alignment with the specified parameters,
including search radius, separation, and fitting geometry.
Parameters
----------
imcats : list
A list of image catalogs containing source information for alignment.
Returns
-------
None
"""
tweakreg.relative_align(
imcats,
searchrad=self.searchrad,
separation=self.separation,
use2dhist=self.use2dhist,
tolerance=self.tolerance,
xoffset=0,
yoffset=0,
enforce_user_order=self.enforce_user_order,
expand_refcat=self.expand_refcat,
minobj=self.minobj,
fitgeometry=self.fitgeometry,
nclip=self.nclip,
sigma=self.sigma,
clip_accum=True,
)
[docs]
def do_absolute_alignment(self, ref_image, imcats):
"""
Perform absolute alignment of images.
This method retrieves a reference image and performs absolute alignment
using the specified parameters, including reference WCS information and
catalog details. It aligns the provided image catalogs to the absolute
reference catalog.
Parameters
----------
ref_image : DataModel
The reference image used for alignment, which contains WCS information.
imcats : list
A list of image catalogs containing source information for alignment.
Returns
-------
None
"""
tweakreg.absolute_align(
imcats,
self.abs_refcat,
ref_wcs=ref_image.meta.wcs,
ref_wcsinfo=ref_image.meta.wcsinfo,
epoch=ref_image.meta.exposure.start_time.decimalyear,
abs_minobj=self.abs_minobj,
abs_fitgeometry=self.abs_fitgeometry,
abs_nclip=self.abs_nclip,
abs_sigma=self.abs_sigma,
abs_searchrad=self.abs_searchrad,
abs_use2dhist=self.abs_use2dhist,
abs_separation=self.abs_separation,
abs_tolerance=self.abs_tolerance,
save_abs_catalog=False,
clip_accum=True,
timeout=self.vo_timeout,
)
def _validate_catalog_columns(catalog) -> bool:
"""
Validate the presence of required columns in the catalog.
This method checks if the specified axis column exists in the catalog.
If the axis is not found, it looks for a corresponding psf column
and renames it if present. If neither is found, it raises an error.
Parameters
----------
catalog : Table
The catalog to validate, which should contain source information.
Returns
-------
True if all the required columns are present, False otherwise.
"""
for axis in ["x", "y"]:
if axis not in catalog.colnames:
long_axis = f"{axis}_psf"
if long_axis in catalog.colnames:
catalog.rename_column(long_axis, axis)
else:
return False
return True
def _serialize_wcs_fit_results(
fit_info: dict | None,
n_detector: int,
force_failed_status: bool = False,
):
"""
Serialize tweakreg fit metadata for storage in datamodel metadata.
Parameters
----------
fit_info : dict or None
``fit_info`` payload from an image corrector.
n_detector : int
Number of detector images used in the tweakreg solve.
force_failed_status : bool
Force output status to ``FAILED``.
Returns
-------
dict
Serialized fit results.
"""
fit_results = {}
if fit_info is not None:
fit_results = {k: _to_python_scalar_or_list(v) for k, v in fit_info.items()}
# remove unwanted keys from WCS fit results
for key in [
"eff_minobj",
"matched_ref_idx",
"matched_input_idx",
"fit_RA",
"fit_DEC",
"fitmask",
]:
fit_results.pop(key, None)
if force_failed_status:
fit_results["status"] = "FAILED"
elif not fit_results.get("status"):
fit_results["status"] = "FAILED"
if fit_results.get("nmatches") is None:
fit_results["nmatches"] = 0
fit_results["n_detector"] = n_detector
return fit_results
def _to_python_scalar_or_list(value):
"""
Convert numpy values to Python-native scalars/lists for ASDF serialization.
"""
if isinstance(value, np.ndarray):
return value.tolist()
if isinstance(value, np.generic):
return value.item()
return value
def _add_required_columns(catalog):
"""
Updates the input catalog with the required columns based on the standard output from SourceCatalogStep.
The centroid coordinates are always present in the standard output from SourceCatalogStep.
Parameters
----------
catalog : Table
The catalog to validate, which should contain source information.
Returns
-------
None
"""
catalog["x"] = catalog["x_centroid"]
catalog["y"] = catalog["y_centroid"]
def _filter_catalog(catalog):
"""
Remove flagged sources from catalog for tweakreg purposes.
This presently removes only sources whose central cores are flagged
DO_NOT_USE.
Parameters
----------
catalog : Table
The catalog from which to filter flagged sources.
Returns
-------
The filtered catalog
"""
if "warning_flags" in catalog.dtype.names:
bad = (catalog["warning_flags"] & dqflags.pixel.DO_NOT_USE) != 0
catalog = catalog[~bad]
return catalog