PrecisionRecallCurveDisplay#
- class skore.PrecisionRecallCurveDisplay(precision, recall, *, average_precision, estimator_name, pos_label=None, data_source=None)[source]#
Precision Recall visualization.
An instance of this class is should created by
EstimatorReport.metrics.precision_recall()
. You should not create an instance of this class directly.- Parameters:
- precisiondict of list of ndarray
Precision values. The structure is:
- for binary classification:
the key is the positive label.
the value is a list of
ndarray
, eachndarray
being the precision.
- for multiclass classification:
the key is the class of interest in an OvR fashion.
the value is a list of
ndarray
, eachndarray
being the precision.
- recalldict of list of ndarray
Recall values. The structure is:
- for binary classification:
the key is the positive label.
the value is a list of
ndarray
, eachndarray
being the recall.
- for multiclass classification:
the key is the class of interest in an OvR fashion.
the value is a list of
ndarray
, eachndarray
being the recall.
- average_precisiondict of list of float
Average precision. The structure is:
- for binary classification:
the key is the positive label.
the value is a list of
float
, eachfloat
being the average precision.
- for multiclass classification:
the key is the class of interest in an OvR fashion.
the value is a list of
float
, eachfloat
being the average precision.
- estimator_namestr
Name of the estimator.
- pos_labelint, float, bool or str, default=None
The class considered as the positive class. If None, the class will not be shown in the legend.
- data_source{“train”, “test”, “X_y”}, default=None
The data source used to compute the precision recall curve.
- Attributes:
- ax_matplotlib Axes
Axes with precision recall curve.
- figure_matplotlib Figure
Figure containing the curve.
- lines_list of matplotlib Artist
Precision recall curve.
- chance_levels_matplotlib Artist or None
The chance level line. It is
None
if the chance level is not plotted.
Examples
>>> from sklearn.datasets import load_breast_cancer >>> from sklearn.linear_model import LogisticRegression >>> from sklearn.model_selection import train_test_split >>> from skore import EstimatorReport >>> X_train, X_test, y_train, y_test = train_test_split( ... *load_breast_cancer(return_X_y=True), random_state=0 ... ) >>> classifier = LogisticRegression(max_iter=10_000) >>> report = EstimatorReport( ... classifier, ... X_train=X_train, ... y_train=y_train, ... X_test=X_test, ... y_test=y_test, ... ) >>> display = report.metrics.precision_recall() >>> display.plot(pr_curve_kwargs={"color": "tab:red"})
- plot(ax=None, *, estimator_name=None, pr_curve_kwargs=None, despine=True)[source]#
Plot visualization.
Extra keyword arguments will be passed to matplotlib’s
plot
.- Parameters:
- axMatplotlib Axes, default=None
Axes object to plot on. If
None
, a new figure and axes is created.- estimator_namestr, default=None
Name of the estimator used to plot the precision-recall curve. If
None
, we use the inferred name from the estimator.- pr_curve_kwargsdict or list of dict, default=None
Keyword arguments to be passed to matplotlib’s
plot
for rendering the precision-recall curve(s).- despinebool, default=True
Whether to remove the top and right spines from the plot.
Notes
The average precision (cf.
average_precision_score()
) in scikit-learn is computed without any interpolation. To be consistent with this metric, the precision-recall curve is plotted without any interpolation as well (step-wise style).You can change this style by passing the keyword argument
drawstyle="default"
. However, the curve will not be strictly consistent with the reported average precision.Examples
>>> from sklearn.datasets import load_breast_cancer >>> from sklearn.linear_model import LogisticRegression >>> from sklearn.model_selection import train_test_split >>> from skore import EstimatorReport >>> X_train, X_test, y_train, y_test = train_test_split( ... *load_breast_cancer(return_X_y=True), random_state=0 ... ) >>> classifier = LogisticRegression(max_iter=10_000) >>> report = EstimatorReport( ... classifier, ... X_train=X_train, ... y_train=y_train, ... X_test=X_test, ... y_test=y_test, ... ) >>> display = report.metrics.precision_recall() >>> display.plot(pr_curve_kwargs={"color": "tab:red"})