Source code for analysis.band_fitting

#!/usr/bin/env python3.7
# -*- coding: UTF-8 -*-

"""The ``band_fitting`` module performs fits of individual bands of observed light-curves

Function Documentation
----------------------
"""

from copy import deepcopy
from pathlib import Path

import numpy as np
import sncosmo
from astropy.table import Table
from matplotlib import pyplot
from tqdm import tqdm

from . import utils

DUST = sncosmo.F99Dust()


[docs]def create_empty_table(parameters, **kwargs): """Create an empty table for storing fit results Columns: - obj_id - band - source - pre_max - post_max - num_params - *parameters - *parameters + _err - chisq - ndof - b_max - delta_15 - message Args: parameters (iter): List of parameter names to add columns for Any arguments to pass ``astropy.Table`` Returns: A masked astropy Table """ # Specify column names names = ['obj_id', 'band', 'source', 'pre_max', 'post_max', 'vparams'] names += list(parameters) + [param + '_err' for param in parameters] names += ['chisq', 'ndof', 'b_max', 'delta_15', 'message'] # Specify column data types dtype = ['U20', 'U100', 'U100', int, int, 'U100'] dtype += [float for _ in range(2 * len(parameters))] dtype += [float, float, float, float, 'U10000'] # Unless otherwise specified, we default to returning a masked table kwargs = deepcopy(kwargs) kwargs.setdefault('masked', True) return Table(names=names, dtype=dtype, **kwargs)
[docs]def fit_results_to_dict(data, obj_id, band_name, results, fitted_model): """Format sncosmo fit results so they can be appended to an astropy table See the ``create_empty_table`` function for information on the assumed table format. Args: data (Table): The data used in the fit obj_id (str): The id of the object that was fit band_name (str): The name of the band that was fit results (Result): Fitting results returned by ``sncosmo`` fitted_model (Model): A fitted ``sncosmo`` model Returns: Fit results as a dictionary """ new_row = { 'obj_id': obj_id, 'band': band_name, 'source': fitted_model.source.name, 'vparams': ','.join(results.vparam_names) } # Determine number of points pre and post maximum t0 = results.parameters[results.param_names.index('t0')] new_row['pre_max'] = sum(data['time'] < t0) new_row['post_max'] = sum(data['time'] >= t0) # Add parameters and their errors params = {p: v for p, v in zip(results.param_names, results.parameters)} new_row.update(params) for param, error in results.errors.items(): new_row[param + '_err'] = error # Calc chi-squared chisq, ndof = utils.calc_model_chisq(data, results, fitted_model) new_row['chisq'] = np.round(chisq, 2) new_row['ndof'] = ndof # Determine peak magnitude and decline rate b_max = fitted_model.source_peakabsmag('bessellb', 'ab') new_row['b_max'] = np.round(b_max, 2) peak_phase = fitted_model.source.peakphase('bessellb') b_0 = fitted_model.source.bandmag('bessellb', 'ab', peak_phase) b_15 = fitted_model.source.bandmag('bessellb', 'ab', peak_phase + 15) delta_15 = b_15 - b_0 new_row['delta_15'] = np.round(delta_15, 3) # Add fitting exit status message. Not all fitting routines include # this attribute, so we assign a default value of 'NONE'. message = getattr(results, 'message', 'NONE') new_row['message'] = message return new_row
def _plot_lc(data, result, fitted_model): """Plot fit results Args: data (Table): The data used in the fit result (Result): The fit results fitted_model (Model): Model with params set to fitted values """ fig = sncosmo.plot_lc(data, fitted_model, errors=result.errors) xs, d = utils.calc_model_chisq(data, result, fitted_model) print(f'chisq / ndof = {xs} / {d} = {xs / d}', flush=True) return fig
[docs]def fit_single_target( fit_func, data, model, priors=None, kwargs=None, out_table=None, show_plots=False): """Run fits to individual bands of an observed light-curves Args: data (Table): Table of photometric data fit_func (func): Function to use to run fits (eg. ``sncosmo.fit_lc``) model (Model): The model to use when fitting priors (dict): Priors to use when fitting out_table (Table): Append results to an existing table kwargs (dict): Kwargs to pass ``fit_func`` when fitting show_plots (bool): Plot and display each individual fit Returns: A table with results each model / dataset combination """ obj_id = data.meta['obj_id'] model.update(priors) kwargs = deepcopy(kwargs) if out_table is None: out_table = create_empty_table(model.param_names) # Fit data in all bands all_result, all_fit = fit_func(data, model, **kwargs) all_row = fit_results_to_dict(data, obj_id, 'all', all_result, all_fit) out_table.add_row(all_row) if show_plots: _plot_lc(data, all_result, all_fit) pyplot.show() # Fix t0 and redshift during individual band fits kwargs['vparam_names'] = set(kwargs['vparam_names']) - {'t0', 'z'} # The amplitude from a fit to all data works as a better initial guess kwargs['guess_amplitude'] = False # Fit data in individual bands data = data.group_by('band') for band_name, band_data in zip(data.groups.keys['band'], data.groups): band_result, band_fit = fit_func(band_data, all_fit, **kwargs) band_row = fit_results_to_dict( band_data, obj_id, band_name, band_result, band_fit) out_table.add_row(band_row) if show_plots: _plot_lc(band_data, band_result, band_fit) return out_table
def _tabulate_fits_for_model( data_iter, model, config, fit_func, out_table, out_path=None): """Tabulate fit results for a collection of data tables and a single model Results are appended to the table specified by ``out_table``. Any objects with Ids already present in this table are skipped. Args: data_iter (iter): Iterable of photometric data for different SN model (Model): The model to use when fitting fit_func (func): Function to use to run fits out_table (Table): Table to append results to out_path (str): Optionally cache results to file in real time """ for data in data_iter: # Get fitting priors and kwargs obj_id = data.meta['obj_id'] if obj_id in out_table['obj_id']: continue try: fit_single_target( fit_func, data, model, priors=config[obj_id]['priors'], kwargs=config[obj_id]['kwargs'], out_table=out_table ) except KeyboardInterrupt: raise except Exception as e: raise e_str = str(e).replace("\n", "") e_name = type(e).__name__ out_table.add_row({ 'obj_id': obj_id, 'message': f'{e_name}: {e_str}' }) if out_path: out_table.write(out_path)
[docs]def tabulate_band_fits( data_release, models, fit_func, config=None, out_path=None): """Tabulate fit results for a collection of data tables Results already written to out_path are skipped. Args: data_release (module): An sndata data release models (list): A list of sncosmo models fit_func (func): Function to use to run fits config (dict): Specifies priors / kwargs for fitting each model out_path (str): Optionally cache results to file in real time Returns: An astropy table with fit results """ # Set default kwargs config = deepcopy(config) or dict() # Add meta_data to output table meta data if Path(out_path).exists(): out_table = Table.read(out_path) else: params = set() for m in models: params = params.union(m.param_names) out_table = create_empty_table(params) out_table.meta['fit_func'] = fit_func.__name__ for model in tqdm(models, desc='Models'): data_iter = data_release.iter_data( verbose={'desc': 'Targets', 'position': 1}, filter_func=utils.filter_has_csp_data) _tabulate_fits_for_model( data_iter, model, config, fit_func, out_table, out_path) return out_table