Source code for niriss_tools.grism.utils

"""
General utility functions related to handling grism data.
"""

from copy import deepcopy
from os import PathLike
from pathlib import Path

import astropy
import grizli.utils as grizli_utils
import numpy as np
from astropy.io import fits
from astropy.wcs import WCS
from grizli import utils as grizli_utils
from grizli.model import BeamCutout, GrismDisperser
from grizli.multifit import MultiBeam
from tqdm import tqdm

__all__ = [
    "gen_stacked_beams",
    "align_direct_images",
    "log_with_offset",
    "LINE_UP",
    "LINE_CLEAR",
]

LINE_UP = "\033[1A"
LINE_CLEAR = "\x1b[2K"


[docs] def log_with_offset(s: str, blank_lines: int = 1, curr_line: str = "") -> None: """ Print to previous lines on the console. Parameters ---------- s : str The string to print. blank_lines : int, optional How many blank lines to leave between the current one and the previous, by default ``1``. curr_line : str, optional The string to print to the current line, by default ``""``. """ print( (blank_lines + 1) * LINE_UP + LINE_CLEAR + s + (blank_lines + 1) * ("\n" + LINE_CLEAR) + curr_line, flush=True, ) return
[docs] def gen_stacked_beams( mb: str | MultiBeam, pixfrac: float = 1.0, kernel: str = "square", dfillval: float = 0, fit_trace_shift: bool = False, trace_shift_kwargs: dict = {}, cluster_beams: bool = False, dbscan_kwargs: dict = {"eps": 5}, **multibeam_kwargs, ): """ Stack individual beams with the same grism and blocking filter. This returns a "master" `~grizli.multifit.MultiBeam` object, with a single beam for each combination of grism orientation and blocking filter. Parameters ---------- mb : str | `grizli.multifit.MultiBeam` The original MultiBeam object, or the location of the ``*beams.fits.`` file. pixfrac : float, optional The fraction by which input pixels are “shrunk” before being drizzled onto the output image grid, given as a real number between 0 and 1. This specifies the size of the footprint, or “dropsize”, of a pixel in units of the input pixel size. By default ``pixfrac=1.0``. kernel : str, optional The form of the kernel function used to distribute flux onto the separate output images, by default ``"square"``. The current options are ``"square"``, ``"point"``, ``"turbo"``, ``"gaussian"``, and ``"lanczos3"``. dfillval : float, optional The value to be assigned to output pixels that have zero weight, or that do not receive flux from any input pixels during drizzling. By default this is 0. fit_trace_shift : bool, optional Fit for a cross-dispersion offset before stacking the beams, using `~grizli.multifit.MultiBeam.fit_trace_shift()`. By default ``False``. trace_shift_kwargs : dict, optional Additional keyword arguments to pass through to `~grizli.multifit.MultiBeam.fit_trace_shift()` if used. cluster_beams : bool, optional Cluster the beams based on their detector position before stacking, using the DBSCAN algorithm. This can minimise residuals arising from the trace variation across the detector, at the cost of an increased number of stacked beams. By default ``False``. dbscan_kwargs : dict, optional Any additional parameters to pass through to `sklean.cluster.DBSCAN`. By default ``dbscan_kwargs={"eps":5}``. **multibeam_kwargs : dict, optional Any additional parameters to pass through to `~grizli.multifit.MultiBeam` when loading the original object. Returns ------- `~grizli.multifit.MultiBeam` The stacked multibeam object. """ # if type(mb) is str: if not isinstance(mb, MultiBeam): mb = MultiBeam(beams=mb, **multibeam_kwargs) if fit_trace_shift: mb.fit_trace_shift(**trace_shift_kwargs) from drizzlepac import adrizzle adrizzle.log.setLevel("ERROR") drizzler = adrizzle.do_driz new_beam_list = [] for filt, pa_info in tqdm(mb.PA.items(), desc="Stacking beams"): for pa, pa_beam_idxs in pa_info.items(): if cluster_beams: pa_beam_idxs = np.array(pa_beam_idxs) # Detector coordinate centres detector_coords = np.array( [[mb.beams[b].beam.xc, mb.beams[b].beam.yc] for b in pa_beam_idxs] ) from sklearn.cluster import DBSCAN db = DBSCAN(**dbscan_kwargs).fit(detector_coords) labels = np.array(db.labels_) grouped_beam_idxs = [ pa_beam_idxs[labels == u] for u in np.unique(labels) ] else: grouped_beam_idxs = [pa_beam_idxs] for beam_idxs in grouped_beam_idxs: # As a reference beam, we use the one with the smallest shift from the centre # along the x-axis # This minimises the chance of trace pixel errors due to integer rounding # in the grizli and grismconf code direct_cen = ( np.asarray( mb.beams[beam_idxs[0]].direct.data["REF"].shape[::-1] ) # + 1 ) / 2 shift_dx = np.zeros((len(beam_idxs), 2)) for i, b_i in enumerate(beam_idxs): shift_dx[i] = direct_cen - np.array( mb.beams[b_i] .direct.wcs.all_world2pix( [[mb.ra, mb.dec]], 1, ra_dec_order=True, ) .flatten() ) new_beam = deepcopy( mb.beams[ beam_idxs[ np.argmin(np.abs(shift_dx[:, 0] - np.round(shift_dx[:, 0]))) ] ] ) # Set centre of direct image to the actual coordinates shift_crpix = direct_cen - np.array( new_beam.direct.wcs.all_world2pix( [[mb.ra, mb.dec]], 1, ra_dec_order=True, ).flatten() ) new_beam.grism.wcs = grizli_utils.transform_wcs( new_beam.grism.wcs, translation=[ shift_crpix[0] - new_beam.beam.xoffset, shift_crpix[1] - new_beam.beam.yoffset, ], ) new_beam.direct.wcs = grizli_utils.transform_wcs( new_beam.direct.wcs, translation=shift_crpix ) sh = new_beam.sh outsci = np.zeros(sh, dtype=np.float32) outwht = np.zeros(sh, dtype=np.float32) outctx = np.zeros(sh, dtype=np.int32) outvar = np.zeros(sh, dtype=np.float32) outwv = np.zeros(sh, dtype=np.float32) outcv = np.zeros(sh, dtype=np.int32) outcon = np.zeros(sh, dtype=np.float32) outwc = np.zeros(sh, dtype=np.float32) outcc = np.zeros(sh, dtype=np.int32) outdir = np.zeros(new_beam.direct.data["REF"].shape, dtype=np.float32) outwd = np.zeros(new_beam.direct.data["REF"].shape, dtype=np.float32) outcd = np.zeros(new_beam.direct.data["REF"].shape, dtype=np.int32) grism_data = [mb.beams[i].grism.data["SCI"] for i in beam_idxs] direct_data = [mb.beams[i].direct.data["REF"] for i in beam_idxs] dir_scale = np.nanmedian( new_beam.direct.data["REF"] / new_beam.beam.direct ) new_seg = grizli_utils.blot_nearest_exact( mb.beams[beam_idxs[0]].beam.seg, mb.beams[beam_idxs[0]].direct.wcs, new_beam.direct.wcs, verbose=False, stepsize=-1, scale_by_pixel_area=False, ) for i, idx in enumerate(beam_idxs): beam = mb.beams[idx] direct_wcs_i = beam.direct.wcs.copy() grism_wcs_i = grizli_utils.transform_wcs( beam.grism.wcs.copy(), translation=[-beam.beam.xoffset, -beam.beam.yoffset], ) # contam_weight = np.exp( # -(mb.fcontam * np.abs(beam.contam) * np.sqrt(beam.ivar)) # ) contam_weight = np.ones_like(beam.ivar) # grism_wht = beam.ivar * contam_weight grism_wht = beam.ivar grism_wht[~np.isfinite(grism_wht)] = 0.0 contam_wht = beam.ivar contam_wht[~np.isfinite(contam_wht)] = 0.0 drizzler( direct_data[i], direct_wcs_i, np.ones_like(direct_data[i]), new_beam.direct.wcs, outdir, outwd, outcd, 1.0, "cps", 1, wcslin_pscale=1.0, uniqid=1, pixfrac=pixfrac, kernel=kernel, fillval=dfillval, wcsmap=grizli_utils.WCSMapAll, ) drizzler( grism_data[i], grism_wcs_i, grism_wht, new_beam.grism.wcs, outsci, outwht, outctx, 1.0, "cps", 1, wcslin_pscale=1.0, uniqid=1, pixfrac=pixfrac, kernel=kernel, fillval=dfillval, wcsmap=grizli_utils.WCSMapAll, ) drizzler( beam.contam, grism_wcs_i, contam_wht, new_beam.grism.wcs, outcon, outwc, outcc, 1.0, "cps", 1, wcslin_pscale=1.0, uniqid=1, pixfrac=pixfrac, kernel=kernel, fillval=dfillval, wcsmap=grizli_utils.WCSMapAll, ) drizzler( contam_weight, grism_wcs_i, grism_wht, new_beam.grism.wcs, outvar, outwv, outcv, 1.0, "cps", 1, wcslin_pscale=1.0, uniqid=1, pixfrac=pixfrac, kernel=kernel, fillval=dfillval, wcsmap=grizli_utils.WCSMapAll, ) # Correct for drizzle scaling area_ratio = 1.0 / new_beam.grism.wcs.pscale**2 # preserve flux density spatial_scale = 1.0 flux_density_scale = spatial_scale**2 # science outsci *= area_ratio * flux_density_scale # Direct outdir *= area_ratio * flux_density_scale # Variance outvar *= area_ratio / outwv * flux_density_scale**2 outwht = 1 / outvar outwht[(outvar == 0) | (~np.isfinite(outwht))] = 0 # Contam outcon *= area_ratio * flux_density_scale new_beam.grism.data["SCI"] = outsci new_beam.grism.data["ERR"] = np.sqrt(outvar) new_beam.grism.data["DQ"] = np.zeros_like(outsci) new_beam.contam = outcon new_beam.direct.data["REF"] = outdir new_beam.direct.header.update( grizli_utils.to_header(new_beam.direct.wcs) ) new_beam.grism.header.update(grizli_utils.to_header(new_beam.grism.wcs)) new_beam.beam = GrismDisperser( id=mb.id, direct=outdir, segmentation=new_seg, origin=np.nanmedian( np.asarray([mb.beams[i].direct.origin for i in beam_idxs]), axis=0, ), pad=np.nanmedian( np.asarray([mb.beams[i].direct.pad for i in beam_idxs]), axis=0, ), grow=np.nanmedian( np.asarray([mb.beams[i].direct.grow for i in beam_idxs]), axis=0, ), beam=mb.beams[beam_idxs[0]].beam.beam, xcenter=0, ycenter=0, conf=mb.beams[beam_idxs[0]].beam.conf, fwcpos=mb.beams[beam_idxs[0]].beam.fwcpos, MW_EBV=mb.beams[beam_idxs[0]].beam.MW_EBV, xoffset=0.0, yoffset=0.0, ) new_beam.beam.compute_model() new_beam.modelf = new_beam.beam.modelf new_beam.model = new_beam.beam.modelf.reshape(new_beam.beam.sh_beam) # new_beam.compute_model() new_beam._parse_from_data( isJWST=True, contam_sn_mask=[10, 3], min_mask=mb.min_mask, min_sens=mb.min_sens, mask_resid=mb.mask_resid, ) new_beam.direct.data["REF"] /= dir_scale new_beam.direct.ref_photflam = new_beam.direct.photflam new_beam_list.append(new_beam) new_multibeam = MultiBeam( new_beam_list, group_name=mb.group_name, fcontam=mb.fcontam, min_mask=mb.min_mask, min_sens=mb.min_sens, mask_resid=mb.mask_resid, ) return new_multibeam
[docs] def align_direct_images( ref_beam: BeamCutout, info_dict: dict, out_dir: PathLike = None, cutout=200, overwrite: bool = False, ) -> dict: """ Align a set of images to the orientation of a dispersed beam. Given a nested dictionary, containing both ``"sci"`` and ``"var"`` keys pointing to the location of the images, this blots the images to the same coordinate system used in the direct image of a `grizli.model.BeamCutout`. Parameters ---------- ref_beam : BeamCutout The dispersed beam to be used as a reference. All images will be aligned to the direct image in this beam. info_dict : dict A nested dictionary, where each value is a dictionary containing ``"sci"`` and ``"var"`` keys. The values for these should point to the location of the original FITS images to be blotted. out_dir : PathLike, optional The location in which the realigned images will be saved. This will default to the current working directory. cutout : int, optional Make a slice of the original image with size ``[-cutout,+cutout]`` around the centre position of the desired object, before passing to blot. By default, ``cutout=200``. overwrite : bool, optional Overwrite existing images if they exist already. By default ``False``. Returns ------- dict An updated version of ``info_dict``, now with the locations of the realigned images. """ from drizzlepac.astrodrizzle import ablot if out_dir is not None: out_dir = Path(out_dir) out_dir.mkdir(exist_ok=True, parents=True) else: out_dir = Path.cwd() beam_wcs = ref_beam.direct.wcs beam_ra, beam_dec = beam_wcs.all_pix2world( [(np.asarray(ref_beam.direct.sh) + 1) / 2], 1, ).flatten() new_info_dict = info_dict.copy() for k, v in info_dict.items(): for img_type in ["sci", "var"]: if (not (out_dir / Path(v[img_type]).name).is_file()) or overwrite: with fits.open(v[img_type]) as orig_hdul: orig_data = orig_hdul[0].data orig_header = orig_hdul[0].header orig_image_filename = Path(orig_hdul.filename()).name if orig_data.dtype not in [np.float32, np.dtype(">f4")]: orig_data = orig_data.astype(np.float32) orig_wcs = WCS(orig_header, relax=True) orig_wcs.pscale = grizli_utils.get_wcs_pscale(orig_wcs) if not hasattr(orig_wcs, "_naxis1") & hasattr(orig_wcs, "_naxis"): orig_wcs._naxis1, orig_wcs._naxis2 = orig_wcs._naxis if "PHOTPLAM" in orig_header: orig_photplam = orig_header["PHOTPLAM"] else: orig_photplam = 1.0 if "PHOTFLAM" in orig_header: orig_photflam = orig_header["PHOTFLAM"] else: orig_photflam = 1.0 try: orig_filter = grizli_utils.parse_filter_from_header(orig_header) except: orig_filter = "N/A" xy = np.asarray( np.round(orig_wcs.all_world2pix([beam_ra], [beam_dec], 0)), dtype=int, ).flatten() sh = orig_data.shape slx = slice( np.maximum(xy[0] - cutout, 0), np.minimum(xy[0] + cutout, sh[1]) ) sly = slice( np.maximum(xy[1] - cutout, 0), np.minimum(xy[1] + cutout, sh[0]) ) if hasattr(beam_wcs, "idcscale"): if beam_wcs.idcscale is None: delattr(beam_wcs, "idcscale") if not hasattr(beam_wcs, "_naxis1") & hasattr(beam_wcs, "_naxis"): beam_wcs._naxis1, beam_wcs._naxis2 = beam_wcs._naxis blotted = ablot.do_blot( orig_data[sly, slx], orig_wcs.slice([sly, slx]), beam_wcs, 1, coeffs=True, interp="sinc", sinscl=1.0, stepsize=1, wcsmap=grizli_utils.WCSMapAll, ) orig_header.update(beam_wcs.to_header()) new_hdul = fits.HDUList() new_hdul.append( fits.ImageHDU( data=blotted, header=orig_header, ) ) new_hdul.writeto( (out_dir / Path(v[img_type]).name), overwrite=overwrite ) new_info_dict[k][img_type] = str(out_dir / Path(v[img_type]).name) return new_info_dict
def gen_psf(multibeam: MultiBeam) -> dict: """ Generate a PSF aligned with the direct image in extracted beams. The PSF matches the rotation of the direct imaging using the ``"PA_APER"`` header keyword. Parameters ---------- multibeam : MultiBeam The grizli-extracted multiple beams object. Returns ------- dict A dictionary with keys corresponding to each unique grism and filter combination, and values of the PSF image. """ import stpsf from drizzlepac.astrodrizzle import ablot psf_aligned_images = {} for i, (beam_cutout, cutout_shape) in enumerate( zip(multibeam.beams, multibeam.Nflat) ): beam_name = f"{beam_cutout.grism.pupil}-{beam_cutout.grism.filter}" if not beam_name in psf_aligned_images: header = beam_cutout.direct.header beam_wcs = beam_cutout.direct.wcs inst = stpsf.instrument(header["INSTRUME"]) inst.set_position_from_aperture_name("NIS_CEN") inst.filter = header["PUPIL"] dateobs = astropy.time.Time( header["DATE-BEG"] ) # + 'T' + header['TIME-OBS']) inst.load_wss_opd_by_date( dateobs, verbose=False, plot=False, choice="closest" ) psf = inst.calc_psf( fov_pixels=np.nanmax(beam_wcs._naxis) * 2 + 1, ) psf_data = psf["DET_DIST"].data psf_wcs = WCS(psf["DET_DIST"]) psf_wcs.wcs.crpix = (np.asarray(psf_data.shape) + 1) / 2 psf_wcs.wcs.crval = [multibeam.ra, multibeam.dec] rotation_angle_rad = np.radians(header["PA_APER"] - 360) psf_wcs.wcs.cd = ( np.array( [ [np.cos(rotation_angle_rad), -np.sin(rotation_angle_rad)], [np.sin(rotation_angle_rad), np.cos(rotation_angle_rad)], ] ) * (inst.pixelscale * u.arcsec).to(u.deg).value ) psf_wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] psf_wcs.pscale = grizli_utils.get_wcs_pscale(psf_wcs) blotted = ablot.do_blot( psf_data.astype(np.float32), psf_wcs, beam_wcs, 1, coeffs=True, sinscl=1.0, stepsize=1, wcsmap=grizli_utils.WCSMapAll, ) psf_aligned_images[beam_name] = blotted return psf_aligned_images