Source code for romancal.tweakreg.tweakreg_step

"""
Roman pipeline step for image alignment.
"""

import os
import weakref
from pathlib import Path

import numpy as np
from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.table import Table
from roman_datamodels import datamodels as rdm
from tweakwcs.correctors import JWSTWCSCorrector
from tweakwcs.imalign import align_wcs
from tweakwcs.matchutils import XYXYMatch

from romancal.lib.basic_utils import is_association

# LOCAL
from ..datamodels import ModelContainer
from ..stpipe import RomanStep
from . import astrometric_utils as amutils


def _oxford_or_str_join(str_list):
    nelem = len(str_list)
    if not nelem:
        return "N/A"
    str_list = list(map(repr, str_list))
    if nelem == 1:
        return str_list
    elif nelem == 2:
        return f"{str_list[0]} or {str_list[1]}"
    else:
        return ", ".join(map(repr, str_list[:-1])) + ", or " + repr(str_list[-1])


SINGLE_GROUP_REFCAT = ["GAIADR3", "GAIADR2", "GAIADR1"]
_SINGLE_GROUP_REFCAT_STR = _oxford_or_str_join(SINGLE_GROUP_REFCAT)
DEFAULT_ABS_REFCAT = SINGLE_GROUP_REFCAT[0]
ALIGN_TO_ABS_REFCAT = True

__all__ = ["TweakRegStep"]


