Source code for niriss_tools.grism.multiregion

"""
Functions and classes related to multi-region grism fitting.
"""

import multiprocessing
import os
import shutil
import sys
import traceback
from collections.abc import Callable
from copy import deepcopy
from functools import partial
from itertools import product
from multiprocessing import Lock, Manager, Pool, cpu_count, shared_memory
from multiprocessing.managers import SharedMemoryManager
from os import PathLike
from pathlib import Path
from time import time

import h5py
import matplotlib.pyplot as plt
import numpy as np
import scipy.optimize
from astropy.io import fits
from astropy.nddata import block_reduce
from astropy.table import Table
from astropy.wcs import WCS
from bagpipes_extended import AtlasFitter, AtlasGenerator
from bagpipes_extended.pipeline import generate_fit_params, load_photom_bagpipes
from grizli import utils as grizli_utils
from grizli.model import GrismDisperser
from grizli.multifit import MultiBeam, drizzle_to_wavelength
from numpy.typing import ArrayLike
from reproject import reproject_interp

import niriss_tools
from niriss_tools.grism.fitting_tools import CDNNLS, fennls, fnnls
from niriss_tools.grism.specgen import (
    CLOUDY_LINE_MAP,
    BagpipesSampler,
    ExtendedModelGalaxy,
    check_coverage,
    pre_gen_spec,
)
from niriss_tools.grism.utils import (
    LINE_UP,
    align_direct_images,
    gen_stacked_beams,
    log_with_offset,
)
from niriss_tools.pipeline.reduction import recursive_merge
from niriss_tools.sed.binning import bin_and_save

__all__ = ["MultiRegionFit", "DEFAULT_PLINE"]

pipes_sampler = None
beams_object = None

DEFAULT_PLINE = {
    "pixscale": 0.06,
    "pixfrac": 1.0,
    "size": 5,
    "kernel": "lanczos3",
}


def _init_beams(beams):
    global beams_object
    beams_object = deepcopy(beams)


def _init_pipes_sampler(fit_instructions, veldisp, beams, lock_value=None):
    global pipes_sampler
    pipes_sampler = BagpipesSampler(fit_instructions=fit_instructions, veldisp=veldisp)
    _init_beams(beams)
    global lock_var
    lock_var = lock_value


