"""The core class for evaluating datasets."""
import pickle
import warnings
from copy import deepcopy
import pandas as pd
from . import metrics, plot, utils
[docs]class Evaluator:
"""The central class in `synthgauge`, used to hold and evaluate data
via metrics and visualisation.
Parameters
----------
real : pandas.DataFrame
Dataframe containing the real data.
synth : pandas.DataFrame
Dataframe containing the synthetic data.
handle_nans : str, default "drop"
Whether to drop missing values. If yes, use "drop" (default).
Returns
-------
synthgauge.Evaluator
An `Evaluator` object ready for metric and visual evaluation.
"""
def __init__(self, real, synth, handle_nans="drop"):
common_feats = real.columns.intersection(synth.columns)
ignore_feats = real.columns.union(synth.columns).difference(
common_feats
)
if len(ignore_feats) > 0:
msg = (
f"Features {', '.join(ignore_feats)} are not common to "
"`real` and `synth` and will be ignored in further analysis."
)
warnings.warn(msg)
self.feature_names = list(common_feats)
# Metrics is private to apply some validation
self.__metrics = dict() # assign metrics and kwargs
self.metric_results = dict() # store results
# Handle NaNs
if handle_nans == "drop":
real.dropna(inplace=True)
synth.dropna(inplace=True)
self.real_data = real
self.synth_data = synth
self.combined_data = utils.df_combine(self.real_data, self.synth_data)
[docs] def describe_numeric(self):
"""Summarise numeric features.
This function uses `pandas.DataFrame.describe` to calculate
summary statistics for each numeric feature in `self.real_data`
and `self.synth_data`.
Returns
-------
pandas.DataFrame
Descriptive statistics for each numeric feature.
"""
real, synth = utils.launder(self.real_data, self.synth_data)
return pd.concat(
[
real.describe(include="number"),
synth.describe(include="number"),
],
axis=1,
).T.sort_index()
[docs] def describe_categorical(self):
"""Summarise categorical features.
This function uses `pandas.DataFrame.describe` to calculate
summary statistics for each object-type feature in
`self.real_data` and `self.synth_data`.
Returns
-------
pandas.DataFrame
Descriptive statistics for each object-type feature.
"""
real, synth = utils.launder(self.real_data, self.synth_data)
return (
pd.concat(
[
real.describe(include=["category", "object"]),
synth.describe(include=["category", "object"]),
],
axis=1,
)
.T.sort_index()
.rename(columns={"top": "most_frequent"})
)
[docs] def add_metric(self, name, alias=None, **kwargs):
"""Add a metric to the evaluator.
Metrics and their arguments are recorded to be run at a later
time. This allows metric customisation but ensures that the same
metric configuration is applied consistently, i.e. once added,
the parameters do not require resupplying for each execution of
the metric. Supplying a metric alias allows the same metric to
be used multiple times with different parameters.
Note that `self.real_data` and `self.synth_data` will be passed
automatically to metrics that expect these arguments. They
should not be declared in `metric_kwargs`.
Parameters
----------
name : str
Name of the metric. Must match one of the functions in
`synthgauge.metrics`.
alias : str, optional
Alias to be given to this use of the metric in the results
table. Allows the same metric to be used multiple times with
different parameters. If not specified, `name` is used.
**kwargs : dict, optional
Keyword arguments for the metric. Refer to the associated
metric documentation for details.
"""
try:
getattr(metrics, name)
kwargs["name"] = name
alias = name if alias is None else alias
self.__metrics.update({alias: kwargs})
except AttributeError:
raise NotImplementedError(f"Metric '{name}' is not implemented")
[docs] def add_custom_metric(self, alias, func, **kwargs):
"""Add a custom metric to the evaluator.
A custom metric uses any user-defined function that accepts the
real and synthetic dataframes as the first and second positional
arguments, respectively. Any other parameters must be defined as
keyword arguments. The metric function can return a value of any
desired type although scalar numeric values are recommended, or
`collections.namedtuples` when there are multiple outputs.
Parameters
----------
alias : str
Alias for the metric to appear in the results table.
func : function
Top-level metric function to be called during the evaluation
step. The first two arguments of `func` must be `self.real`
and `self.synth`.
**kwargs : dict, optional
Keyword arguments to be passed to `func`.
"""
kwargs.update({"func": func, "name": alias})
self.__metrics.update({alias: kwargs})
[docs] def copy_metrics(self, other):
"""Copy metrics from another evaluator.
To facilitate consistent comparisons of different synthetic
datasets, this function copies the metrics dictionary from
another `Evaluator` instance.
Parameters
----------
other : Evaluator
The other evaluator from which the metrics dictionary will
be copied.
"""
if not isinstance(other, Evaluator):
raise TypeError("`other` must be of class Evaluator")
self.__metrics = deepcopy(other.metrics)
[docs] def save_metrics(self, filename):
"""Save the current metrics dictionary to disk via `pickle`.
Parameters
----------
filename : str
Path to pickle file to save the metrics.
"""
with open(filename, "wb") as f:
pickle.dump(self.metrics, f)
[docs] def load_metrics(self, filename, overwrite=False):
"""Load metrics from disk.
Update or overwrite the current metric dictionary from a pickle.
Parameters
----------
filename : str
Path to metrics pickle file.
overwrite : bool, default False
If `True`, all current metrics will be replaced with the
loaded metrics. Default is `False`, which will update the
current metric dictionary with the loaded metrics.
"""
with open(filename, "rb") as f:
new_metrics = pickle.load(f)
invalid_metrics = []
for k, v in new_metrics.items():
if getattr(metrics, v["name"], None) is None:
invalid_metrics.append(k)
if len(invalid_metrics) > 0:
invalid_str = ", ".join(invalid_metrics)
raise ValueError(f"Invalid metrics encountered in: {invalid_str}.")
if overwrite:
self.__metrics = new_metrics
else:
self.__metrics.update(new_metrics)
@property
def metrics(self):
"""Return __metrics."""
return self.__metrics
[docs] def drop_metric(self, metric):
"""Drops the named metric from the metrics dictionary.
Parameters
----------
metric : str
Key (name or alias, if specified) of the metric to remove.
"""
try:
del self.__metrics[metric]
except KeyError:
pass
[docs] def evaluate(self, as_df=False):
"""Compute metrics for real and synth data.
Run through the metrics dictionary and execute each with its
corresponding arguments. The results are returned as either a
dictionary or dataframe.
Results are also silently stored as a dictionary in
`self.metric_results`.
Parameters
----------
as_df : bool, default False
If `True`, the results will be returned as a
`pandas.DataFrame`, otherwise a dictionary is returned.
Default is `False`.
Returns
-------
pandas.DataFrame
If `as_df` is `True`. Each row corresponds to a metric-value
pair. Metrics with multiple values have multiple rows.
dict
If `as_df` is `False`. The keys are the metric names and
the values are the metric values (grouped). Metrics with
multiple values are assigned to a single key.
"""
results = dict.fromkeys(self.__metrics.keys())
metrics_copy = deepcopy(self.__metrics)
for metric, kwargs in metrics_copy.items():
metric_name = kwargs.pop("name")
if metric_name in metrics.__dict__.keys():
metric_func = getattr(metrics, metric_name)
else:
metric_func = kwargs.pop("func")
results[metric] = metric_func(
self.real_data, self.synth_data, **kwargs
)
self.metric_results = dict(results)
if as_df:
tidy_results = {}
for k, v in self.metric_results.items():
try:
for vk, vv in v._asdict().items():
tidy_results[k + "-" + vk] = vv
except AttributeError:
tidy_results[k] = v
return pd.DataFrame(tidy_results, index=["value"]).T
else:
return results
[docs] def plot_histograms(self, figcols=2, figsize=None):
"""Plot grid of feature distributions.
Convenience wrapper for `synthgauge.plot.plot_histograms`. This
function uses the combined real and synthetic data sets and
groups by `'source'`.
"""
return plot.plot_histograms(
self.combined_data,
feats=self.feature_names,
groupby="source",
figcols=figcols,
figsize=figsize,
)
[docs] def plot_histogram3d(
self, data, x, y, x_bins="auto", y_bins="auto", figsize=None
):
"""Plot 3D histogram.
Convenience wrapper for `synthgauge.plot.plot_histogram3d`.
Parameters
----------
data: {"real", "synth", "combined"}
Dataframe to pass to plotting function. Either `"real"` to
pass `self.real_data`, `"synth"` to pass `self.synth_data`
or `"combined"` to pass `self.combined_data`.
x : str
Column to plot along the x-axis.
y : str
Column to plot alont the y-axis.
"""
return plot.plot_histogram3d(
getattr(self, f"{data}_data"),
x=x,
y=y,
x_bins=x_bins,
y_bins=y_bins,
figsize=figsize,
)
[docs] def plot_correlation(
self, feats=None, method="pearson", figcols=2, figsize=None, **kwargs
):
"""Plot a grid of correlation heatmaps.
Convenience wrapper for `synthgauge.plot.plot_correlation`. Each
dataset (real and synthetic) has a plot, as well as one for the
differences in their correlations.
"""
return plot.plot_correlation(
self.real_data,
self.synth_data,
feats=feats,
method=method,
plot_diff=True,
figcols=figcols,
figsize=figsize,
**kwargs,
)
[docs] def plot_crosstab(self, x, y, figsize=None, **kwargs):
"""Plot pairwise cross-tabulation.
Convenience wrapper for `synthgauge.plot.plot_crosstab`.
Automatically sets `real` and `synth` parameters to the
corresponding data in `self`.
"""
return plot.plot_crosstab(
self.real_data,
self.synth_data,
x=x,
y=y,
figsize=figsize,
**kwargs,
)
[docs] def plot_qq(self, feature, n_quantiles=None, figsize=None):
"""Plot quantile-quantile plot.
Convenience wrapper for `synthgauge.plot.plot_qq`.
Parameters
----------
feature : str
Feature to plot.
**kwargs : dict, optional
Keyword arguments to pass to `synthgauge.plot.plot_qq`.
"""
return plot.plot_qq(
self.real_data, self.synth_data, feature, n_quantiles, figsize
)