[docs] class TweakRegStep(RomanStep): """ TweakRegStep: Image alignment based on catalogs of sources detected in input images. """ class_alias = "tweakreg" spec = f""" use_custom_catalogs = boolean(default=False) # Use custom user-provided catalogs? catalog_format = string(default='ascii.ecsv') # Catalog output file format catfile = string(default='') # Name of the file with a list of custom user-provided catalogs catalog_path = string(default='') # Catalog output file path enforce_user_order = boolean(default=False) # Align images in user specified order? expand_refcat = boolean(default=False) # Expand reference catalog with new sources? minobj = integer(default=15) # Minimum number of objects acceptable for matching searchrad = float(default=2.0) # The search radius in arcsec for a match use2dhist = boolean(default=True) # Use 2d histogram to find initial offset? separation = float(default=1.0) # Minimum object separation in arcsec tolerance = float(default=0.7) # Matching tolerance for xyxymatch in arcsec fitgeometry = option('shift', 'rshift', 'rscale', 'general', default='rshift') # Fitting geometry nclip = integer(min=0, default=3) # Number of clipping iterations in fit sigma = float(min=0.0, default=3.0) # Clipping limit in sigma units abs_refcat = string(default='{DEFAULT_ABS_REFCAT}') # Absolute reference # catalog. Options: {_SINGLE_GROUP_REFCAT_STR} save_abs_catalog = boolean(default=False) # Write out used absolute astrometric reference catalog as a separate product abs_minobj = integer(default=15) # Minimum number of objects acceptable for matching when performing absolute astrometry abs_searchrad = float(default=6.0) # The search radius in arcsec for a match when performing absolute astrometry # We encourage setting this parameter to True. Otherwise, xoffset and yoffset will be set to zero. abs_use2dhist = boolean(default=True) # Use 2D histogram to find initial offset when performing absolute astrometry? abs_separation = float(default=0.1) # Minimum object separation in arcsec when performing absolute astrometry abs_tolerance = float(default=0.7) # Matching tolerance for xyxymatch in arcsec when performing absolute astrometry # Fitting geometry when performing absolute astrometry abs_fitgeometry = option('shift', 'rshift', 'rscale', 'general', default='rshift') abs_nclip = integer(min=0, default=3) # Number of clipping iterations in fit when performing absolute astrometry abs_sigma = float(min=0.0, default=3.0) # Clipping limit in sigma units when performing absolute astrometry output_use_model = boolean(default=True) # When saving use `DataModel.meta.filename` """ # noqa: E501 reference_file_types = [] refcat = None
[docs] def process(self, input): use_custom_catalogs = self.use_custom_catalogs if use_custom_catalogs: catdict = _parse_catfile(self.catfile) # if user requested the use of custom catalogs and provided a # valid 'catfile' file name that has no custom catalogs, # turn off the use of custom catalogs: if catdict is not None and not catdict: self.log.warning( "'use_custom_catalogs' is set to True but 'catfile' " "contains no user catalogs." ) use_custom_catalogs = False try: if use_custom_catalogs and catdict: images = ModelContainer() if isinstance(input, str): asn_dir = os.path.dirname(input) asn_data = images.read_asn(input) for member in asn_data["products"][0]["members"]: filename = member["expname"] member["expname"] = os.path.join(asn_dir, filename) if filename in catdict: member["tweakreg_catalog"] = catdict[filename] elif "tweakreg_catalog" in member: del member["tweakreg_catalog"] images.from_asn(asn_data) elif is_association(input): images.from_asn(input) else: images = ModelContainer(input) for im in images: filename = im.meta.filename if filename in catdict: self.log.info( f"setting " f"{filename}.source_detection.tweakreg_catalog_name =" f" {repr(catdict[filename])}" ) # set catalog name only (no catalog data at this point) im.meta["source_detection"] = { "tweakreg_catalog_name": catdict[filename], } else: images = ( ModelContainer([input]) if ( isinstance(input, rdm.DataModel) or str(input).endswith(".asdf") ) else ModelContainer(input) ) except TypeError as e: e.args = ( "Input to tweakreg must be a list of DataModels, an " "association, or an already open ModelContainer " "containing one or more DataModels.", ) + e.args[1:] raise e if len(self.catalog_path) == 0: self.catalog_path = os.getcwd() self.catalog_path = Path(self.catalog_path).as_posix() self.log.info(f"All source catalogs will be saved to: {self.catalog_path}") if self.abs_refcat is None or len(self.abs_refcat.strip()) == 0: self.abs_refcat = DEFAULT_ABS_REFCAT if self.abs_refcat != DEFAULT_ABS_REFCAT: # Set expand_refcat to True to eliminate possibility of duplicate # entries when aligning to absolute astrometric reference catalog self.expand_refcat = True if len(images) == 0: raise ValueError("Input must contain at least one image model.") # Build the catalogs for input images for i, image_model in enumerate(images): if image_model.meta.exposure.type != "WFI_IMAGE": # Check to see if attempt to run tweakreg on non-Image data self.log.info("Skipping TweakReg for spectral exposure.") # Uncomment below once rad & input data have the cal_step tweakreg # image_model.meta.cal_step.tweakreg = "SKIPPED" return image_model if hasattr(image_model.meta, "source_detection"): is_tweakreg_catalog_present = hasattr( image_model.meta.source_detection, "tweakreg_catalog" ) is_tweakreg_catalog_name_present = hasattr( image_model.meta.source_detection, "tweakreg_catalog_name" ) if is_tweakreg_catalog_present: # read catalog from structured array catalog = Table( np.asarray(image_model.meta.source_detection.tweakreg_catalog) ) elif is_tweakreg_catalog_name_present: catalog = Table.read( image_model.meta.source_detection.tweakreg_catalog_name, format=self.catalog_format, ) else: raise AttributeError( "Attribute 'meta.source_detection.tweakreg_catalog' is missing." "Please either run SourceDetectionStep or provide a" "custom source catalog." ) # remove 4D numpy array from meta.source_detection if is_tweakreg_catalog_present: del image_model.meta.source_detection["tweakreg_catalog"] else: raise AttributeError( "Attribute 'meta.source_detection' is missing." "Please either run SourceDetectionStep or provide a" "custom source catalog." ) for axis in ["x", "y"]: if axis not in catalog.colnames: long_axis = axis + "centroid" if long_axis in catalog.colnames: catalog.rename_column(long_axis, axis) else: raise ValueError( "'tweakreg' source catalogs must contain a header with " "columns named either 'x' and 'y' or " "'xcentroid' and 'ycentroid'." ) filename = image_model.meta.filename # filter out sources outside the WCS bounding box bb = image_model.meta.wcs.bounding_box x = catalog["x"] y = catalog["y"] if bb is None: r, d = image_model.meta.wcs(x, y) mask = np.isfinite(r) & np.isfinite(d) catalog = catalog[mask] n_removed_src = np.sum(np.logical_not(mask)) if n_removed_src: self.log.info( f"Removed {n_removed_src} sources from {filename}'s " "catalog whose image coordinates could not be " "converted to world coordinates." ) else: # assume image coordinates of all sources within a bounding box # can be converted to world coordinates. ((xmin, xmax), (ymin, ymax)) = bb mask = (x > xmin) & (x < xmax) & (y > ymin) & (y < ymax) catalog = catalog[mask] n_removed_src = np.sum(np.logical_not(mask)) if n_removed_src: self.log.info( f"Removed {n_removed_src} sources from {filename}'s " "catalog that were outside of the bounding box." ) # set meta.tweakreg_catalog image_model.meta["tweakreg_catalog"] = catalog.as_array() nsources = len(catalog) if nsources == 0: self.log.warning(f"No sources found in {filename}.") else: self.log.info(f"Detected {len(catalog)} sources in {filename}.") images[i] = image_model # group images by their "group id": grp_img = list(images.models_grouped) self.log.info("") self.log.info(f"Number of image groups to be aligned: {len(grp_img):d}.") self.log.info("Image groups:") if len(grp_img) == 1 and not ALIGN_TO_ABS_REFCAT: self.log.info("* Images in GROUP 1:") for im in grp_img[0]: self.log.info(f" {im.meta.filename}") self.log.info("") # we need at least two exposures to perform image alignment self.log.warning("At least two exposures are required for image alignment.") self.log.warning("Nothing to do. Skipping 'TweakRegStep'...") self.skip = True for model in images: model.meta.cal_step["tweakreg"] = "SKIPPED" return input elif len(grp_img) == 1 and ALIGN_TO_ABS_REFCAT: # create a list of WCS-Catalog-Images Info and/or their Groups: g = grp_img[0] if len(g) == 0: raise AssertionError("Logical error in the pipeline code.") group_name = _common_name(g) imcats = list(map(self._imodel2wcsim, g)) # Remove the attached catalogs for model in g: model = ( model if isinstance(model, rdm.DataModel) else rdm.open(os.path.basename(model)) ) self.log.info(f"* Images in GROUP '{group_name}':") for im in imcats: im.meta["group_id"] = group_name self.log.info(f" {im.meta['name']}") self.log.info("") elif len(grp_img) > 1: # create a list of WCS-Catalog-Images Info and/or their Groups: imcats = [] for g in grp_img: if len(g) == 0: raise AssertionError("Logical error in the pipeline code.") else: group_name = _common_name(g) wcsimlist = list(map(self._imodel2wcsim, g)) # Remove the attached catalogs # for model in g: # del model.catalog self.log.info(f"* Images in GROUP '{group_name}':") for im in wcsimlist: im.meta["group_id"] = group_name # im.meta["image_model"] = group_name self.log.info(f" {im.meta['name']}") imcats.extend(wcsimlist) self.log.info("") # align images: xyxymatch = XYXYMatch( searchrad=self.searchrad, separation=self.separation, use2dhist=self.use2dhist, tolerance=self.tolerance, xoffset=0, yoffset=0, ) try: align_wcs( imcats, refcat=None or self.refcat, enforce_user_order=self.enforce_user_order, expand_refcat=self.expand_refcat, minobj=self.minobj, match=xyxymatch, fitgeom=self.fitgeometry, nclip=self.nclip, sigma=(self.sigma, "rmse"), clip_accum=True, ) except ValueError as e: msg = e.args[0] if ( msg == "Too few input images (or groups of images) with non-empty" " catalogs." ): # we need at least two exposures to perform image alignment self.log.warning(msg) self.log.warning( "At least two exposures are required for image alignment." ) self.log.warning("Nothing to do. Skipping 'TweakRegStep'...") for model in images: model.meta.cal_step["tweakreg"] = "SKIPPED" if not ALIGN_TO_ABS_REFCAT: self.skip = True return images else: raise e except RuntimeError as e: msg = e.args[0] if msg.startswith("Number of output coordinates exceeded allocation"): # we need at least two exposures to perform image alignment self.log.error(msg) self.log.error( "Multiple sources within specified tolerance " "matched to a single reference source. Try to " "adjust 'tolerance' and/or 'separation' parameters." ) self.log.warning("Skipping 'TweakRegStep'...") self.skip = True for model in images: model.meta.cal_step.tweakreg = "SKIPPED" return images else: raise e for imcat in imcats: model = imcat.meta["image_model"]() if model.meta.cal_step.get("tweakreg") == "SKIPPED": continue wcs = model.meta.wcs twcs = imcat.wcs if not self._is_wcs_correction_small(wcs, twcs): # Large corrections are typically a result of source # mis-matching or poorly-conditioned fit. Skip such models. self.log.warning( "WCS has been tweaked by more than" f" {10 * self.tolerance} arcsec" ) for model in images: model.meta.cal_step["tweakreg"] = "SKIPPED" if ALIGN_TO_ABS_REFCAT: self.log.warning("Skipping relative alignment (stage 1)...") else: self.log.warning("Skipping 'TweakRegStep'...") self.skip = True return images if ALIGN_TO_ABS_REFCAT: # Get catalog of GAIA sources for the field # # NOTE: If desired, the pipeline can write out the reference # catalog as a separate product with a name based on # whatever convention is determined by the JWST Cal Working # Group. if self.save_abs_catalog: output_name = os.path.join( self.catalog_path, f"fit_{self.abs_refcat.lower()}_ref.ecsv" ) else: output_name = None # initial shift to be used with absolute astrometry self.abs_xoffset = 0 self.abs_yoffset = 0 self.abs_refcat = self.abs_refcat.strip() gaia_cat_name = self.abs_refcat.upper() if gaia_cat_name in SINGLE_GROUP_REFCAT: try: ref_cat = amutils.create_astrometric_catalog( images, gaia_cat_name, output=output_name ) except Exception as e: self.log.warning( "TweakRegStep cannot proceed because of an error that " "occurred while fetching data from the VO server. " f"Returned error message: '{e}'" ) self.log.warning("Skipping 'TweakRegStep'...") self.skip = True for model in images: model.meta.cal_step["tweakreg"] = "SKIPPED" return images elif os.path.isfile(self.abs_refcat): ref_cat = Table.read(self.abs_refcat) else: raise ValueError( "'abs_refcat' must be a path to an " "existing file name or one of the supported " f"reference catalogs: {_SINGLE_GROUP_REFCAT_STR}." ) # Check that there are enough GAIA sources for a reliable/valid fit num_ref = len(ref_cat) if num_ref < self.abs_minobj: # Raise Exception here to avoid rest of code in this try block self.log.warning( f"Not enough sources ({num_ref}) in the reference catalog " "for the single-group alignment step to perform a fit. " f"Skipping alignment to the {self.abs_refcat} reference " "catalog!" ) else: # align images: # Update to separation needed to prevent confusion of sources # from overlapping images where centering is not consistent or # for the possibility that errors still exist in relative overlap. xyxymatch_gaia = XYXYMatch( searchrad=self.abs_searchrad, separation=self.abs_separation, use2dhist=self.abs_use2dhist, tolerance=self.abs_tolerance, xoffset=self.abs_xoffset, yoffset=self.abs_yoffset, ) # Set group_id to same value so all get fit as one observation # The assigned value, 987654, has been hard-coded to make it # easy to recognize when alignment to GAIA was being performed # as opposed to the group_id values used for relative alignment # earlier in this step. for imcat in imcats: imcat.meta["group_id"] = 987654 if ( "fit_info" in imcat.meta and "REFERENCE" in imcat.meta["fit_info"]["status"] ): del imcat.meta["fit_info"] # Perform fit align_wcs( imcats, refcat=ref_cat, enforce_user_order=True, expand_refcat=False, minobj=self.abs_minobj, match=xyxymatch_gaia, fitgeom=self.abs_fitgeometry, nclip=self.abs_nclip, sigma=(self.abs_sigma, "rmse"), ref_tpwcs=imcats[0], clip_accum=True, ) for imcat in imcats: image_model = imcat.meta["image_model"]() image_model.meta.cal_step["tweakreg"] = "COMPLETE" # retrieve fit status and update wcs if fit is successful: if "SUCCESS" in imcat.meta.get("fit_info")["status"]: # Update/create the WCS .name attribute with information # on this astrometric fit as the only record that it was # successful: if ALIGN_TO_ABS_REFCAT: # NOTE: This .name attrib agreed upon by the JWST Cal # Working Group. # Current value is merely a place-holder based # on HST conventions. This value should also be # translated to the FITS WCSNAME keyword # IF that is what gets recorded in the archive # for end-user searches. imcat.wcs.name = f"FIT-LVL2-{self.abs_refcat}" # serialize object from tweakwcs # (typecasting numpy objects to python types so that it doesn't cause an # issue when saving datamodel to ASDF) wcs_fit_results = { k: v.tolist() if isinstance(v, (np.ndarray, np.bool_)) else v for k, v in imcat.meta["fit_info"].items() } # add fit results and new WCS to datamodel image_model.meta["wcs_fit_results"] = wcs_fit_results # remove unwanted keys from WCS fit results for k in [ "eff_minobj", "matched_ref_idx", "matched_input_idx", "fit_RA", "fit_DEC", "fitmask", ]: del image_model.meta["wcs_fit_results"][k] image_model.meta.wcs = imcat.wcs return images
def _is_wcs_correction_small(self, wcs, twcs): """Check that the newly tweaked wcs hasn't gone off the rails""" tolerance = 10.0 * self.tolerance * u.arcsec ra, dec = wcs.footprint(axis_type="spatial").T tra, tdec = twcs.footprint(axis_type="spatial").T skycoord = SkyCoord(ra=ra, dec=dec, unit="deg") tskycoord = SkyCoord(ra=tra, dec=tdec, unit="deg") separation = skycoord.separation(tskycoord) return (separation < tolerance).all() def _imodel2wcsim(self, image_model): image_model = ( image_model if isinstance(image_model, rdm.DataModel) else rdm.open(os.path.basename(image_model)) ) catalog = image_model.meta.tweakreg_catalog model_name = os.path.splitext(image_model.meta.filename)[0].strip("_- ") try: if self.use_custom_catalogs: catalog_format = self.catalog_format else: catalog_format = "ascii.ecsv" if isinstance(catalog, str): # a string with the name of the catalog was provided catalog = Table.read(catalog, format=catalog_format) else: # catalog is a structured array, convert to astropy table: catalog = Table(catalog) catalog.meta["name"] = ( str(catalog) if isinstance(catalog, str) else model_name ) except OSError: self.log.error(f"Cannot read catalog {catalog}") # make sure catalog has 'x' and 'y' columns for axis in ["x", "y"]: if axis not in catalog.colnames: long_axis = axis + "centroid" if long_axis in catalog.colnames: catalog.rename_column(long_axis, axis) else: raise ValueError( "'tweakreg' source catalogs must contain either columns 'x' and" " 'y' or 'xcentroid' and 'ycentroid'." ) # create WCSImageCatalog object: refang = image_model.meta.wcsinfo # TODO: create RSTWCSCorrector in tweakwcs im = JWSTWCSCorrector( wcs=image_model.meta.wcs, wcsinfo={ "roll_ref": refang["roll_ref"], "v2_ref": refang["v2_ref"], "v3_ref": refang["v3_ref"], }, meta={ "image_model": weakref.ref(image_model), "catalog": catalog, "name": model_name, }, ) return im
def _common_name(group): file_names = [] for im in group: if isinstance(im, rdm.DataModel): file_names.append(os.path.splitext(im.meta.filename)[0].strip("_- ")) else: raise TypeError("Input must be a list of datamodels list.") cn = os.path.commonprefix(file_names) return cn def _parse_catfile(catfile): if catfile is None or not catfile.strip(): return None catdict = {} with open(catfile) as f: catfile_dir = os.path.dirname(catfile) for line in f.readlines(): sline = line.strip() if not sline or sline[0] == "#": continue data_model, *catalog = sline.split() catalog = list(map(str.strip, catalog)) if len(catalog) == 1: catdict[data_model] = os.path.join(catfile_dir, catalog[0]) elif len(catalog) == 0: catdict[data_model] = None else: raise ValueError("'catfile' can contain at most two columns.") return catdict