[docs] class MultiRegionFit: """ A multi-region version of `grizli.multifit.MultiBeam`. Parameters ---------- config_path : PathLike The TOML-formatted configuration file containing the parameters to use for the fit. obj_id : int The object ID number to fit. obj_z : float The redshift at which the object should be fitted. run_all : bool, optional If ``True`` (default), the fit will proceed automatically based on the setup described in the configuration file. If ``False``, the configuration file will be parsed, but the individual fitting methods must be called manually. """ def __init__( self, config_path: PathLike, obj_id: int, obj_z: float, run_all: bool = True, ): self.obj_id = obj_id self.obj_z = obj_z self.import_config(config_path) if run_all: self.run_all()
[docs] def run_all(self): """ Run the full fit based on the supplied configuration file. """ self.beam_path, self.binned_data_path = self.gen_aligned_photometry( self.binning_kwargs, use_stacks=self.use_stacks, **self.multibeam_kwargs, ) self.atlas_path = self.gen_atlas(**self.sed_fit_kwargs) self.fit_atlas( self.atlas_path, self.binned_data_path, overwrite_fit=self.overwrite_atlas_fit, n_cores=self.n_cores_atlas_fit, z_range=self.sed_fit_kwargs["z_range"], ) self.MB = MultiBeam(beams=str(self.beam_path), **self.multibeam_kwargs) self.ra, self.dec = self.MB.ra, self.MB.dec self.regions_phot_cat = Table.read(self.binned_data_path, "PHOT_CAT") self.regions_phot_cat = self.regions_phot_cat[ self.regions_phot_cat["bin_id"].astype(int) != 0 ] self.n_regions = len(self.regions_phot_cat) with fits.open(self.binned_data_path) as hdul: self.regions_seg_map = hdul["SEG_MAP"].data.copy() self.regions_seg_hdr = hdul["SEG_MAP"].header.copy() self.regions_seg_wcs = WCS(self.regions_seg_hdr) self.regions_seg_ids = np.asarray(self.regions_phot_cat["bin_id"], dtype=int) self.fit_at_z(self.obj_z, **self.grism_fit_kwargs)
[docs] def import_config(self, config_path: PathLike): """ Parse a configuration file and set class attributes accordingly. Parameters ---------- config_path : PathLike The TOML-formatted configuration file containing the parameters to use for the fit. """ import tomllib import warnings import yaml with open(config_path, "rb") as f: config = tomllib.load(f) self.root_dir = Path(config["files"].get("root_dir", "/")) self.out_dir = self.root_dir / config["files"].get("out_dir", ".") self.out_dir.mkdir(exist_ok=True, parents=True) self.extractions_dir = self.root_dir / config["files"].get( "extractions_dir", "." ) _info_dict = config["files"].get("info_dict", "") if not _info_dict: raise ValueError("`info_dict` is not present in the supplied config file.") with open(self.root_dir / _info_dict, "r") as file: self.info_dict = yaml.safe_load(file) self.pipes_dir = self.out_dir / "sed_fitting" / "pipes" self.pipes_dir.mkdir(exist_ok=True, parents=True) if not config["files"].get("atlas_dir", ""): self.atlas_dir = self.pipes_dir / "atlases" else: self.atlas_dir = self.root_dir / config["files"]["atlas_dir"] self.atlas_dir.mkdir(exist_ok=True, parents=True) self.binning_kwargs = { "bin_scheme": config["SED"].get("bin_scheme", "colour"), "target_sn": config["SED"].get("target_sn", 10), "bin_diameter": config["SED"].get("bin_diameter", 3), "sn_filter": config["SED"]["sn_filter"], **config["SED"].get("bin_kwargs", {}), } self.sed_fit_kwargs = { "z_range": config["SED"].get("z_range", 0.005), "sfh_type": config["SED"].get("sfh_type", "continuity"), "min_age_bin": config["SED"].get("min_age_bin", 20), "num_age_bins": config["SED"].get("num_age_bins", 5), "n_samples": config["SED"].get("n_samples", 1e6), "remake_atlas": config["SED"].get("remake_atlas", False), "n_cores_atlas": config["SED"].get("n_cores_atlas", 4), } self.n_cores_atlas_fit = config["SED"].get("n_cores_fit", 4) self.overwrite_atlas_fit = config["SED"].get("overwrite_fit", False) self.field_name = config["grism"]["field_name"] self.use_stacks = config["grism"].get("use_stacks", True) self.stack_beam_kwargs = config["grism"].get("stack_beam_kwargs", {}) multibeam_kwargs = config["grism"].get("multibeam_kwargs", {}) default_multib_kwargs = { "min_mask": 0.0, "min_sens": 0.0, "mask_resid": False, "verbose": False, "fcontam": 0.2, "group_name": self.field_name, } self.multibeam_kwargs = recursive_merge(default_multib_kwargs, multibeam_kwargs) self.grism_fit_kwargs = { "fit_background": config["grism"].get("fit_background", True), "poly_order": config["grism"].get("poly_order", 0), "n_samples": config["grism"].get("n_region_samples", 3), "n_iters": config["grism"].get("n_iters", 10), "bad_pa_threshold": config["grism"].get("bad_pa_threshold", 1.6), "spec_wavs": config["grism"].get("spec_wavs", None), "oversamp_factor": config["grism"].get("oversamp_factor", 1), "veldisp": config["grism"].get("veldisp", 500), "out_dir": config["grism"].get("out_dir", "multiregion"), "temp_dir": config["grism"].get("temp_dir", None), "memmap": config["grism"].get("memmap", False), "cpu_count": config["grism"].get("cpu_count", -1), "overwrite": config["grism"].get("overwrite", False), "use_lines": config["grism"].get("use_lines", CLOUDY_LINE_MAP), "save_lines": config["grism"].get("save_lines", True), "save_stacks": config["grism"].get("save_stacks", True), "pline": config["grism"].get("pline", DEFAULT_PLINE), "seed": config["grism"].get("seed", 2744), "nnls_method": config["grism"].get("nnls_method", "scipy"), "nnls_iters": config["grism"].get("nnls_iters", 10), "nnls_tol": config["grism"].get("nnls_tol", 1e-5), "n_shifted": config["grism"].get("n_shifted", 2), "n_shifted_samples": config["grism"].get("n_shifted_samples", 1), "cache_spec": config["grism"].get("cache_spec", True), }
[docs] def fit_atlas( self, atlas_path: PathLike, binned_data_path: PathLike, binned_data_hdu: str | int | None = "PHOT_CAT", bagpipes_atlas_params: dict | None = None, load_fn: Callable | None = None, overwrite_fit: bool = False, id_colname: str = "bin_id", n_cores: int = 4, obj_z: float | ArrayLike | None = None, z_range: float = 0.005, ): """ Perform a 2D SED fit to the binned photometric data. Parameters ---------- atlas_path : PathLike The location of the previously generated model grid. binned_data_path : PathLike The location of the binned photometric catalogue and segmentation map used as input for Bagpipes. binned_data_hdu : str | int | None, optional The identifier of the HDU within ``binned_data_path`` containing the photometric catalogue, by default ``"PHOT_CAT"``. bagpipes_atlas_params : dict | None, optional A dictionary containing instructions on the kind of model which should be fitted to the data. This should match the previously generated model grid. If ``None`` (default), ``self.bagpipes_atlas_params`` will be used. load_fn : Callable | None, optional A function which takes the ID as an argument and returns the model photometry. This should be in the form of an array with a column of fluxes in microjanskys and a column of flux errors in the same units. If ``None`` (default), `~bagpipes_extended.pipeline.load_photom_bagpipes` will be used. overwrite_fit : bool, optional If ``True``, then any existing posterior distributions and output catalogues will be overwritten. By default ``False``. id_colname : str, optional The name of the column in the photometric catalogue containing the bin ID, by default ``"bin_id"``. n_cores : int, optional The number of processes to use when fitting the catalogue. If set to ``0``, the code will run on a single process. If set to an integer less than 0, this will run on the number of cores returned by `multiprocessing.cpu_count`. By default, ``4`` processes will be used. obj_z : float | ArrayLike | None, optional This can be used to override the redshift used for fitting, in case of a mismatch between the model atlas and the object of interest. See `~niriss_tools.grism.MultiRegionFit.gen_atlas` for more details. z_range : float, optional As above. """ if bagpipes_atlas_params is None: bagpipes_atlas_params = self.bagpipes_atlas_params os.chdir(self.pipes_dir) if load_fn is None: load_fn = partial( load_photom_bagpipes, phot_cat=binned_data_path, cat_hdu_index=binned_data_hdu, ) fit = AtlasFitter( fit_instructions=bagpipes_atlas_params, atlas_path=atlas_path, out_path=self.pipes_dir.parent, overwrite=overwrite_fit, ) self.run_name = str(Path(binned_data_path).stem).removesuffix("_data") obs_table = Table.read(binned_data_path, hdu=binned_data_hdu) cat_IDs = np.array(obs_table[id_colname]) catalogue_out_path = fit.out_path / f"{self.run_name}.fits" if (not catalogue_out_path.is_file()) or overwrite_fit: fit.fit_catalogue( IDs=cat_IDs, load_data=load_fn, spectrum_exists=False, make_plots=False, cat_filt_list=self.filter_list, run=self.run_name, parallel=n_cores, redshifts=self.obj_z if obj_z is None else obj_z, redshift_range=z_range, n_posterior=500, ) else: fit.cat = Table.read(catalogue_out_path) self.sed_fit_cat_path = fit.out_path / f"{self.run_name}.fits"
[docs] def gen_atlas( self, obj_z: float | ArrayLike | None = None, z_range: float = 0.005, num_age_bins: int = 5, min_age_bin: float = 20, sfh_type: str = "continuity", n_samples: int | float = 1e6, remake_atlas: bool = False, n_cores_atlas: int = 4, ) -> PathLike: """ Generate the fit parameters and model grid. This method generates a dictionary of fit parameters for Bagpipes, as well as a large model grid to speed up the SED fitting. Parameters ---------- obj_z : float | ArrayLike | None, optional The redshift of the object to fit. If a scalar value is passed, and ``z_range==0.0``, the object will be fit to a single redshift value. If ``z_range!=0.0``, this will be the centre of the redshift window. If an array is passed, this explicity sets the redshift range to use for fitting. If ``None`` (default), this will be set to ``self.obj_z``. z_range : float, optional The maximum redshift range to search over, by default 0.005. To fit to a single redshift, pass a single value for ``obj_z``, and set ``z_range=0.0``. If ``obj_z`` is ``ArrayLike``, this parameter is ignored. num_age_bins : int, optional The number of age bins to fit, each of which will have a constant star formation rate following Leja+19. By default, ``5`` bins are generated. min_age_bin : float, optional The minimum age to use for the continuity SFH in Myr, i.e. the first bin will range from ``(0,min_age_bin)``. By default 20. sfh_type : str, optional The type of SFH prior to generate. Currently supports ``"continuity"`` (Leja+19, fixed age bins), ``"continuity_varied_z"`` (Leja+19, only the youngest age bin is fixed), and ``"dblplaw"``. n_samples : int | float, optional The number of samples to generate. By default ``1e6``. A useful number will typically be ``>10^5``. remake_atlas : bool, optional If ``True``, any existing model atlas with the same name will be recreated and overwritten. By default ``False``. n_cores_atlas : int, optional The number of processes to use when generating the model grid. If set to ``0``, the code will run on a single process. If set to an integer less than 0, this will run on the number of cores returned by `multiprocessing.cpu_count`. By default, ``4`` processes will be used. Returns ------- PathLike The location of the model atlas. """ default_filter_dir = ( Path(niriss_tools.__file__).parent / "data" / "filter_throughputs" ) # Create the filter directory; populate as needed filter_dir = self.pipes_dir / "filter_throughputs" filter_dir.mkdir(exist_ok=True, parents=True) for file in default_filter_dir.glob("*.txt"): shutil.copy(file, filter_dir) # Create a list of the filters used in our data self.filter_list = [] for key in self.info_dict.keys(): self.filter_list.append(str(filter_dir / f"{key}.txt")) self.bagpipes_atlas_params = generate_fit_params( obj_z=self.obj_z if obj_z is None else obj_z, z_range=z_range, num_age_bins=num_age_bins, sfh_type=sfh_type, min_age_bin=min_age_bin, ) self.atlas_run_name = ( f"z_{self.bagpipes_atlas_params["redshift"][0]}_" f"{self.bagpipes_atlas_params["redshift"][1]}_" f"{n_samples:.2E}" ) atlas_path = self.atlas_dir / f"{self.atlas_run_name}.hdf5" if not atlas_path.is_file() or remake_atlas: atlas_gen = AtlasGenerator( fit_instructions=self.bagpipes_atlas_params, filt_list=self.filter_list, phot_units="ergscma", ) atlas_gen.gen_samples(n_samples=n_samples, parallel=n_cores_atlas) atlas_gen.write_samples(filepath=atlas_path) return atlas_path
[docs] def gen_aligned_photometry( self, binning_kwargs: dict, use_stacks: bool = True, beams_path: PathLike | None = None, img_cutout: int = 500, stack_beam_kwargs: dict = {}, **multibeam_kwargs, ) -> tuple[PathLike, PathLike]: """ Align photometric data to the direct image in an extracted beam. Parameters ---------- binning_kwargs : dict Any arguments to pass to `~niriss_tools.sed.binning.bin_and_save`. use_stacks : bool, optional Whether to fit to individual beams, or beams stacked by filter and grism. By default ``True``. beams_path : PathLike | None, optional The location of a ``*beams.fits`` file to use for fitting. If ``None`` (default), this will be selected automatically based on the ``obj_id`` and directory structure specified in the configuration file. img_cutout : int, optional Make a slice of the original image with size in pixels ``[-cutout,+cutout]`` around the centre of the object, before alignment. By default, ``cutout=500``. stack_beam_kwargs : dict, optional Any additional parameters to pass through to `~niriss_tools.grism.utils.gen_stacked_beams`. **multibeam_kwargs : dict, optional Any additional parameters to pass through to `grizli.multifit.MultiBeam`. Returns ------- new_beam_path : PathLike The location of the ``*beams.fits`` used for alignment. If ``use_stacks==True``, this will be a stacked version of the input file. binned_data_path : PathLike The location of the binned data in FITS format. This contains both the segmentation map and the binned photometric catalogue. Raises ------ IOError If no ``*beams.fits`` file can be found, this method will raise an error. """ binned_data_dir = self.out_dir / "binned_data" binned_data_dir.mkdir(exist_ok=True, parents=True) new_beam_loc = ( binned_data_dir / f"{self.field_name}_{self.obj_id:0>5}.beams.fits" ) # Give a descriptive name for the binned data binned_name = ( f"{self.obj_id}_{binning_kwargs["bin_scheme"]}_" f"{binning_kwargs["bin_diameter"]}_{binning_kwargs["target_sn"]}" f"_{binning_kwargs["sn_filter"]}" ) binned_data_path = binned_data_dir / f"{binned_name}_data.fits" if not binned_data_path.is_file(): try: multib = MultiBeam( str(new_beam_loc), **multibeam_kwargs, ) except: if beams_path is None: beams_path = [ *self.extractions_dir.glob(f"**/*{self.obj_id:0>5}.beams.fits") ] if len(beams_path) >= 1: beams_path = [str(b) for b in beams_path] else: raise IOError( f"Original beams file does not exist in {self.extractions_dir}" ) if multibeam_kwargs is None: multibeam_kwargs = self.multibeam_kwargs if len(stack_beam_kwargs) == 0: stack_beam_kwargs = self.stack_beam_kwargs if use_stacks: multib = gen_stacked_beams( beams_path, **stack_beam_kwargs, **multibeam_kwargs, ) else: multib = MultiBeam( beams_path, **multibeam_kwargs, ) # Write the realigned and (stacked?) beam to a file beam_hdul = multib.write_master_fits(get_hdu=True) beam_hdul.writeto(new_beam_loc, overwrite=True) # Align all images to the new beam aligned_info_dict = align_direct_images( multib.beams[0], info_dict=self.info_dict, out_dir=binned_data_dir / f"{self.obj_id:0>5}", overwrite=False, cutout=img_cutout, ) _seg = fits.getdata(new_beam_loc, "SEG") bin_and_save( obj_id=self.obj_id, out_dir=binned_data_dir, seg_map=_seg, info_dict=aligned_info_dict, binned_name=binned_name, **binning_kwargs, ) return new_beam_loc, binned_data_path
[docs] def add_pipes_info(self, header: fits.Header) -> fits.Header: """ Update a header with information about the 2D SED fitting. Parameters ---------- header : fits.Header The original header. Returns ------- fits.Header The updated header. """ header["MRBPRUN"] = ( str(self.run_name), "The name of the bagpipes run used to generate the prior templates.", ) header["MRBPPCAT"] = ( str(self.binned_data_path), "The binned photometric catalogue and segmentation map used as input for bagpipes.", ) header["MRBPFCAT"] = ( str(self.sed_fit_cat_path), "The bagpipes output fit catalogue.", ) return header
@staticmethod def _gen_stacked_templates_from_pipes( seg_idx=None, seg_id=None, shared_seg_name=None, seg_maps_shape=None, shared_models_name=None, models_shape=None, posterior_dir=None, spectral_dir=None, n_samples=None, spec_wavs=None, beam_info=None, temp_offset=None, cont_only=False, rm_line=None, rows=None, memmap: bool = False, id_shifts=None, n_shifted_rows=None, ): seg_id = int(seg_id) try: if memmap: seg_maps = np.memmap( shared_seg_name, dtype=np.float32, shape=seg_maps_shape, mode="r+" ) temps_arr = np.memmap( shared_models_name, dtype=np.float32, shape=models_shape, mode="r+" ) else: shm_seg_maps = shared_memory.SharedMemory(name=shared_seg_name) seg_maps = np.ndarray( seg_maps_shape, dtype=np.float32, buffer=shm_seg_maps.buf ) shm_temps = shared_memory.SharedMemory(name=shared_models_name) temps_arr = np.ndarray( models_shape, dtype=np.float32, buffer=shm_temps.buf ) if spectral_dir is not None: temps_resampled = np.zeros( (len(rows) + len(id_shifts) * n_shifted_rows, spec_wavs.shape[0]) ) with h5py.File( Path(spectral_dir) / f"{seg_id}.h5", "r", ) as spec_file: temps_resampled[: len(rows), :] = np.array(spec_file["spec_data"])[ rows ] # Just make one set of templates from all possible seg ids if (id_shifts is not None) and (len(id_shifts) > 0): for s_i, s in enumerate(id_shifts): shifted_id = int((seg_id + s) % np.nanmax(seg_maps[0])) with h5py.File( Path(spectral_dir) / f"{shifted_id}.h5", "r", ) as spec_file: temps_resampled[ int(len(rows) + s_i * n_shifted_rows) : int( len(rows) + (s_i + 1) * n_shifted_rows ) : ] = np.array(spec_file["spec_data"])[rows[:n_shifted_rows]] else: with h5py.File(Path(posterior_dir) / f"{seg_id}.h5", "r") as post_file: samples2d = np.zeros( ( len(rows) + len(id_shifts) * n_shifted_rows, post_file["samples2d"].shape[1], ) ) samples2d[: len(rows), :] = np.array(post_file["samples2d"])[rows] # Just make one set of posterior samples from all possible seg ids if (id_shifts is not None) and (len(id_shifts) > 0): for s_i, s in enumerate(id_shifts): shifted_id = int((seg_id + s) % np.nanmax(seg_maps[0])) with h5py.File( Path(posterior_dir) / f"{shifted_id}.h5", "r" ) as post_file: samples2d[ int(len(rows) + s_i * n_shifted_rows) : int( len(rows) + (s_i + 1) * n_shifted_rows ) : ] = np.array(post_file["samples2d"])[rows[:n_shifted_rows]] temps_resampled = np.zeros((samples2d.shape[0], spec_wavs.shape[0])) for sample_i, sample in enumerate(samples2d): temps_resampled[sample_i] = pipes_sampler.sample( sample, spec_wavs=spec_wavs, cont_only=cont_only, rm_line=rm_line, )[1] for sample_i, temp_spec in enumerate(temps_resampled): temp_resamp_1d = np.c_[spec_wavs, temp_spec].T i0 = 0 for k_i, (k, v) in enumerate(beam_info.items()): for ib in v["list_idx"]: beam = beams_object[ib] if seg_maps.ndim == 3: beam_seg = seg_maps[ib] == seg_id else: beam_seg = seg_maps[seg_idx][ib] tmodel = beam.compute_model( spectrum_1d=temp_resamp_1d, thumb=beam.beam.direct * beam_seg, in_place=False, is_cgs=True, ) temps_arr[ (n_samples * seg_idx) + sample_i + temp_offset, i0 : i0 + np.prod(v["2d_shape"]), ] += tmodel temps_arr[ (n_samples * seg_idx) + sample_i + temp_offset, i0 : i0 + np.prod(v["2d_shape"]), ] /= len(v["list_idx"]) i0 += np.prod(v["2d_shape"]) if memmap: temps_arr.flush() except: raise Exception( shared_models_name + "".join(traceback.format_exception(*sys.exc_info())) ) @staticmethod def _gen_beam_templates_from_pipes( seg_idx=None, seg_id=None, shared_seg_name=None, seg_maps_shape=None, shared_models_name=None, models_shape=None, posterior_dir=None, spec_wavs=None, beam_info=None, cont_only=False, rm_line=None, rows=None, coeffs=None, memmap: bool = False, id_shifts=None, n_shifted_rows=None, return_line_flux=False, ): seg_id = int(seg_id) try: if memmap: seg_maps = np.memmap( shared_seg_name, dtype=np.float32, shape=seg_maps_shape, mode="r+" ) models_arr = np.memmap( shared_models_name, dtype=np.float32, shape=models_shape, mode="r+" ) else: shm_seg_maps = shared_memory.SharedMemory(name=shared_seg_name) seg_maps = np.ndarray( seg_maps_shape, dtype=np.float32, buffer=shm_seg_maps.buf ) shm_models = shared_memory.SharedMemory(name=shared_models_name) models_arr = np.ndarray( models_shape, dtype=np.float32, buffer=shm_models.buf ) with h5py.File(Path(posterior_dir) / f"{seg_id}.h5", "r") as post_file: samples2d = np.zeros( ( len(rows) + len(id_shifts) * n_shifted_rows, post_file["samples2d"].shape[1], ) ) samples2d[: len(rows), :] = np.array(post_file["samples2d"])[rows] # Just make one set of posterior samples from all possible seg ids if (id_shifts is not None) and (len(id_shifts) > 0): for s_i, s in enumerate(id_shifts): shifted_id = int((seg_id + s) % np.nanmax(seg_maps[0])) with h5py.File( Path(posterior_dir) / f"{shifted_id}.h5", "r" ) as post_file: samples2d[ int(len(rows) + s_i * n_shifted_rows) : int( len(rows) + (s_i + 1) * n_shifted_rows ) : ] = np.array(post_file["samples2d"])[rows[:n_shifted_rows]] if return_line_flux: line_fluxes = np.zeros(len(samples2d)) for sample_i, sample in enumerate(samples2d): if coeffs[f"bin_{seg_id}"][sample_i] == 0: continue out = pipes_sampler.sample( sample, spec_wavs=spec_wavs, cont_only=cont_only, rm_line=rm_line, return_line_flux=return_line_flux, ) if return_line_flux: temp_resamp_1d, line_flux = out line_fluxes[sample_i] = line_flux else: temp_resamp_1d = out i0 = 0 for k_i, (k, v) in enumerate(beam_info.items()): for ib in v["list_idx"]: if seg_maps.ndim == 3: beam_seg = seg_maps[ib] == seg_id else: beam_seg = seg_maps[seg_idx][ib] tmodel = ( beams_object[ib].compute_model( spectrum_1d=temp_resamp_1d, thumb=beams_object[ib].beam.direct * beam_seg, in_place=False, is_cgs=True, ) * coeffs[f"bin_{seg_id}"][sample_i] ) with lock_var: models_arr[i0 : i0 + np.prod(v["2d_shape"])] += tmodel i0 += np.prod(v["2d_shape"]) if memmap: models_arr.flush() if return_line_flux: return line_fluxes except: raise Exception( shared_models_name + "".join(traceback.format_exception(*sys.exc_info())) ) @staticmethod def _reduce_seg_map( seg_idx=None, seg_id=None, shared_input=None, init_shape=None, shared_output=None, output_shape=None, beam_idx=None, oversamp_factor=None, memmap: bool = False, ): if memmap: init_arr = np.memmap( shared_input, dtype=np.float32, shape=init_shape, mode="r+" ) output_arr = np.memmap( shared_output, dtype=np.float32, shape=output_shape, mode="r+" ) else: shm_init = shared_memory.SharedMemory(name=shared_input) init_arr = np.ndarray(init_shape, dtype=np.float32, buffer=shm_init.buf) shm_output = shared_memory.SharedMemory(name=shared_output) output_arr = np.ndarray( output_shape, dtype=np.float32, buffer=shm_output.buf ) output_arr[seg_idx, beam_idx, :, :] = block_reduce( init_arr == seg_id, oversamp_factor, func=np.mean, ) if memmap: output_arr.flush()
[docs] def fit_at_z( self, z: float = 0.0, fit_background: bool = True, poly_order: int = 0, n_samples: int = 3, n_iters: int = 10, bad_pa_threshold: float | None = 1.6, spec_wavs: ArrayLike | None = None, oversamp_factor: int = 1, veldisp: float = 500, direct_images: None = None, out_dir: PathLike | None = None, temp_dir: PathLike | None = None, memmap: bool = False, cpu_count: int = -1, overwrite: bool = False, use_lines: dict = CLOUDY_LINE_MAP, save_lines: bool = True, save_stacks: bool = True, pline: dict = DEFAULT_PLINE, seed: int = 2744, nnls_method: str = "scipy", nnls_iters: int = 100, nnls_tol: float = 1e-5, n_shifted: int = 2, n_shifted_samples: int = 1, cache_spec: bool = True, ): """ Fit the object at a specified redshift. Parameters ---------- z : float, optional The redshift at which the object will be fitted, by default 0. fit_background : bool, optional Fit a constant background level, by default ``True``. poly_order : int, optional Fit a polynomial function to the spectrum, with a default order of ``0``. n_samples : int, optional The number of samples to draw from the joint posterior distributions in each region, by default ``3``. n_iters : int, optional The number of iterations to perform when fitting, by default ``10``. bad_pa_threshold : float | None, optional The threshold for identifying bad PAs before fitting. By default ``1.6``, if ``None`` all beams will be used. spec_wavs : ArrayLike | None, optional The wavelength sampling to use when generating the template spectra from the `bagpipes` posterior distributions. The default value of ``None`` sets this to the :math:`0.96 - 2.3\\mu\\rm{m}` range in :math:`45\\mathring{A}` steps. oversamp_factor : int, optional The factor by which the region segmentation map is oversampled before reprojecting to the beam coordinate system. This significantly slows down the model generation, but is essential to ensure that the template spectra correspond to the correct pixels in the NIRISS detector frame, and defaults to a factor of ``1``. veldisp : float, optional The velocity dispersion of the template spectra in km/s, by default ``500``. direct_images : _type_, optional WIP, may allow for changing beam direct images at some point. By default ``None``. out_dir : PathLike | None, optional Where the output files will be written. If ``None`` (default), files will be written to ``self.out_dir/multiregion``. temp_dir : PathLike | None, optional The temporary directory to use for memmapped files (if ``memmap==True``). If ``None`` (default), the current working directory will be used. memmap : bool, optional Whether to use a memmap to store large files. By default ``False``. If ``True``, the large model array will be written to a temporary array on disk. cpu_count : int, optional The number of CPUs to use for multiprocessing, by default -1. overwrite : bool, optional If ``True``, overwrite any existing fit. By default ``False``, and will attempt to load a previous fit. use_lines : ArrayLike, optional A list of lines, for which a 2D map will be generated based on the multi-region fit. Each item in the list should be a ``dict``, containing the following keys: * ``"cloudy"`` : The name of one or more lines following the ``Cloudy`` nomenclature (`Ferland+17 <https://ui.adsabs.harvard.edu/abs/2017RMxAA..53..385F/abstract>`__). * ``"grizli"`` : The name of the emission line in `grizli`, or any other desired name. * ``"wave"`` : The rest-frame vacuum wavelength of the line. save_lines : bool, optional Save the drizzled emission line maps, matching the `grizli` output format. By default ``True``. save_stacks : bool, optional Save the stacked beams, full models, and continuum models. The output format differs from grizli in that these stacks are not drizzled to account for the subpixel shifts. By default ``True``. pline : dict, optional Parameters for generating the drizzled emission line maps. Defaults to `~niriss_tools.grism.DEFAULT_PLINE`. seed : int | None, optional The seed for the random sampling, by default 2744. If None, then a new seed will be generated each time this method is called. nnls_method : str, optional The method to use for finding the best-fit coefficients for the set of templates. Must be one of "scipy", "numba", or "adelie" (ordered in increasing speed). By default, "scipy" will be used. nnls_iters : int or ArrayLike, optional The maximum number of iterations to attempt if the NNLS solution has not converged to the specified tolerance. If more than one value is passed, a two-stage fit will be run, whereby the best-fit solution from the first ``n_iters`` attempts will be re-fit with ``nnls_iters[0]`` iterations. This is only relevant if ``nnls_method != "scipy"``. nnls_tol : float or ArrayLike, optional The desired tolerance for the NNLS solver. As with ``nnls_iters``, the second value can be used to perform a more precise fit after the initial ``n_iters`` attempts. n_shifted : int or None, optional This allows for drawing additional posterior samples from other regions of the object. Regions will be selected randomly, with no preference as to spatial or spectral coherence (they are selected by shifting each segmentation map id when generating the models, hence the name). This reduces the chance of template mismatch based on the SED fit. By default, 2 additional regions will be used. n_shifted_samples : int or None, optional This determines the number of samples drawn from each of the additional regions (i.e. the total number of samples is given by ``n_samples + n_shifted * n_shifted_samples``). By default, only 1 sample is drawn from each extra region. cache_spec : bool, optional Pre-generate the spectra before forward modelling. Can give a large speedup if running for many iterations, at the cost of additional disk space. Returns ------- tuple Exact form of return still WIP. """ try: import adelie print("Using `adelie` solver.") HAS_ADELIE = True except: HAS_ADELIE = False nnls_iters = np.atleast_1d(nnls_iters).astype(int) nnls_tol = np.atleast_1d(nnls_tol) TWO_STAGE = (len(nnls_iters) > 1) | (len(nnls_tol) > 1) if spec_wavs is None: spec_wavs = np.arange(10000.0, 23000.0, 22.5) if memmap: if temp_dir is None: temp_dir = Path.cwd() else: temp_dir = Path(temp_dir) temp_dir.mkdir(exist_ok=True, parents=True) if not out_dir: multireg_out_dir = self.out_dir / "multiregion" else: multireg_out_dir = self.out_dir / Path(out_dir) multireg_out_dir.mkdir(exist_ok=True, parents=True) if bad_pa_threshold is not None: out = self.MB.check_for_bad_PAs( chi2_threshold=bad_pa_threshold, poly_order=1, reinit=True, fit_background=True, ) fit_log, keep_dict, has_bad = out if has_bad: print(f"Has bad PA! Final list: {keep_dict}\n{fit_log}") self.MB.init_poly_coeffs(poly_order=poly_order) if fit_background: self.fit_bg = True A = np.vstack((self.MB.A_bg, self.MB.A_poly)) else: self.fit_bg = False A = self.MB.A_poly * 1 with h5py.File( self.pipes_dir / "posterior" / self.run_name / f"{int(self.regions_phot_cat["bin_id"][0])}.h5", "r", ) as test_post: fit_info_str = test_post.attrs["fit_instructions"] fit_info_str = fit_info_str.replace("array", "np.array") fit_info_str = fit_info_str.replace("float", "np.float") fit_info_str = fit_info_str.replace("np.np.", "np.") fit_instructions = eval(fit_info_str) n_post_samples = test_post["samples2d"].shape[0] # Try to allow for both memory and file-backed multiprocessing of # large arrays smm = SharedMemoryManager() smm.start() # If we are not oversampling, we can drastically reduce the memory usage # and only store the reprojected (multiregion) seg map for each beam. # If the segmentation map is not aligned to the beam direct images, and # we are oversampling to find the exact pixel overlap, the fastest # method for generating forward-modelled spectra is to pre-calculate # the segmentation map overlap for each region, and store this in a # shared memory array. if oversamp_factor == 1: oversamp_seg_maps_shape = ( self.MB.N, self.MB.beams[0].beam.sh[0], self.MB.beams[0].beam.sh[1], ) else: oversamp_seg_maps_shape = ( self.n_regions, self.MB.N, self.MB.beams[0].beam.sh[0], self.MB.beams[0].beam.sh[1], ) if memmap: oversamp_seg_maps = np.memmap( temp_dir / "memmap_oversamp_seg_maps.dat", dtype=np.float32, mode="w+", shape=oversamp_seg_maps_shape, ) else: shm_seg_maps = smm.SharedMemory( size=np.dtype(np.float32).itemsize * np.prod(oversamp_seg_maps_shape) ) oversamp_seg_maps = np.ndarray( oversamp_seg_maps_shape, dtype=np.float32, buffer=shm_seg_maps.buf, ) oversamp_seg_maps.fill(0.0) beam_info = {} start_idx = 0 oversampled_shape = ( oversamp_factor * self.MB.beams[0].beam.sh[0], oversamp_factor * self.MB.beams[0].beam.sh[1], ) if memmap: oversampled = np.memmap( temp_dir / "memmap_oversampled.dat", dtype=np.float32, mode="w+", shape=oversampled_shape, ) else: shm_oversampled = smm.SharedMemory( size=np.dtype(np.float32).itemsize * np.prod(oversampled_shape) ) oversampled = np.ndarray( oversampled_shape, dtype=np.float32, buffer=shm_oversampled.buf, ) oversampled.fill(np.nan) start_idx = 0 for i, (beam_cutout, cutout_shape) in enumerate( zip(self.MB.beams, self.MB.Nflat) ): beam_name = f"{beam_cutout.grism.pupil}-{beam_cutout.grism.filter}-{i}" if not beam_name in beam_info: beam_info[beam_name] = {} beam_info[beam_name]["2d_shape"] = beam_cutout.sh beam_info[beam_name]["list_idx"] = [] beam_info[beam_name]["flat_slice"] = [] beam_info[beam_name]["list_idx"].append(i) beam_info[beam_name]["flat_slice"].append( slice(int(start_idx), int(start_idx + cutout_shape)) ) start_idx += cutout_shape beam_wcs = deepcopy(beam_cutout.direct.wcs) beam_wcs = grizli_utils.transform_wcs(beam_wcs, scale=oversamp_factor) oversampled[:] = reproject_interp( (self.regions_seg_map, self.regions_seg_wcs), beam_wcs, oversampled_shape, return_footprint=False, order=0, ) if oversamp_factor == 1: oversamp_seg_maps[i] = oversampled continue with Pool(processes=cpu_count) as pool: multi_fn = partial( self._reduce_seg_map, shared_input=( temp_dir / "memmap_oversampled.dat" if memmap else shm_oversampled.name ), init_shape=oversampled_shape, shared_output=( temp_dir / "memmap_oversamp_seg_maps.dat" if memmap else shm_seg_maps.name ), output_shape=oversamp_seg_maps_shape, beam_idx=i, oversamp_factor=oversamp_factor, memmap=memmap, ) pool.starmap(multi_fn, enumerate(self.regions_seg_ids)) # Avoid any nan-related problems later oversamp_seg_maps[~np.isfinite(oversamp_seg_maps)] = 0 # The total number of templates NTEMP = self.n_regions * n_samples if n_shifted > 0: NTEMP += self.n_regions * n_shifted * n_shifted_samples num_stacks = len([*beam_info.keys()]) stacked_shape = np.nansum([np.prod(v["2d_shape"]) for v in beam_info.values()]) temp_offset = A.shape[0] - self.MB.N + num_stacks # This is the large array of models. Each row corresponds to a # (probably) unique template, forward-modelled across all beams, # and flattened. stacked_A_shape = (temp_offset + NTEMP, stacked_shape) print(f"{memmap=}\n\n") if memmap: stacked_A = np.memmap( temp_dir / "memmap_stacked_A.dat", dtype=np.float32, mode="w+", shape=stacked_A_shape, ) else: shm_stacked_A = smm.SharedMemory( size=np.dtype(np.float32).itemsize * np.prod(stacked_A_shape) ) stacked_A = np.ndarray( stacked_A_shape, dtype=np.float32, buffer=shm_stacked_A.buf, ) stacked_A.fill(0.0) # There are very few scenarios in which it makes sense to use the # individual beams for fitting. The computational requirements are # already considerable, and for GLASS-JWST, not stacking would mean ~6x # more memory, and at least that in computation time. stacked_scif = np.zeros(stacked_shape) stacked_ivarf = np.zeros(stacked_shape) stacked_weightf = np.zeros(stacked_shape) stacked_fit_mask = np.zeros(stacked_shape, dtype=bool) start_idx = 0 for k_i, (k, v) in enumerate(beam_info.items()): stack_idxs = np.r_["0,2", *v["flat_slice"]] stacked_scif[start_idx : start_idx + np.prod(v["2d_shape"])] = np.nanmedian( self.MB.scif[stack_idxs], axis=0, ) stacked_weightf[start_idx : start_idx + np.prod(v["2d_shape"])] = ( np.nanmedian( self.MB.weightf[stack_idxs], axis=0, ) ) stacked_fit_mask[start_idx : start_idx + np.prod(v["2d_shape"])] = np.any( self.MB.fit_mask[stack_idxs], axis=0, ) if fit_background: stacked_A[k_i, start_idx : start_idx + np.prod(v["2d_shape"])] = 1.0 stacked_A[ num_stacks : num_stacks + self.MB.A_poly.shape[0], start_idx : start_idx + np.prod(v["2d_shape"]), ] = np.nanmean(self.MB.A_poly[:, stack_idxs], axis=1) else: stacked_A[ num_stacks : num_stacks + self.MB.A_poly.shape[0], start_idx : start_idx + np.prod(v["2d_shape"]), ] = np.nanmean(self.MB.A_poly[:, stack_idxs], axis=1) stacked_ivarf[start_idx : start_idx + np.prod(v["2d_shape"])] = ( np.nanmedian(self.MB.ivarf[stack_idxs], axis=0) ) start_idx += np.prod(v["2d_shape"]) stacked_fit_mask &= np.isfinite(stacked_scif) DoF = int((stacked_weightf * stacked_fit_mask).sum()) # Allow for background fitting by including an offset if fit_background: pedestal = 0.04 else: pedestal = 0.0 y = stacked_scif[stacked_fit_mask] + pedestal y *= np.sqrt(stacked_ivarf[stacked_fit_mask]) y = y.astype(np.float32) stacked_sivarf_masked = np.sqrt(stacked_ivarf[stacked_fit_mask]) # TODO: make the output name a parameter? self.output_table_path = multireg_out_dir / ( f"{self.obj_id}_{len(np.unique(self.regions_phot_cat["bin_id"]))}" f"bins_{n_iters}iters_{n_samples}samples_z_{z}_sig_{veldisp}.ecsv" ) # The function to produce the templates - only a couple of # parameters change on each iteration. stacked_fn = partial( self._gen_stacked_templates_from_pipes, shared_seg_name=( temp_dir / "memmap_oversamp_seg_maps.dat" if memmap else shm_seg_maps.name ), seg_maps_shape=oversamp_seg_maps_shape, shared_models_name=( temp_dir / "memmap_stacked_A.dat" if memmap else shm_stacked_A.name ), models_shape=stacked_A_shape, posterior_dir=str(self.pipes_dir / "posterior" / self.run_name), n_samples=n_samples + (n_shifted * n_shifted_samples), spec_wavs=spec_wavs, beam_info=beam_info, temp_offset=temp_offset, cont_only=False, rm_line=None, memmap=memmap, n_shifted_rows=n_shifted_samples, ) if cache_spec: pre_gen_spec( self.pipes_dir, fit_instructions, spec_wavs=spec_wavs, veldisp=veldisp, run=self.run_name, cpu_count=cpu_count, ) pool_kwargs = dict( processes=cpu_count, initializer=_init_beams, initargs=(self.MB.beams,), ) else: pool_kwargs = dict( processes=cpu_count, initializer=_init_pipes_sampler, initargs=( fit_instructions, veldisp, self.MB.beams, ), ) # These column names should be fixed for all objects init_col_names = [ "iteration", "chi2", "max_nnls_iters", "solve_nnls_iters", "nnls_tol", "rows", "id_shifts", "n_shifted_samples", "fitting_time", "total_time", "unique_temp", ] total_iters = n_iters + int(TWO_STAGE) # Construct the table if it doesn't already exist try: output_table = Table.read(self.output_table_path, format="ascii.ecsv") assert np.logical_not(overwrite) except: output_table = Table( [ np.zeros(total_iters), np.full(total_iters, np.nan), *np.zeros((3, total_iters)), np.zeros((total_iters, n_samples), dtype=int), np.zeros((total_iters, n_shifted), dtype=int), *np.zeros((4, total_iters)), *np.zeros((temp_offset, total_iters)), *np.zeros( ( self.n_regions, total_iters, n_samples + (n_shifted * n_shifted_samples), ) ), ], names=init_col_names + [f"base_coeffs_{b}" for b in np.arange(temp_offset)] + [f"bin_{p}" for p in self.regions_phot_cat["bin_id"]], dtype=[int, float, int, int, float, int, int, int, float, float, int] + [float] * temp_offset + [float] * self.n_regions, ) # Check if seed was previously set, write to table if not seed = output_table.meta.get("RNGSEED", [seed])[0] output_table.meta["RNGSEED"] = (seed, "Random seed") output_table.meta["ID"] = (self.obj_id, "Object ID") output_table.meta["RA"] = (self.ra, "Right Ascension") output_table.meta["DEC"] = (self.dec, "Declination") output_table.meta["Z"] = (z, "Best-fit redshift") output_table.meta["DOF"] = (DoF, "Degrees of freedom (active pixels)") # Two RNG, so can compare with and without extra shift samples rng = np.random.default_rng(seed=seed) rng_shifts = np.random.default_rng(seed=seed) # Check if the table length matches the expected number of iterations n_prev_iters = np.max( (output_table["iteration"] + 1)[np.isfinite(output_table["chi2"])], initial=0, ) if n_prev_iters < total_iters: # If there are previous iterations, sample from the RNGs so the # seed order is preserved for x in np.arange(n_prev_iters): rows = rng.choice( np.arange(n_post_samples, dtype=int), size=n_samples, replace=False, ) id_shifts = rng_shifts.choice( np.arange(self.n_regions, dtype=int), size=n_shifted, replace=False if n_shifted <= self.n_regions else True, ) remaining_iters = total_iters - n_prev_iters output_table.write( self.output_table_path, overwrite=True, format="ascii.ecsv" ) iterations = np.arange(n_prev_iters, total_iters) for iteration in iterations: try: curr_line = ( f"Minimum chi2: {np.nanmin(output_table["chi2"]):.3f}" f"\t\t(Iteration {np.nanargmin(output_table["chi2"])})" ) except: curr_line = "Minimum chi2: ---" log_with_offset("", curr_line=curr_line) # On the final iteration, reuse the samples from the current # best-fit solution if TWO_STAGE and (iteration == iterations[-1]): best_iter = np.nanargmin(output_table["chi2"]) rows = [int(s) for s in output_table["rows"][best_iter]] id_shifts = [int(s) for s in output_table["id_shifts"][best_iter]] else: rows = rng.choice( np.arange(n_post_samples, dtype=int), size=n_samples, replace=False, ) id_shifts = rng_shifts.choice( np.arange(self.n_regions, dtype=int), size=n_shifted, replace=False if n_shifted <= self.n_regions else True, ) log_with_offset(f"Iteration {iteration}, {rows=}", curr_line=curr_line) t0 = time() # Generate the forward-modelled spectra log_with_offset(f"Generating models...", curr_line=curr_line) with multiprocessing.Pool(**pool_kwargs) as pool: for s_i, s in enumerate(self.regions_seg_ids): pool.apply_async( stacked_fn, (s_i, s), kwds={ "rows": rows, "id_shifts": id_shifts, "spectral_dir": ( str(self.pipes_dir / "spec" / self.run_name) if cache_spec else None ), }, error_callback=print, ) pool.close() pool.join() t1 = time() # print("\r" + f"Generating models... DONE {t1-t0:.3f}s", flush=True) log_with_offset( LINE_UP + f"Generating models... DONE in {t1-t0:.3f}s", curr_line=curr_line, ) # Remove any negative or zero templates ok_temp = np.sum(stacked_A, axis=1) > 0 # We need to remove duplicate templates so the NNLS solvers can # actually converge. Until such time as np.unique implements # a hash map to allow for speeding up `return_index`, we just # skip over every 199 pixels stacked_A_contig = np.ascontiguousarray(stacked_A[:, ::199]) # np.unique() finds identical items in a raveled array. To make it # see each row as a single item, we create a view of each row as a # byte string of length itemsize times number of columns in `ar` ar_row_view = stacked_A_contig.view( "|S%d" % (stacked_A_contig.itemsize * stacked_A_contig.shape[1]) ) _, unique_idxs = np.unique(ar_row_view, return_index=True) unique_temp = np.isin(np.arange(stacked_A.shape[0]), unique_idxs) del stacked_A_contig # Select only the unique templates ok_temp &= unique_temp out_coeffs = np.zeros(stacked_A.shape[0]) # Transpose the template array # stacked_Ax = stacked_A[:, stacked_fit_mask][ok_temp, :].T stacked_Ax = stacked_A[np.ix_(ok_temp, stacked_fit_mask)].T # stacked_Ax *= np.sqrt(stacked_ivarf[stacked_fit_mask][:, np.newaxis]) stacked_Ax *= stacked_sivarf_masked[:, np.newaxis] # Change the max iters and tolerance for the final iteration if TWO_STAGE and (iteration == iterations[-1]): log_with_offset("Final iteration", curr_line=curr_line) _nnls_i = nnls_iters[1] _nnls_t = nnls_tol[1] else: _nnls_i = nnls_iters[0] _nnls_t = nnls_tol[0] # print("NNLS fitting...", end="") log_with_offset("NNLS fitting... ", curr_line=curr_line) # Three different methods of fitting, each with different call # signatures and return values if nnls_method == "adelie" and HAS_ADELIE: state = adelie.solver.bvls( stacked_Ax, y, lower=np.zeros(stacked_Ax.shape[-1], dtype=np.float32), upper=np.full(stacked_Ax.shape[-1], np.inf, dtype=np.float32), max_iters=_nnls_i, tol=_nnls_t, n_threads=1, # Inter-thread communication is actually slower ) state.solve() state_iters = deepcopy(state.iters) coeffs = deepcopy(state.beta) coeffs[:num_stacks] -= pedestal del state elif nnls_method == "numba": nnls_solver = CDNNLS(stacked_Ax, y) nnls_solver.run(n_iter=_nnls_i, epsilon=_nnls_t) coeffs = nnls_solver.w coeffs[:num_stacks] -= pedestal elif nnls_method == "fnnls": coeffs = fnnls( stacked_Ax, y, tolerance=_nnls_t, max_iterations=_nnls_i, ) coeffs[:num_stacks] -= pedestal elif nnls_method == "fennls": coeffs = fennls( stacked_Ax, y, tolerance=_nnls_t, max_iterations=_nnls_i, ) coeffs[:num_stacks] -= pedestal else: coeffs, rnorm, info = scipy.optimize._nnls._nnls( stacked_Ax, y, _nnls_i ) coeffs[:num_stacks] -= pedestal t2 = time() # print("\r" + f"NNLS fitting... DONE {t2-t1:.3f}s", flush=True) log_with_offset( LINE_UP + f"NNLS fitting... DONE in {t2-t1:.3f}s", curr_line=curr_line, ) out_coeffs[ok_temp] = coeffs stacked_modelf = np.dot(out_coeffs, stacked_A) chi2 = np.nansum( ( stacked_weightf * (stacked_scif - stacked_modelf) ** 2 * stacked_ivarf )[stacked_fit_mask] ) # output_table.add_row( # [ # iteration, # chi2, # _nnls_i, # state_iters if (nnls_method == "adelie" and HAS_ADELIE) else 0, # _nnls_t, # rows, # id_shifts, # n_shifted_samples, # t2 - t1, # time() - t0, # ok_temp.sum(), # *out_coeffs[:temp_offset], # *out_coeffs[temp_offset:].reshape(self.n_regions, -1), # ] # ) output_table[iteration] = [ iteration, chi2, _nnls_i, state_iters if (nnls_method == "adelie" and HAS_ADELIE) else 0, _nnls_t, rows, id_shifts, n_shifted_samples, t2 - t1, time() - t0, ok_temp.sum(), *out_coeffs[:temp_offset], *out_coeffs[temp_offset:].reshape(self.n_regions, -1), ] output_table.write(self.output_table_path, overwrite=True) # print(f"Iteration {iteration}: chi2={chi2:.3f}\n") log_with_offset( f"Iteration {iteration}: chi2={chi2:.3f}", curr_line=curr_line ) # Reset the template array stacked_A[temp_offset:].fill(0.0) del stacked_Ax # There must be a better way to obtain the coefficients, but # slicing tables is not entirely straightforward best_iter = np.argmin(output_table["chi2"]) out_coeffs = np.asarray( [ i for d in output_table[best_iter][len(init_col_names) :] for i in np.atleast_1d(d) ] ).ravel() # Repopulate the background parameters for MultiBeam if fit_background: for k_i, (k, v) in enumerate(beam_info.items()): for ib in v["list_idx"]: self.MB.beams[ib].background = out_coeffs[k_i] best_rows = [int(s) for s in output_table["rows"][best_iter]] best_id_shifts = [int(s) for s in output_table["id_shifts"][best_iter]] # Refill array with best model print("Calculating covariance array...") with multiprocessing.Pool(**pool_kwargs) as pool: for s_i, s in enumerate(self.regions_seg_ids): pool.apply_async( stacked_fn, (s_i, s), kwds={ "rows": best_rows, "id_shifts": best_id_shifts, "spectral_dir": ( str(self.pipes_dir / "spec" / self.run_name) if cache_spec else None ), }, error_callback=print, ) pool.close() pool.join() stacked_Ax = stacked_A[:, stacked_fit_mask] ok_temp = (np.sum(stacked_Ax, axis=1) > 0) & (out_coeffs != 0) stacked_Ax = stacked_Ax[ok_temp, :].T * 1 stacked_Ax *= np.sqrt(stacked_ivarf[stacked_fit_mask][:, np.newaxis]) try: covar = grizli_utils.safe_invert(np.dot(stacked_Ax.T, stacked_Ax)) except: N = ok_temp.sum() covar = np.zeros((N, N)) covard = np.sqrt(covar.diagonal()) coeffs_errs = out_coeffs * 0.0 coeffs_errs[ok_temp] = covard chi2nu = output_table["chi2"][best_iter] / ( DoF - output_table["unique_temp"][best_iter] ) # Ensure that the array is cleaned before repopulating stacked_A[temp_offset:].fill(0.0) # Largely unmodified from the original grizli code. Included within # this particular class method to avoid dealing with SharedMemory if save_stacks: print("Generating models...") with multiprocessing.Pool( **pool_kwargs, ) as pool: for s_i, s in enumerate(self.regions_seg_ids): pool.apply_async( stacked_fn, (s_i, s), kwds={ "rows": best_rows, "id_shifts": best_id_shifts, "spectral_dir": ( str(self.pipes_dir / "spec" / self.run_name) if cache_spec else None ), }, error_callback=print, ) pool.close() pool.join() stacked_modelf = np.dot(out_coeffs, stacked_A) stacked_A[temp_offset:].fill(0.0) print("Generating continuum...") with multiprocessing.Pool( processes=cpu_count, initializer=_init_pipes_sampler, initargs=( fit_instructions, veldisp, self.MB.beams, ), ) as pool: for s_i, s in enumerate(self.regions_seg_ids): pool.apply_async( stacked_fn, (s_i, s), kwds={ "rows": best_rows, "id_shifts": best_id_shifts, "cont_only": True, }, ) pool.close() pool.join() stacked_contf = np.dot(out_coeffs, stacked_A) stacked_hdul = fits.HDUList(fits.PrimaryHDU()) start_idx = 0 for i, (k, v) in enumerate(beam_info.items()): slice_plot = slice(start_idx, start_idx + np.prod(v["2d_shape"])) hdus = [ fits.ImageHDU( data=stacked_scif[slice_plot].reshape(v["2d_shape"]), name="SCI", ), fits.ImageHDU( data=stacked_weightf[slice_plot].reshape(v["2d_shape"]), name="WHT", ), fits.ImageHDU( data=stacked_ivarf[slice_plot].reshape(v["2d_shape"]), name="IVAR", ), fits.ImageHDU( data=(stacked_fit_mask[slice_plot] * 1.0).reshape( v["2d_shape"] ), name="MASK", ), fits.ImageHDU( data=stacked_modelf[slice_plot].reshape(v["2d_shape"]), name="MODEL", ), fits.ImageHDU( data=stacked_contf[slice_plot].reshape(v["2d_shape"]), name="CONT", ), ] for h in hdus: h.header["EXTVER"] = k h.header["RA"] = (self.ra, "Right ascension") h.header["DEC"] = (self.dec, "Declination") h.header["GRISM"] = (k.split("_")[0], "Grism") h.header["CONF"] = ( self.MB.beams[0].beam.conf.conf_file, "Configuration file", ) h.header["REDSHIFT"] = (z, "Redshift used") h.header["CHI2"] = ( output_table["chi2"][best_iter], "Chi^2 statistic", ) h.header["DOF"] = (DoF, "Degrees of freedom (active pixels)") h.header["NTEMP"] = ( output_table["unique_temp"][best_iter], "Number of unique templates", ) h.header["CHI2NU"] = (chi2nu, "Reduced chi^2 statistic") h.header = self.add_pipes_info(h.header) stacked_hdul.extend(hdus) start_idx += np.prod(v["2d_shape"]) stacked_hdul.writeto( multireg_out_dir / f"regions_{self.obj_id:05d}_z_{z}_stacked.fits", output_verify="silentfix", overwrite=True, ) if save_lines: beam_models_len = 0 for k_i, (k, v) in enumerate(beam_info.items()): beam_models_len += np.prod(v["2d_shape"]) * len(v["list_idx"]) if memmap: flat_beam_models = np.memmap( temp_dir / "memmap_beams_model.dat", dtype=np.float32, mode="w+", shape=(beam_models_len), ) else: shm_beam_models = smm.SharedMemory( size=np.dtype(np.float32).itemsize * beam_models_len ) flat_beam_models = np.ndarray( (beam_models_len), dtype=np.float32, buffer=shm_beam_models.buf, ) line_hdu = None saved_lines = [] beams_fn = partial( self._gen_beam_templates_from_pipes, shared_seg_name=( temp_dir / "memmap_oversamp_seg_maps.dat" if memmap else shm_seg_maps.name ), seg_maps_shape=oversamp_seg_maps_shape, shared_models_name=( temp_dir / "memmap_beams_model.dat" if memmap else shm_beam_models.name ), models_shape=flat_beam_models.shape, posterior_dir=str(self.pipes_dir / "posterior" / self.run_name), spec_wavs=spec_wavs, beam_info=beam_info, cont_only=False, rows=best_rows, coeffs=output_table[best_iter], memmap=memmap, n_shifted_rows=n_shifted_samples, return_line_flux=True, ) lock = Lock() for l_i, l_v in enumerate(use_lines): if not check_coverage(l_v["wave"] * (1 + z)): continue print(f"Generating map for {l_v["grizli"]}...") add_hdu = None for continuum_temp in [True, False]: flat_beam_models.fill(0.0) results = [None] * self.n_regions with multiprocessing.Pool( processes=cpu_count, initializer=_init_pipes_sampler, initargs=(fit_instructions, veldisp, self.MB.beams, lock), ) as pool: for s_i, s in enumerate(self.regions_phot_cat["bin_id"]): results[s_i] = pool.apply_async( beams_fn, args=(s_i, s), kwds={ "rm_line": ( l_v["cloudy"] if continuum_temp else None ), "id_shifts": best_id_shifts, }, error_callback=print, ) results = np.asarray([r.get() for r in results]).ravel() pool.close() pool.join() i0 = 0 start_idx = 0 for k_i, (k, v) in enumerate(beam_info.items()): for ib in v["list_idx"]: self.MB.beams[ib].beam.model = flat_beam_models[ i0 : i0 + np.prod(v["2d_shape"]) ].reshape(v["2d_shape"]) i0 += np.prod(v["2d_shape"]) start_idx += np.prod(v["2d_shape"]) if not continuum_temp: for b_i, (b, b_old) in enumerate( zip(self.MB.beams, beams_copy) ): self.MB.beams[b_i].beam.model -= b_old hdu = drizzle_to_wavelength( self.MB.beams, ra=self.ra, dec=self.dec, wave=l_v["wave"] * (1 + z), fcontam=self.MB.fcontam, **pline, ) hdu[0].header["REDSHIFT"] = (z, "Redshift used") hdu[0].header["CHI2"] = ( output_table["chi2"][best_iter], "Chi^2 statistic", ) hdu[0].header["DOF"] = (DoF, "Degrees of freedom (active pixels)") hdu[0].header["NTEMP"] = ( output_table["unique_temp"][best_iter], "Number of unique templates", ) hdu[0].header["CHI2NU"] = (chi2nu, "Reduced chi^2 statistic") hdu[0].header = self.add_pipes_info(hdu[0].header) for e in [-4, -3, -2, -1]: hdu[e].header["EXTVER"] = l_v["grizli"] hdu[e].header["REDSHIFT"] = (z, "Redshift used") hdu[e].header["RESTWAVE"] = ( l_v["wave"], "Line rest wavelength", ) if add_hdu is None: add_hdu = hdu beams_copy = [b.beam.model.copy() for b in self.MB.beams] line_sn = np.nansum( results * out_coeffs[temp_offset:] ) / np.sqrt( np.nansum((results * coeffs_errs[temp_offset:]) ** 2) ) else: hdu[-3].header["EXTNAME"] = "MODEL" add_hdu.append(hdu[-3]) line_flux_i = np.nansum(hdu[-3].data) * 1e-17 line_err_i = line_flux_i / line_sn saved_lines.append(l_v["grizli"]) if line_hdu is None: line_hdu = add_hdu line_hdu[0].header["NUMLINES"] = ( 1, "Number of lines in this file", ) else: line_hdu.extend(add_hdu[-5:]) line_hdu[0].header["NUMLINES"] += 1 # Make sure DSCI extension is filled. Can be empty for # lines at the edge of the grism throughput for f_i in range(hdu[0].header["NDFILT"]): filt_i = hdu[0].header["DFILT{0:02d}".format(f_i + 1)] if hdu["DWHT", filt_i].data.max() != 0: line_hdu["DSCI", filt_i] = hdu["DSCI", filt_i] line_hdu["DWHT", filt_i] = hdu["DWHT", filt_i] li = line_hdu[0].header["NUMLINES"] line_hdu[0].header["LINE{0:03d}".format(li)] = l_v["grizli"] line_hdu[0].header["FLUX{0:03d}".format(li)] = ( line_flux_i, "Line flux, erg/s/cm2", ) line_hdu[0].header["ERR{0:03d}".format(li)] = ( line_err_i, "Line flux err, erg/s/cm2", ) if line_hdu is not None: line_hdu[0].header["HASLINES"] = ( " ".join(saved_lines), "Lines in this file", ) line_wcs = WCS(line_hdu[1].header) segm = self.MB.drizzle_segmentation(wcsobj=line_wcs) seg_hdu = fits.ImageHDU(data=segm.astype(np.int32), name="SEG") line_hdu.insert(1, seg_hdu) line_hdu.writeto( multireg_out_dir / f"regions_{self.obj_id:05d}_z_{z}_{pline.get("pixscale", 0.06)}arcsec.line.fits", output_verify="silentfix", overwrite=True, ) if "DSCI" in line_hdu: from grizli.fitting import show_drizzled_lines # s, si = 1, line_size s = 4.0e-19 / np.max( [beam.beam.total_flux for beam in self.MB.beams] ) s = np.clip(s, 0.25, 4) s /= (pline.get("pixscale", 0.06) / 0.1) ** 2 scale_linemap = 1 if scale_linemap < 0: s = -1 dscale = 1.0 / 4 fig = show_drizzled_lines( line_hdu, size_arcsec=1.6, cmap="plasma_r", scale=s * scale_linemap, dscale=s * dscale * scale_linemap, full_line_list=[ "Lya", "OII", "Hb", "OIII-5007", "Ha", "SII", "SIII-9068", "SIII-9531", ], ) fig.savefig( multireg_out_dir / f"regions_{self.obj_id:05d}_z_{z}_{pline.get("pixscale", 0.06)}arcsec.line.png", ) smm.shutdown() return