__author__ = 'aymgal', 'lynevdv', 'gvernard'
import os
import copy
import logging
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, LogNorm, TwoSlopeNorm
from matplotlib.colors import ListedColormap
from getdist import plots, chains, MCSamples
from coolest.api.analysis import Analysis
from coolest.api.composable_models import *
from coolest.api import util
from coolest.api import plot_util as plut
import pandas as pd
# matplotlib global settings
plt.rc('image', interpolation='none', origin='lower') # imshow settings
# logging settings
logging.getLogger().setLevel(logging.INFO)
# TODO: separate ParametersPlotter from ModelPlotter to avoid dependencies on getdist
__all__ = [
'ModelPlotter',
'MultiModelPlotter',
'ParametersPlotter',
]
[docs]
class ModelPlotter(object):
"""Create pyplot panels from a lens model stored in the COOLEST format.
Parameters
----------
coolest_object : COOLEST
COOLEST instance
coolest_directory : str, optional
Directory which contains the COOLEST template, by default None
color_bad_values : str, optional
Color assigned to NaN values (typically negative values in log-scale),
by default '#111111' (dark gray)
"""
def __init__(self, coolest_object, coolest_directory=None,
color_bad_values='#222222'):
[docs]
self.coolest = coolest_object
self._directory = coolest_directory
[docs]
self.cmap_flux = copy.copy(plt.get_cmap('magma'))
self.cmap_flux.set_bad(color_bad_values)
[docs]
self.cmap_mag = plt.get_cmap('viridis')
[docs]
self.cmap_conv = plt.get_cmap('cividis')
[docs]
self.cmap_res = plt.get_cmap('RdBu_r')
#cmap_colors = self.cmap_flux(np.linspace(0, 1, 256))
#cmap_colors[0,:] = [0.15, 0.15, 0.15, 1.0] # Set the color of the very first value to gray
#self.cmap_flux_mod = ListedColormap(cmap_colors)
[docs]
def plot_data_image(self, ax, title=None, norm=None, cmap=None, xylim=None,
neg_values_as_bad=False, add_colorbar=True,
add_scalebar=True, scalebar_size=1):
"""plt.imshow panel with the data image"""
if cmap is None:
cmap = self.cmap_flux
coordinates = util.get_coordinates(self.coolest)
extent = coordinates.plt_extent
image = self.coolest.observation.pixels.get_pixels(directory=self._directory)
ax, im = plut.plot_regular_grid(ax, title, image, extent=extent,
cmap=cmap, norm=norm,
neg_values_as_bad=neg_values_as_bad,
xylim=xylim)
if add_colorbar:
cb = plut.nice_colorbar(im, ax=ax)
cb.set_label("flux")
if add_scalebar:
plut.scale_bar(ax, scalebar_size, color='white', loc='lower right')
return image
[docs]
def plot_surface_brightness(self, ax, title = None, coordinates=None,
extent_irreg=None, norm=None, cmap=None,
xylim=None, neg_values_as_bad=True,
plot_points_irreg=False, add_colorbar=True,
add_scalebar=False, scalebar_size=0.4,
kwargs_light=None,
plot_caustics=None, caustics_color='white', caustics_alpha=0.5,
coordinates_lens=None, kwargs_lens_mass=None):
"""plt.imshow panel showing the surface brightness of the (unlensed)
lensing entity selected via kwargs_light (see ComposableLightModel docstring)"""
if extent_irreg is not None:
raise ValueError("`extent_irreg` is deprecated; use `xylim` instead.")
if kwargs_light is None:
kwargs_light = {}
light_model = ComposableLightModel(self.coolest, self._directory, **kwargs_light)
if plot_caustics:
if kwargs_lens_mass is None:
raise ValueError("`kwargs_lens_mass` must be provided to compute caustics")
if coordinates_lens is None:
coordinates_lens = util.get_coordinates(self.coolest).create_new_coordinates(pixel_scale_factor=0.1)
# NOTE: here we assume that `kwargs_light` is for the source!
mass_model = ComposableMassModel(self.coolest, self._directory, **kwargs_lens_mass)
_, caustics = util.find_all_lens_lines(coordinates_lens, mass_model)
if cmap is None:
cmap = self.cmap_flux
if coordinates is not None:
x, y = coordinates.pixel_coordinates
image = light_model.evaluate_surface_brightness(x, y)
extent = coordinates.plt_extent
ax, im = plut.plot_regular_grid(ax, title, image, extent=extent, cmap=cmap,
neg_values_as_bad=neg_values_as_bad,
norm=norm, xylim=xylim)
else:
values, extent_model, coordinates = light_model.surface_brightness(return_extra=True)
if isinstance(values, np.ndarray) and len(values.shape) == 2:
image = values
ax, im = plut.plot_regular_grid(ax, title, image, extent=extent_model,
cmap=cmap,
neg_values_as_bad=neg_values_as_bad,
norm=norm, xylim=xylim)
else:
points = values
if xylim is None:
xylim = extent_model
ax, im = plut.plot_irregular_grid(ax, title, points, xylim, norm=norm, cmap=cmap,
neg_values_as_bad=neg_values_as_bad,
plot_points=plot_points_irreg)
image = None
if plot_caustics:
for caustic in caustics:
ax.plot(caustic[0], caustic[1], lw=1, color=caustics_color, alpha=caustics_alpha)
if add_colorbar:
cb = plut.nice_colorbar(im, ax=ax)
cb.set_label("flux")
if add_scalebar:
plut.scale_bar(ax, scalebar_size, color='white', loc='lower right')
return image, coordinates
[docs]
def plot_model_image(self, ax, title=None,
norm=None, cmap=None, xylim=None, neg_values_as_bad=False,
add_colorbar=True, add_scalebar=True, scalebar_size=1,
auto_selection=False,
kwargs_lens_mass=None,
kwargs_lens_light=None,
kwargs_source=None,
**model_image_kwargs):
"""plt.imshow panel showing the surface brightness of the (lensed)
selected lensing entities (see ComposableLensModel docstring)
"""
if cmap is None:
cmap = self.cmap_flux
lens_model = ComposableLensModel(
self.coolest, self._directory,
auto_selection=auto_selection,
kwargs_selection_source=kwargs_source,
kwargs_selection_lens_mass=kwargs_lens_mass,
kwargs_selection_lens_light=kwargs_lens_light
)
image, coordinates = lens_model.model_image(**model_image_kwargs)
extent = coordinates.plt_extent
ax, im = plut.plot_regular_grid(ax, title, image, extent=extent,
cmap=cmap,
neg_values_as_bad=neg_values_as_bad,
norm=norm, xylim=xylim)
if add_colorbar:
cb = plut.nice_colorbar(im, ax=ax)
cb.set_label("flux")
if add_scalebar:
plut.scale_bar(ax, scalebar_size, color='white', loc='lower right')
return image
[docs]
def plot_model_residuals(self, ax, title = None, mask=None,
norm=None, cmap=None, xylim=None, add_chi2_label=False, chi2_fontsize=12,
kwargs_source=None,
kwargs_lens_mass=None,
kwargs_lens_light=None,
add_colorbar=True, add_scalebar=True, scalebar_size=1,
**model_image_kwargs):
"""plt.imshow panel showing the normalized model residuals image"""
if cmap is None:
cmap = self.cmap_res
if norm is None:
norm = Normalize(-6, 6)
ll_mask = self._get_likelihood_mask(mask)
lens_model = ComposableLensModel(self.coolest, self._directory,
kwargs_selection_source=kwargs_source,
kwargs_selection_lens_mass=kwargs_lens_mass,
kwargs_selection_lens_light=kwargs_lens_light)
image, coordinates = lens_model.model_residuals(mask=ll_mask, **model_image_kwargs)
extent = coordinates.plt_extent
ax, im = plut.plot_regular_grid(ax, title, image, extent=extent,
cmap=cmap,
neg_values_as_bad=False,
norm=norm, xylim=xylim)
if add_colorbar:
cb = plut.nice_colorbar(im, ax=ax)
cb.set_label("(data $-$ model) / noise")
if add_scalebar:
plut.scale_bar(ax, scalebar_size, color='black', loc='lower right')
if add_chi2_label is True:
num_constraints = np.size(image) if ll_mask is None else np.sum(ll_mask)
red_chi2 = np.sum(image**2) / num_constraints
ax.text(0.05, 0.05, r'$\chi^2_\nu$='+f'{red_chi2:.2f}', color='black', alpha=1,
fontsize=chi2_fontsize, va='bottom', ha='left', transform=ax.transAxes,
bbox={'color': 'white', 'alpha': 0.6})
return image
[docs]
def plot_convergence(self, ax, title = None, coordinates=None,
norm=None, cmap=None, xylim=None, neg_values_as_bad=False,
add_colorbar=True,
add_scalebar=True, scalebar_size=1,
kwargs_lens_mass=None):
"""plt.imshow panel showing the 2D convergence map associated to the
selected lensing entities (see ComposableMassModel docstring)
"""
if kwargs_lens_mass is None:
kwargs_lens_mass = {}
mass_model = ComposableMassModel(self.coolest, self._directory,
**kwargs_lens_mass)
if cmap is None:
cmap = self.cmap_conv
if coordinates is None:
coordinates = util.get_coordinates(self.coolest)
extent = coordinates.plt_extent
x, y = coordinates.pixel_coordinates
image = mass_model.evaluate_convergence(x, y)
ax, im = plut.plot_regular_grid(ax, title, image, extent=extent,
cmap=cmap,
neg_values_as_bad=neg_values_as_bad,
norm=norm, xylim=xylim)
if add_colorbar:
cb = plut.nice_colorbar(im, ax=ax)
cb.set_label(r"$\kappa$")
if add_scalebar:
plut.scale_bar(ax, scalebar_size, color='white', loc='lower right')
return image
[docs]
def plot_convergence_diff(
self, ax, reference_map, title = None, relative_error=True,
norm=None, cmap=None, xylim=None, coordinates=None,
add_colorbar=True, add_scalebar=True, scalebar_size=1,
kwargs_lens_mass=None,
plot_crit_lines=False, crit_lines_color='black', crit_lines_alpha=0.5):
"""plt.imshow panel showing the 2D convergence map associated to the
selected lensing entities (see ComposableMassModel docstring)
"""
if kwargs_lens_mass is None:
kwargs_lens_mass = {}
mass_model = ComposableMassModel(self.coolest, self._directory,
**kwargs_lens_mass)
if cmap is None:
cmap = self.cmap_res
if norm is None:
norm = Normalize(-1, 1)
if coordinates is None:
coordinates = util.get_coordinates(self.coolest)
if plot_crit_lines:
critical_lines, _ = util.find_all_lens_lines(coordinates, mass_model)
extent = coordinates.plt_extent
x, y = coordinates.pixel_coordinates
image = mass_model.evaluate_convergence(x, y)
if relative_error is True:
diff = (reference_map - image) / reference_map
else:
diff = reference_map - image
ax, im = plut.plot_regular_grid(ax, title, diff, extent=extent,
cmap=cmap,
norm=norm, xylim=xylim)
if plot_crit_lines:
for cline in critical_lines:
ax.plot(cline[0], cline[1], lw=1, color=crit_lines_color, alpha=crit_lines_alpha)
if add_colorbar:
cb = plut.nice_colorbar(im, ax=ax)
cb.set_label(r"$\kappa$")
if add_scalebar:
plut.scale_bar(ax, scalebar_size, color='black', loc='lower right')
return image
[docs]
def plot_magnification(self, ax, title = None,
norm=None, cmap=None, xylim=None,
add_colorbar=True, add_scalebar=True, scalebar_size=1,
coordinates=None, kwargs_lens_mass=None):
"""plt.imshow panel showing the 2D magnification map associated to the
selected lensing entities (see ComposableMassModel docstring)
"""
if kwargs_lens_mass is None:
kwargs_lens_mass = {}
mass_model = ComposableMassModel(self.coolest, self._directory,
**kwargs_lens_mass)
if cmap is None:
cmap = self.cmap_mag
if norm is None:
norm = Normalize(-10, 10)
if coordinates is None:
coordinates = util.get_coordinates(self.coolest)
x, y = coordinates.pixel_coordinates
extent = coordinates.plt_extent
image = mass_model.evaluate_magnification(x, y)
ax, im = plut.plot_regular_grid(ax, title, image, extent=extent,
cmap=cmap,
norm=norm, xylim=xylim)
if add_colorbar:
cb = plut.nice_colorbar(im, ax=ax)
cb.set_label(r"$\mu$")
if add_scalebar:
plut.scale_bar(ax, scalebar_size, color='white', loc='lower right')
return image
[docs]
def plot_magnification_diff(
self, ax, reference_map, title = None, relative_error=True,
norm=None, cmap=None, xylim=None,
add_colorbar=True, add_scalebar=True, scalebar_size=1,
coordinates=None, kwargs_lens_mass=None):
"""plt.imshow panel showing the (absolute or relative)
difference between 2D magnification maps
"""
if kwargs_lens_mass is None:
kwargs_lens_mass = {}
mass_model = ComposableMassModel(self.coolest, self._directory,
**kwargs_lens_mass)
if cmap is None:
cmap = self.cmap_res
if norm is None:
norm = Normalize(-1, 1)
if coordinates is None:
coordinates = util.get_coordinates(self.coolest)
x, y = coordinates.pixel_coordinates
extent = coordinates.plt_extent
image = mass_model.evaluate_magnification(x, y)
if relative_error is True:
diff = (reference_map - image) / reference_map
else:
diff = reference_map - image
ax, im = plut.plot_regular_grid(ax, title, diff, extent=extent,
cmap=cmap,
norm=norm, xylim=xylim)
if add_colorbar:
cb = plut.nice_colorbar(im, ax=ax)
cb.set_label(r"$\mu$")
if add_scalebar:
plut.scale_bar(ax, scalebar_size, color='black', loc='lower right')
return image
def _get_likelihood_mask(self, user_mask):
# TODO:
if self.coolest.likelihoods is None:
return None
try:
img_ll_idx = self.coolest.likelihoods.index('ImagingDataLikelihood')
except ValueError:
return None
img_ll = self.coolest.likelihoods[img_ll_idx]
mask = img_ll.get_mask_pixels(directory=self._directory)
if mask is None: # then we use the user-provided mask
mask = user_mask
return mask
[docs]
class MultiModelPlotter(object):
"""Wrapper around a set of ModelPlotter instances to produce panels that
consistently compare different models, evaluated on the same
coordinates systems.
Parameters
----------
coolest_objects : list
List of COOLEST instances
coolest_directories : list, optional
List of directories corresponding to each COOLEST instance, by default None
kwargs_plotter : dict, optional
Additional keyword arguments passed to ModelPlotter
"""
def __init__(self, coolest_objects, coolest_directories=None, **kwargs_plotter):
[docs]
self.num_models = len(coolest_objects)
if coolest_directories is None:
coolest_directories = self.num_models * [None]
for coolest, c_dir in zip(coolest_objects, coolest_directories):
self.plotter_list.append(ModelPlotter(coolest, coolest_directory=c_dir,
**kwargs_plotter))
[docs]
def plot_surface_brightness(self, axes, **kwargs):
return self._plot_light_multi('plot_surface_brightness',axes, **kwargs)
[docs]
def plot_data_image(self, axes, **kwargs):
return self._plot_data_multi(axes, **kwargs)
[docs]
def plot_model_image(self, axes, **kwargs):
return self._plot_lens_model_multi('plot_model_image', axes, **kwargs)
[docs]
def plot_model_residuals(self, axes, **kwargs):
return self._plot_lens_model_multi('plot_model_residuals', axes, **kwargs)
[docs]
def plot_convergence(self, axes, **kwargs):
return self._plot_lens_model_multi('plot_convergence', axes, **kwargs)
[docs]
def plot_magnification(self, axes, **kwargs):
return self._plot_lens_model_multi('plot_magnification', axes, **kwargs)
[docs]
def plot_convergence_diff(self, axes, *args, **kwargs):
return self._plot_lens_model_multi('plot_convergence_diff', axes, *args, **kwargs)
[docs]
def plot_magnification_diff(self, axes, *args, **kwargs):
return self._plot_lens_model_multi('plot_magnification_diff', axes, *args, **kwargs)
def _plot_light_multi(self, method_name, axes, **kwargs):
assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
kwargs_ = copy.deepcopy(kwargs)
if 'titles' in kwargs_:
del kwargs_['titles']
image_list = []
for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
if ax is None:
continue
if 'kwargs_light' in kwargs:
kwargs_['kwargs_light'] = {k: v[i] for k, v in kwargs['kwargs_light'].items()}
if 'kwargs_lens_mass' in kwargs: # used for over-plotting caustics
kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
if 'titles' in kwargs:
title = kwargs['titles'][i]
image = getattr(plotter, method_name)(ax, title, **kwargs_)
image_list.append(image)
return image_list
def _plot_mass_multi(self, method_name, axes, **kwargs):
assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
kwargs_ = copy.deepcopy(kwargs)
if 'titles' in kwargs_:
del kwargs_['titles']
image_list = []
for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
if ax is None:
continue
if 'kwargs_lens_mass' in kwargs:
kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
if 'titles' in kwargs:
title = kwargs['titles'][i]
image = getattr(plotter, method_name)(ax, title, **kwargs_)
image_list.append(image)
return image_list
def _plot_lens_model_multi(self, method_name, axes, *args, **kwargs):
assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
kwargs_ = copy.deepcopy(kwargs)
if 'titles' in kwargs_:
del kwargs_['titles']
image_list = []
for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
if ax is None:
continue
if 'kwargs_source' in kwargs:
kwargs_['kwargs_source'] = {k: v[i] for k, v in kwargs['kwargs_source'].items()}
if 'kwargs_lens_mass' in kwargs:
kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
if 'kwargs_lens_light' in kwargs:
kwargs_['kwargs_lens_light'] = {k: v[i] for k, v in kwargs['kwargs_lens_light'].items()}
if 'titles' in kwargs:
title = kwargs['titles'][i]
image = getattr(plotter, method_name)(ax, title, *args, **kwargs_)
image_list.append(image)
return image_list
def _plot_data_multi(self, axes, **kwargs):
assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
kwargs_ = copy.deepcopy(kwargs)
if 'titles' in kwargs_:
del kwargs_['titles']
image_list = []
for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
if ax is None:
continue
if 'titles' in kwargs:
title = kwargs['titles'][i]
image = getattr(plotter, 'plot_data_image')(ax, title, **kwargs_)
image_list.append(image)
return image_list
[docs]
class ParametersPlotter(object):
"""Handles plot of analytical models in a comparative way
Parameters
----------
parameter_id_list : array
A list of parameter unique ids obtained from lensing entities. Their order determines the order of the plot panels.
coolest_objects : array
A list of coolest objects that have a chain file associated to them.
coolest_directories : array
A list of paths matching the coolest files in 'chain_objs'.
coolest_names : array, optional
A list of labels for the coolest models in the 'chain_objs' list. Must have the same order as 'chain_objs'.
ref_coolest_objects : array, optional
A list of coolest objects that will be used as point estimates.
ref_coolest_directories : array
A list of paths matching the coolest files in 'point_estimate_objs'.
ref_coolest_names : array, optional
A list of labels for the models in the 'point_estimate_objs' list. Must have the same order as 'point_estimate_objs'.
posterior_bool_list : list, optional
List of bool to toggle errorbars on point-estimate values
colors : list, optional
List of pyplot color names to associate to each coolest model.
linestyles : list, optional
List of pyplot linesyles to associate to each coolest model.
add_multivariate_margin_samples : bool, optional
If True, will append to the list of compared models
a new chain that is resampled from the multi-variate normal distribution,
where its covariance matrix is computed from the marginalization of
all samples from all models. By default False.
num_samples_per_model_margin : int, optional
Number of samples to (randomly) draw from each model samples to concatenate
before estimating the multi-variate normal marginalization.
"""
np.random.seed(598237) # fix the random seed for reproducibility
def __init__(self, parameter_id_list, coolest_objects, coolest_directories=None, coolest_names=None,
ref_coolest_objects=None, ref_coolest_directories=None, ref_coolest_names=None,
posterior_bool_list=None, colors=None, linestyles=None,
add_multivariate_margin_samples=False, num_samples_per_model_margin=5_000):
[docs]
self.parameter_id_list = parameter_id_list
[docs]
self.coolest_objects = coolest_objects
[docs]
self.coolest_directories = coolest_directories
if coolest_names is None:
coolest_names = ["Model "+str(i) for i in range(len(coolest_objects))]
[docs]
self.coolest_names = coolest_names
[docs]
self.ref_coolest_objects = ref_coolest_objects
[docs]
self.ref_coolest_directories = ref_coolest_directories
[docs]
self.ref_coolest_names = ref_coolest_names
[docs]
self.ref_file_names = ref_coolest_names
[docs]
self.num_models = len(self.coolest_objects)
[docs]
self.num_params = len(self.parameter_id_list)
if colors is None:
colors = plt.cm.turbo(np.linspace(0.1, 0.9, self.num_models))
if linestyles is None:
linestyles = ['-']*self.num_models
[docs]
self.linestyles = linestyles
[docs]
self.ref_linestyles = ['--', ':', '-.', '-']
[docs]
self.ref_markers = ['s', '^', 'o', '*']
self._add_margin_samples = add_multivariate_margin_samples
self._ns_per_model_margin = num_samples_per_model_margin
self._color_margin = 'black'
self._label_margin = "Combined"
# self.posterior_bool_list = posterior_bool_list
# self.param_lens, self.param_source = util.split_lens_source_params(
# self.coolest_objects, self.coolest_names, lens_light=False)
[docs]
def init_getdist(self, shift_sample_list=None, settings_mcsamples=None,
add_multivariate_margin_samples=False):
"""Initializes the getdist plotter.
Parameters
----------
shift_sample_list : dict
Dictionary keyed by parameter ID to apply a uniform additive shift to
all samples of that parameters posterior distribution.
settings_mcsamples : dict, optional
Keyword arguments passed as the `settings` argument of getdist.MCSamples, by default None
Raises
------
ValueError
If the csv file containing samples is is not coma (,) separated.
"""
chains.print_load_details = False # Just to silence messages
parameter_id_set = set(self.parameter_id_list)
if shift_sample_list is None:
shift_sample_list = [None]*self.num_models
# Get the values of the point_estimates
point_estimates = []
if self.ref_coolest_objects is not None:
for coolest_obj in self.ref_coolest_objects:
values = []
for par in self.parameter_id_list:
param = coolest_obj.lensing_entities.get_parameter_from_id(par)
val = param.point_estimate.value
if val is None:
values.append(None)
else:
values.append(val)
point_estimates.append(values)
mcsamples = []
samples_margin, weights_margin = None, None
mysample_margin = None
for i in range(self.num_models):
chain_file = os.path.join(self.coolest_directories[i],self.coolest_objects[i].meta["chain_file_name"]) # Here get the chain file path for each coolest object
# Each chain file can have a different number of free parameters
f = open(chain_file)
header = f.readline()
f.close()
if ';' in header:
raise ValueError("Columns must be coma-separated (no semi-colon) in chain file.")
chain_file_headers = header.split(',')
num_cols = len(chain_file_headers)
chain_file_headers.pop() # Remove the last column name that is the probability weights
chain_file_headers_set = set(chain_file_headers)
# Check that the given parameters are a subset of those in the chain file
assert parameter_id_set.issubset(chain_file_headers_set), "Not all given parameters are free parameters for model %d (not in the chain file: %s)!" % (i,chain_file)
# Set the labels for the parameters in the chain file
labels = []
for par_id in self.parameter_id_list:
param = self.coolest_objects[i].lensing_entities.get_parameter_from_id(par_id)
labels.append(param.latex_str.strip('$'))
# Read parameter values and probability weights
column_indices = [chain_file_headers.index(par_id) for par_id in self.parameter_id_list]
columns_to_read = sorted(column_indices) + [num_cols-1] # add last one for probability weights
samples = pd.read_csv(chain_file, usecols=columns_to_read, delimiter=',')
# Re-order columns to match self.parameter_id_list and labels
sample_par_values = np.array(samples[self.parameter_id_list])
# If needed, shift samples by a constant
if shift_sample_list[i] is not None:
for param_id, value in shift_sample_list[i].items():
sample_par_values[:, self.parameter_id_list.index(param_id)] += value
logging.info(f"posterior for parameter '{param_id}' from model '{self.coolest_names[i]}' "
f"has been shifted by {value}.")
# Clean-up the probability weights
mypost = np.array(samples['probability_weights'])
min_non_zero = np.min(mypost[np.nonzero(mypost)])
sample_prob_weight = np.where(mypost<min_non_zero, min_non_zero, mypost)
#sample_prob_weight = mypost
# Create MCSamples object
mysample = MCSamples(samples=sample_par_values, names=self.parameter_id_list,
labels=labels, settings=settings_mcsamples)
mysample.reweightAddingLogLikes(-np.log(sample_prob_weight))
mcsamples.append(mysample)
# if required, aggregate the samples in a "marginalized" posterior
if self._add_margin_samples:
if i == 0:
mysample_margin = copy.deepcopy(mysample)
else:
# combine the sample such that the probability mass of each set of samples is the same
mysample_margin = mysample_margin.getCombinedSamplesWithSamples(mysample, sample_weights=(1, 1))
if self._add_margin_samples:
mcsamples.append(mysample_margin)
self._mcsamples = mcsamples
self.ref_values = point_estimates
self.ref_values_markers = [dict(zip(self.parameter_id_list, values)) for values in self.ref_values]
[docs]
def get_mcsamples_getdist(self, with_margin=False):
if not self._add_margin_samples or with_margin:
return self._mcsamples
else:
return self._mcsamples[:-1]
[docs]
def get_margin_mcsamples_getdist(self):
if not self._add_margin_samples:
return None
else:
return self._mcsamples[-1]
[docs]
def plot_triangle_getdist(self, filled_contours=True, angles_range=None,
linewidth_hist=2, linewidth_cont=2, linewidth_margin=4,
marker_linewidth=2, marker_size=15,
axes_labelsize=None, legend_fontsize=None,
**subplot_kwargs):
"""Corner array of subplots using getdist.triangle_plot method.
Parameters
----------
subplot_size : int, optional
Size of the getdist plot, by default 1
filled_contours : bool, optional
Wether or not to fill the 2D contours, by default True
angles_range : _type_, optional
Restrict the range of angle (containing 'phi' in their name) parameters, by default None
linewidth_hist : int, optional
Line width for 1D histograms, by default 2
linewidth_cont : int, optional
Line width for 2D contours, by default 1
marker_size : int, optional
Size of the reference (scatter) markers on 2D contours plots, by default 15
Returns
-------
GetDistPlotter
Instance of GetDistPlotter corresponding to the figure
"""
line_args, contour_lws, contour_ls, colors, legend_labels \
= self._prepare_getdist_plot(linewidth_hist,
lw_cont=linewidth_cont,
lw_margin=linewidth_margin)
filled_contours = [filled_contours]*len(self._mcsamples)
alphas = [1]*len(self._mcsamples)
if self._add_margin_samples:
filled_contours[-1] = True
# alphas[-1] = 0.7
# Make the plot
g = plots.get_subplot_plotter(**subplot_kwargs)
if legend_fontsize is not None:
g.settings.legend_fontsize = legend_fontsize
if axes_labelsize is not None:
g.settings.axes_labelsize = axes_labelsize
g.triangle_plot(
self._mcsamples,
params=self.parameter_id_list,
legend_labels=legend_labels,
filled=filled_contours,
colors=colors,
line_args=line_args, # TODO: issue that linewidth settings in line_args are being overwritten by contour_lws
contour_colors=self.colors,
contour_lws=contour_lws,
contour_ls=contour_ls,
alphas=alphas,
)
# Add marker lines and points
for k in range(0, len(self.ref_values)):
g.add_param_markers(self.ref_values_markers[k], color='black', ls=self.ref_linestyles[k],
lw=marker_linewidth)
for i in range(0,self.num_params):
val_x = self.ref_values[k][i]
for j in range(i+1,self.num_params):
val_y = self.ref_values[k][j]
if val_x is not None and val_y is not None:
g.subplots[j,i].scatter(val_x, val_y, s=marker_size, facecolors='black',
color='black', marker=self.ref_markers[k])
# Set default ranges for angles
if angles_range is None:
angles_range = (-90, 90)
for i in range(0, len(self.parameter_id_list)):
dum = self.parameter_id_list[i].split('-')
name = dum[-1]
if name in ['phi','phi_ext']:
xlim = g.subplots[i,i].get_xlim()
#print(xlim)
if xlim[0] < -90:
for ax in g.subplots[i:,i]:
ax.set_xlim(left=angles_range[0])
for ax in g.subplots[i,:i]:
ax.set_ylim(bottom=angles_range[0])
if xlim[1] > 90:
for ax in g.subplots[i:,i]:
ax.set_xlim(right=angles_range[1])
for ax in g.subplots[i,:i]:
ax.set_ylim(top=angles_range[1])
return g
[docs]
def plot_rectangle_getdist(self, x_param_ids, y_param_ids, subplot_size=1,
legend_ncol=None, legend_fontsize=None,
filled_contours=True, linewidth=1,
marker_size=15, axes_labelsize=None, **subplot_kwargs):
"""Array of (2D contours) subplots using getdist.rectangle_plot method.
Parameters
----------
subplot_size : int, optional
Size of the getdist plot, by default 1
filled_contours : bool, optional
Wether or not to fill the 2D contours, by default True
linewidth : int, optional
Line width for 2D contours, by default 1
marker_size : int, optional
Size of the reference (scatter) markers on 2D contours plots, by default 15
legend_ncol : number of columns in the legend
Returns
-------
GetDistPlotter
Instance of GetDistPlotter corresponding to the figure
"""
line_args, _, _, colors, legend_labels = self._prepare_getdist_plot(linewidth)
if legend_ncol is None:
legend_ncol = 3
# Make the plot
g = plots.get_subplot_plotter(**subplot_kwargs)
if legend_fontsize is not None:
g.settings.legend_fontsize = legend_fontsize
if axes_labelsize is not None:
g.settings.axes_labelsize = axes_labelsize
g.rectangle_plot(x_param_ids, y_param_ids, roots=self._mcsamples,
filled=filled_contours,
colors=colors,
legend_ncol=legend_ncol,
legend_labels=legend_labels,
line_args=line_args,
contour_colors=self.colors)
for k in range(len(self.ref_values)):
g.add_param_markers(self.ref_values_markers[k], color='black', ls=self.ref_linestyles[k], lw=linewidth)
for j, key_x in enumerate(x_param_ids):
val_x = self.ref_values_markers[k][key_x]
for i, key_y in enumerate(y_param_ids):
val_y = self.ref_values_markers[k][key_y]
if val_x is not None and val_y is not None:
g.subplots[i, j].scatter(val_x,val_y,s=marker_size,facecolors='black',color='black',marker=self.ref_markers[k])
return g
[docs]
def plot_1d_getdist(self, num_columns=None, legend_ncol=None,
legend_fontsize=None, axes_labelsize=None,
linewidth=1, **subplot_kwargs):
"""Array of 1D histogram subplots using getdist.plots_1d method.
Parameters
----------
subplot_size : int, optional
Size of the getdist plot, by default 1
linewidth : int, optional
Line width for 2D contours, by default 1
marker_size : int, optional
Size of the reference (scatter) markers on 2D contours plots, by default 15
legend_ncol : int, optional
number of columns in the legend
num_columns : int, optional
number of columns of the subplot array
Returns
-------
GetDistPlotter
Instance of GetDistPlotter corresponding to the figure
"""
line_args, _, _, colors, legend_labels = self._prepare_getdist_plot(linewidth)
if num_columns is None:
num_columns = self.num_models//2+1
if legend_ncol is None:
legend_ncol = 3
# Make the plot
g = plots.get_subplot_plotter(**subplot_kwargs)
if legend_fontsize is not None:
g.settings.legend_fontsize = legend_fontsize
if axes_labelsize is not None:
g.settings.axes_labelsize = axes_labelsize
g.plots_1d(self._mcsamples,
params=self.parameter_id_list,
legend_labels=legend_labels,
colors=colors,
share_y=True,
line_args=line_args,
nx=num_columns, legend_ncol=legend_ncol,
)
for k in range(len(self.ref_values)):
g.add_param_markers(self.ref_values_markers[k], color='black', ls=self.ref_linestyles[k], lw=linewidth)
# for k in range(0, len(self.ref_values)):
# # Add vertical and horizontal lines
# for i in range(0, self.num_params):
# val = self.ref_values[k][i]
# ax = g.subplots.flatten()[i]
# if val is not None:
# ax.axvline(val, color='black', ls=self.ref_linestyles[k], alpha=1.0, lw=1)
return g
[docs]
def plot_source(self, idx_file=0):
f,ax = self.plotting_routine(self.param_source,idx_file)
return f,ax
[docs]
def plot_lens(self, idx_file=0):
f,ax = self.plotting_routine(self.param_lens,idx_file)
return f,ax
[docs]
def plotting_routine(self, param_dict, idx_file=0):
"""
plot the parameters
INPUT
-----
param_dict: dict, organized dictonnary with all parameters results of the different files
idx_file: int, chooses the file on which the choice of plotted parameters will be made
(not very clear: basically in file 0 you may have a sersic fit and in file 1 sersic+shapelets. If you choose
idx_file=0, you will plot the sersic results of both file. If you choose idx_file=1, you will plot all the
sersic and shapelets parameters when available)
"""
#find the numer of parameters to plot and define a nice looking figure
number_param = len(param_dict[self.file_names[idx_file]])
unused_figs = []
if number_param <= 4:
print('so few parameters not implemented yet')
else:
if number_param % 4 == 0:
num_lines = int(number_param / 4.)
else:
num_lines = int(number_param / 4.) + 1
for idx in range(3):
if (number_param + idx) % 4 != 0:
unused_figs.append(-idx - 1)
else:
break
f, ax = plt.subplots(num_lines, 4, figsize=(4 * 3.5, 2.5 * num_lines))
markers = ['*', '.', 's', '^','<','>','v','p','P','X','D','1','2','3','4','+']
#may find a better way to define markers but right now, it is sufficient
for j, file_name in enumerate(self.file_names):
i = 0
result = param_dict[file_name]
for key in result.keys():
idx_line = int(i / 4.)
idx_col = i % 4
p = result[key]
m = markers[j]
if self.posterior_bool_list[j]:
# UNCOMMENT IF NO ERROR BARS AVAILABLE ON SHEAR
# if (j== 1) and (key=='SHEAR_0_gamma_ext' or key == 'SHEAR_0_phi_ext'):
# ax[idx_line,idx_col].plot(j,p['point_estimate'],marker=m,ls='',label=file_name)
# i+=1
# continue
#trick to plot correct error bars if close to the +180/-180 edge
if (key == 'SHEAR_0_phi_ext' or key == 'PEMD_0_phi'):
if p['percentile_16th'] > p['median']:
p['percentile_16th'] -= 180.
if p['percentile_84th'] < p['median']:
p['percentile_84th'] += 180.
ax[idx_line, idx_col].errorbar(j, p['median'], [[p['median'] - p['percentile_16th']],
[p['percentile_84th'] - p['median']]],
marker=m, ls='', label=file_name)
else:
ax[idx_line, idx_col].plot(j, p['point_estimate'], marker=m, ls='', label=file_name)
if j == 0:
ax[idx_line, idx_col].get_xaxis().set_visible(False)
ax[idx_line, idx_col].set_ylabel(p['latex_str'], fontsize=12)
ax[idx_line, idx_col].tick_params(axis='y', labelsize=12)
i += 1
ax[0, 0].legend()
for idx in unused_figs:
ax[-1, idx].axis('off')
plt.tight_layout()
plt.show()
return f, ax
def _prepare_getdist_plot(self, lw, lw_cont=None, lw_margin=None):
if lw_margin is None:
lw_margin = lw + 2
line_args = [{'ls': ls, 'lw': lw, 'color': c} for ls, c in zip(self.linestyles, self.colors)]
lw_conts = [lw_cont]*self.num_models
ls_conts = self.linestyles
legend_labels = copy.deepcopy(self.coolest_names)
colors = copy.deepcopy(self.colors)
if self._add_margin_samples:
line_args.append({'ls': '-.', 'lw': lw_margin, 'alpha': 0.8, 'color': self._color_margin})
ls_conts.append('-.')
if lw_cont is not None: lw_conts.append(lw_margin)
legend_labels.append(self._label_margin)
colors.append(self._color_margin)
return line_args, lw_conts, ls_conts, colors, legend_labels
# def plot_corner(parameter_id_list,
# chain_objs, chain_dirs, chain_names=None,
# point_estimate_objs=None, point_estimate_dirs=None, point_estimate_names=None,
# colors=None, labels=None, subplot_size=1, mc_samples_kwargs=None,
# filled_contours=True, angles_range=None, shift_sample_list=None):
# """
# Adding this as just a function for the moment.
# Takes a list of COOLEST files as input, which must have a chain file associated to them, and returns a corner plot.
# Parameters
# ----------
# parameter_id_list : array
# A list of parameter unique ids obtained from lensing entities. Their order determines the order of the plot panels.
# chain_objs : array
# A list of coolest objects that have a chain file associated to them.
# chain_dirs : array
# A list of paths matching the coolest files in 'chain_objs'.
# chain_names : array, optional
# A list of labels for the coolest models in the 'chain_objs' list. Must have the same order as 'chain_objs'.
# point_estimate_objs : array, optional
# A list of coolest objects that will be used as point estimates.
# point_estimate_dirs : array
# A list of paths matching the coolest files in 'point_estimate_objs'.
# point_estimate_names : array, optional
# A list of labels for the models in the 'point_estimate_objs' list. Must have the same order as 'point_estimate_objs'.
# labels : dict, optional
# A dictionary matching the parameter_id_list entries to some human-readable labels.
# Returns
# -------
# An image
# """
# chains.print_load_details = False # Just to silence messages
# parameter_id_set = set(parameter_id_list)
# Npars = len(parameter_id_list)
# Nobjs = len(chain_objs)
# # Set the chain names
# if chain_names is None:
# chain_names = ["chain "+str(i) for i in range(Nobjs)]
# if shift_sample_list is None:
# shift_sample_list = [None]*Nobjs
# # Get the values of the point_estimates
# point_estimates = []
# if point_estimate_objs is not None:
# for coolest_obj in point_estimate_objs:
# values = []
# for par in parameter_id_list:
# param = coolest_obj.lensing_entities.get_parameter_from_id(par)
# val = param.point_estimate.value
# if val is None:
# values.append(None)
# else:
# values.append(val)
# point_estimates.append(values)
# mcsamples = []
# for i in range(Nobjs):
# chain_file = os.path.join(chain_dirs[i],chain_objs[i].meta["chain_file_name"]) # Here get the chain file path for each coolest object
# # Each chain file can have a different number of free parameters
# f = open(chain_file)
# header = f.readline()
# f.close()
# if ';' in header:
# raise ValueError("Columns must be coma-separated (no semi-colon) in chain file.")
# chain_file_headers = header.split(',')
# num_cols = len(chain_file_headers)
# chain_file_headers.pop() # Remove the last column name that is the probability weights
# chain_file_headers_set = set(chain_file_headers)
# # Check that the given parameters are a subset of those in the chain file
# assert parameter_id_set.issubset(chain_file_headers_set), "Not all given parameters are free parameters for model %d (not in the chain file: %s)!" % (i,chain_file)
# # Set the labels for the parameters in the chain file
# par_labels = []
# if labels is None:
# labels = {}
# for par_id in parameter_id_list:
# if labels.get(par_id, None) is None:
# param = coolest_obj.lensing_entities.get_parameter_from_id(par_id)
# par_labels.append(param.latex_str.strip('$'))
# else:
# par_labels.append(labels[par_id])
# # Read parameter values and probability weights
# column_indices = [chain_file_headers.index(par_id) for par_id in parameter_id_list]
# columns_to_read = sorted(column_indices) + [num_cols-1] # add last one for probability weights
# samples = pd.read_csv(chain_file, usecols=columns_to_read, delimiter=',')
# # Re-order columnds to match parameter_id_list and par_labels
# sample_par_values = np.array(samples[parameter_id_list])
# # If needed, shift samples by a constant
# if shift_sample_list[i] is not None:
# for param_id, value in shift_sample_list[i].items():
# sample_par_values[:, parameter_id_list.index(param_id)] += value
# print(f"INFO: posterior for parameter '{param_id}' from model '{chain_names[i]}' "
# f"has been shifted by {value}.")
# # Clean-up the probability weights
# mypost = np.array(samples['probability_weights'])
# min_non_zero = np.min(mypost[np.nonzero(mypost)])
# sample_prob_weight = np.where(mypost<min_non_zero,min_non_zero,mypost)
# #sample_prob_weight = mypost
# # Create MCSamples object
# mysample = MCSamples(samples=sample_par_values,names=parameter_id_list,labels=par_labels,settings=mc_samples_kwargs)
# mysample.reweightAddingLogLikes(-np.log(sample_prob_weight))
# mcsamples.append(mysample)
# # Make the plot
# image = plots.getSubplotPlotter(subplot_size=subplot_size)
# image.triangle_plot(mcsamples,
# params=parameter_id_list,
# legend_labels=chain_names,
# filled=filled_contours,
# colors=colors,
# line_args=[{'ls':'-', 'lw': 2, 'color': c} for c in colors],
# contour_colors=colors)
# my_linestyles = ['solid','dotted','dashed','dashdot']
# my_markers = ['s','^','o','star']
# for k in range(0,len(point_estimates)):
# # Add vertical and horizontal lines
# for i in range(0,Npars):
# val = point_estimates[k][i]
# if val is not None:
# for ax in image.subplots[i:,i]:
# ax.axvline(val,color='black',ls=my_linestyles[k],alpha=1.0,lw=1)
# for ax in image.subplots[i,:i]:
# ax.axhline(val,color='black',ls=my_linestyles[k],alpha=1.0,lw=1)
# # Add points
# for i in range(0,Npars):
# val_x = point_estimates[k][i]
# for j in range(i+1,Npars):
# val_y = point_estimates[k][j]
# if val_x is not None and val_y is not None:
# image.subplots[j,i].scatter(val_x,val_y,s=10,facecolors='black',color='black',marker=my_markers[k])
# else:
# pass
# # Set default ranges for angles
# if angles_range is None:
# angles_range = (-90, 90)
# for i in range(0,len(parameter_id_list)):
# dum = parameter_id_list[i].split('-')
# name = dum[-1]
# if name in ['phi','phi_ext']:
# xlim = image.subplots[i,i].get_xlim()
# #print(xlim)
# if xlim[0] < -90:
# for ax in image.subplots[i:,i]:
# ax.set_xlim(left=angles_range[0])
# for ax in image.subplots[i,:i]:
# ax.set_ylim(bottom=angles_range[0])
# if xlim[1] > 90:
# for ax in image.subplots[i:,i]:
# ax.set_xlim(right=angles_range[1])
# for ax in image.subplots[i,:i]:
# ax.set_ylim(top=angles_range[1])
# return image