.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/technical_details/plot_cache_mechanism.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_technical_details_plot_cache_mechanism.py: .. _example_cache_mechanism: =============== Cache mechanism =============== This example shows how :class:`~skore.EstimatorReport` and :class:`~skore.CrossValidationReport` use caching to speed up computations. .. GENERATED FROM PYTHON SOURCE LINES 13-15 We set some environment variables to avoid some spurious warnings related to parallelism. .. GENERATED FROM PYTHON SOURCE LINES 16-20 .. code-block:: Python import os os.environ["POLARS_ALLOW_FORKING_THREAD"] = "1" .. GENERATED FROM PYTHON SOURCE LINES 21-27 Loading some data ================= First, we load a dataset from `skrub`. Our goal is to predict if a company paid a physician. The ultimate goal is to detect potential conflict of interest when it comes to the actual problem that we want to solve. .. GENERATED FROM PYTHON SOURCE LINES 27-33 .. code-block:: Python from skrub.datasets import fetch_open_payments dataset = fetch_open_payments() df = dataset.X y = dataset.y .. GENERATED FROM PYTHON SOURCE LINES 34-38 .. code-block:: Python from skrub import TableReport TableReport(df) .. raw:: html

Please enable javascript

The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").

.. GENERATED FROM PYTHON SOURCE LINES 39-43 .. code-block:: Python import pandas as pd TableReport(pd.DataFrame(y)) .. raw:: html

Please enable javascript

The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").

.. GENERATED FROM PYTHON SOURCE LINES 44-46 The dataset has over 70,000 records with only categorical features. Some categories are not well defined. .. GENERATED FROM PYTHON SOURCE LINES 49-54 Caching with :class:`~skore.EstimatorReport` and :class:`~skore.CrossValidationReport` ====================================================================================== We use `skrub` to create a simple predictive model that handles our dataset's challenges. .. GENERATED FROM PYTHON SOURCE LINES 54-60 .. code-block:: Python from skrub import tabular_learner model = tabular_learner("classifier") model .. raw:: html
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

.. GENERATED FROM PYTHON SOURCE LINES 61-63 This model handles all types of data: numbers, categories, dates, and missing values. Let's train it on part of our dataset. .. GENERATED FROM PYTHON SOURCE LINES 64-68 .. code-block:: Python from skore import train_test_split X_train, X_test, y_train, y_test = train_test_split(df, y, random_state=42) .. rst-class:: sphx-glr-script-out .. code-block:: none ╭───────────────────────────── HighClassImbalanceWarning ──────────────────────────────╮ │ It seems that you have a classification problem with a high class imbalance. In this │ │ case, using train_test_split may not be a good idea because of high variability in │ │ the scores obtained on the test set. To tackle this challenge we suggest to use │ │ skore's cross_validate function. │ ╰──────────────────────────────────────────────────────────────────────────────────────╯ ╭───────────────────────────────── ShuffleTrueWarning ─────────────────────────────────╮ │ We detected that the `shuffle` parameter is set to `True` either explicitly or from │ │ its default value. In case of time-ordered events (even if they are independent), │ │ this will result in inflated model performance evaluation because natural drift will │ │ not be taken into account. We recommend setting the shuffle parameter to `False` in │ │ order to ensure the evaluation process is really representative of your production │ │ release process. │ ╰──────────────────────────────────────────────────────────────────────────────────────╯ .. GENERATED FROM PYTHON SOURCE LINES 69-77 Caching the predictions for fast metric computation ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ First, we focus on :class:`~skore.EstimatorReport`, as the same philosophy will apply to :class:`~skore.CrossValidationReport`. Let's explore how :class:`~skore.EstimatorReport` uses caching to speed up predictions. We start by training the model: .. GENERATED FROM PYTHON SOURCE LINES 77-84 .. code-block:: Python from skore import EstimatorReport report = EstimatorReport( model, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test ) report.help() .. rst-class:: sphx-glr-script-out .. code-block:: none ╭──────────── Tools to diagnose estimator HistGradientBoostingClassifier ─────────────╮ │ EstimatorReport │ │ ├── .metrics │ │ │ ├── .accuracy(...) (↗︎) - Compute the accuracy score. │ │ │ ├── .brier_score(...) (↘︎) - Compute the Brier score. │ │ │ ├── .log_loss(...) (↘︎) - Compute the log loss. │ │ │ ├── .precision(...) (↗︎) - Compute the precision score. │ │ │ ├── .precision_recall(...) - Plot the precision-recall curve. │ │ │ ├── .recall(...) (↗︎) - Compute the recall score. │ │ │ ├── .roc(...) - Plot the ROC curve. │ │ │ ├── .roc_auc(...) (↗︎) - Compute the ROC AUC score. │ │ │ ├── .custom_metric(...) - Compute a custom metric. │ │ │ └── .report_metrics(...) - Report a set of metrics for our estimator. │ │ ├── .cache_predictions(...) - Cache estimator's predictions. │ │ ├── .clear_cache(...) - Clear the cache. │ │ └── Attributes │ │ ├── .X_test │ │ ├── .X_train │ │ ├── .y_test │ │ ├── .y_train │ │ ├── .estimator_ │ │ └── .estimator_name_ │ │ │ │ │ │ Legend: │ │ (↗︎) higher is better (↘︎) lower is better │ ╰─────────────────────────────────────────────────────────────────────────────────────╯ .. GENERATED FROM PYTHON SOURCE LINES 85-86 We compute the accuracy on our test set and measure how long it takes: .. GENERATED FROM PYTHON SOURCE LINES 87-94 .. code-block:: Python import time start = time.time() result = report.metrics.accuracy() end = time.time() result .. rst-class:: sphx-glr-script-out .. code-block:: none 0.9528548123980424 .. GENERATED FROM PYTHON SOURCE LINES 95-97 .. code-block:: Python print(f"Time taken: {end - start:.2f} seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Time taken: 1.56 seconds .. GENERATED FROM PYTHON SOURCE LINES 98-99 For comparison, here's how scikit-learn computes the same accuracy score: .. GENERATED FROM PYTHON SOURCE LINES 100-107 .. code-block:: Python from sklearn.metrics import accuracy_score start = time.time() result = accuracy_score(report.y_test, report.estimator_.predict(report.X_test)) end = time.time() result .. rst-class:: sphx-glr-script-out .. code-block:: none 0.9528548123980424 .. GENERATED FROM PYTHON SOURCE LINES 108-110 .. code-block:: Python print(f"Time taken: {end - start:.2f} seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Time taken: 1.55 seconds .. GENERATED FROM PYTHON SOURCE LINES 111-115 Both approaches take similar time. Now, watch what happens when we compute the accuracy again with our skore estimator report: .. GENERATED FROM PYTHON SOURCE LINES 116-121 .. code-block:: Python start = time.time() result = report.metrics.accuracy() end = time.time() result .. rst-class:: sphx-glr-script-out .. code-block:: none 0.9528548123980424 .. GENERATED FROM PYTHON SOURCE LINES 122-124 .. code-block:: Python print(f"Time taken: {end - start:.2f} seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Time taken: 0.00 seconds .. GENERATED FROM PYTHON SOURCE LINES 125-127 The second calculation is instant! This happens because the report saves previous calculations in its cache. Let's look inside the cache: .. GENERATED FROM PYTHON SOURCE LINES 128-130 .. code-block:: Python report._cache .. rst-class:: sphx-glr-script-out .. code-block:: none {(np.int64(785285855598900323), 'predict', 'test'): array(['disallowed', 'disallowed', 'disallowed', ..., 'disallowed', 'disallowed', 'disallowed'], shape=(18390,), dtype=object), (np.int64(785285855598900323), 'accuracy_score', 'test'): 0.9528548123980424} .. GENERATED FROM PYTHON SOURCE LINES 131-134 The cache stores predictions by type and data source. This means that computing metrics that use the same type of predictions will be faster. Let's try the precision metric: .. GENERATED FROM PYTHON SOURCE LINES 134-139 .. code-block:: Python start = time.time() result = report.metrics.precision() end = time.time() result .. rst-class:: sphx-glr-script-out .. code-block:: none {'allowed': np.float64(0.6906290115532734), 'disallowed': np.float64(0.9644540344103117)} .. GENERATED FROM PYTHON SOURCE LINES 140-142 .. code-block:: Python print(f"Time taken: {end - start:.2f} seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Time taken: 0.06 seconds .. GENERATED FROM PYTHON SOURCE LINES 143-148 We observe that it takes only a few milliseconds to compute the precision because we don't need to re-compute the predictions and only have to compute the precision metric itself. Since the predictions are the bottleneck in terms of computation time, we observe an interesting speedup. .. GENERATED FROM PYTHON SOURCE LINES 150-154 Caching all the possible predictions at once ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ We can pre-compute all predictions at once using parallel processing: .. GENERATED FROM PYTHON SOURCE LINES 154-156 .. code-block:: Python report.cache_predictions(n_jobs=4) .. GENERATED FROM PYTHON SOURCE LINES 157-159 Now, all possible predictions are stored. Any metric calculation will be much faster, even on different data (like the training set): .. GENERATED FROM PYTHON SOURCE LINES 160-165 .. code-block:: Python start = time.time() result = report.metrics.log_loss(data_source="train") end = time.time() result .. rst-class:: sphx-glr-script-out .. code-block:: none 0.09865494232505337 .. GENERATED FROM PYTHON SOURCE LINES 166-168 .. code-block:: Python print(f"Time taken: {end - start:.2f} seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Time taken: 0.11 seconds .. GENERATED FROM PYTHON SOURCE LINES 169-174 Caching external data ^^^^^^^^^^^^^^^^^^^^^ The report can also work with external data. We use `data_source="X_y"` to indicate that we want to pass those external data. .. GENERATED FROM PYTHON SOURCE LINES 174-179 .. code-block:: Python start = time.time() result = report.metrics.log_loss(data_source="X_y", X=X_test, y=y_test) end = time.time() result .. rst-class:: sphx-glr-script-out .. code-block:: none 0.12305206715107839 .. GENERATED FROM PYTHON SOURCE LINES 180-182 .. code-block:: Python print(f"Time taken: {end - start:.2f} seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Time taken: 1.75 seconds .. GENERATED FROM PYTHON SOURCE LINES 183-186 The first calculation of the above cell is slower than when using the internal train or test sets because it needs to compute a hash of the new data for later retrieval. Let's calculate it again: .. GENERATED FROM PYTHON SOURCE LINES 187-192 .. code-block:: Python start = time.time() result = report.metrics.log_loss(data_source="X_y", X=X_test, y=y_test) end = time.time() result .. rst-class:: sphx-glr-script-out .. code-block:: none 0.12305206715107839 .. GENERATED FROM PYTHON SOURCE LINES 193-195 .. code-block:: Python print(f"Time taken: {end - start:.2f} seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Time taken: 0.18 seconds .. GENERATED FROM PYTHON SOURCE LINES 196-199 It is much faster for the second time as the predictions are cached! The remaining time corresponds to the hash computation. Let's compute the ROC AUC on the same data: .. GENERATED FROM PYTHON SOURCE LINES 200-205 .. code-block:: Python start = time.time() result = report.metrics.roc_auc(data_source="X_y", X=X_test, y=y_test) end = time.time() result .. rst-class:: sphx-glr-script-out .. code-block:: none 0.9439820500298637 .. GENERATED FROM PYTHON SOURCE LINES 206-208 .. code-block:: Python print(f"Time taken: {end - start:.2f} seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Time taken: 0.21 seconds .. GENERATED FROM PYTHON SOURCE LINES 209-212 We observe that the computation is already efficient because it boils down to two computations: the hash of the data and the ROC-AUC metric. We save a lot of time because we don't need to re-compute the predictions. .. GENERATED FROM PYTHON SOURCE LINES 214-218 Caching for plotting ^^^^^^^^^^^^^^^^^^^^ The cache also speeds up plots. Let's create a ROC curve: .. GENERATED FROM PYTHON SOURCE LINES 218-226 .. code-block:: Python import matplotlib.pyplot as plt start = time.time() display = report.metrics.roc(pos_label="allowed") display.plot() end = time.time() plt.tight_layout() .. image-sg:: /auto_examples/technical_details/images/sphx_glr_plot_cache_mechanism_001.png :alt: plot cache mechanism :srcset: /auto_examples/technical_details/images/sphx_glr_plot_cache_mechanism_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 227-229 .. code-block:: Python print(f"Time taken: {end - start:.2f} seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Time taken: 0.02 seconds .. GENERATED FROM PYTHON SOURCE LINES 230-231 The second plot is instant because it uses cached data: .. GENERATED FROM PYTHON SOURCE LINES 232-238 .. code-block:: Python start = time.time() display = report.metrics.roc(pos_label="allowed") display.plot() end = time.time() plt.tight_layout() .. image-sg:: /auto_examples/technical_details/images/sphx_glr_plot_cache_mechanism_002.png :alt: plot cache mechanism :srcset: /auto_examples/technical_details/images/sphx_glr_plot_cache_mechanism_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 239-241 .. code-block:: Python print(f"Time taken: {end - start:.2f} seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Time taken: 0.01 seconds .. GENERATED FROM PYTHON SOURCE LINES 242-244 We only use the cache to retrieve the `display` object and not directly the matplotlib figure. It means that we can still customize the cached plot before displaying it: .. GENERATED FROM PYTHON SOURCE LINES 245-248 .. code-block:: Python display.plot(roc_curve_kwargs={"color": "tab:orange"}) plt.tight_layout() .. image-sg:: /auto_examples/technical_details/images/sphx_glr_plot_cache_mechanism_003.png :alt: plot cache mechanism :srcset: /auto_examples/technical_details/images/sphx_glr_plot_cache_mechanism_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 249-250 Be aware that we can clear the cache if we want to: .. GENERATED FROM PYTHON SOURCE LINES 251-254 .. code-block:: Python report.clear_cache() report._cache .. rst-class:: sphx-glr-script-out .. code-block:: none {} .. GENERATED FROM PYTHON SOURCE LINES 255-262 It means that nothing is stored anymore in the cache. Caching with :class:`~skore.CrossValidationReport` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ :class:`~skore.CrossValidationReport` uses the same caching system for each fold in cross-validation by leveraging the previous :class:`~skore.EstimatorReport`: .. GENERATED FROM PYTHON SOURCE LINES 263-268 .. code-block:: Python from skore import CrossValidationReport report = CrossValidationReport(model, X=df, y=y, cv_splitter=5, n_jobs=4) report.help() .. rst-class:: sphx-glr-script-out .. code-block:: none ╭───────────── Tools to diagnose estimator HistGradientBoostingClassifier ─────────────╮ │ CrossValidationReport │ │ ├── .metrics │ │ │ ├── .accuracy(...) (↗︎) - Compute the accuracy score. │ │ │ ├── .brier_score(...) (↘︎) - Compute the Brier score. │ │ │ ├── .log_loss(...) (↘︎) - Compute the log loss. │ │ │ ├── .precision(...) (↗︎) - Compute the precision score. │ │ │ ├── .precision_recall(...) - Plot the precision-recall curve. │ │ │ ├── .recall(...) (↗︎) - Compute the recall score. │ │ │ ├── .roc(...) - Plot the ROC curve. │ │ │ ├── .roc_auc(...) (↗︎) - Compute the ROC AUC score. │ │ │ ├── .custom_metric(...) - Compute a custom metric. │ │ │ └── .report_metrics(...) - Report a set of metrics for our estimator. │ │ ├── .cache_predictions(...) - Cache the predictions for sub-estimators │ │ │ reports. │ │ ├── .clear_cache(...) - Clear the cache. │ │ └── Attributes │ │ ├── .X │ │ ├── .y │ │ ├── .estimator_ │ │ ├── .estimator_name_ │ │ ├── .estimator_reports_ │ │ └── .n_jobs │ │ │ │ │ │ Legend: │ │ (↗︎) higher is better (↘︎) lower is better │ ╰──────────────────────────────────────────────────────────────────────────────────────╯ .. GENERATED FROM PYTHON SOURCE LINES 269-273 Since a :class:`~skore.CrossValidationReport` uses many :class:`~skore.EstimatorReport`, we will observe the same behaviour as we previously exposed. The first call will be slow because it computes the predictions for each fold. .. GENERATED FROM PYTHON SOURCE LINES 274-279 .. code-block:: Python start = time.time() result = report.metrics.report_metrics(aggregate=["mean", "std"]) end = time.time() result .. raw:: html
mean std
Metric Label / Average
Precision allowed 0.399561 0.126373
disallowed 0.959646 0.004407
Recall allowed 0.423438 0.084925
disallowed 0.943480 0.050043
ROC AUC 0.866834 0.037982
Brier score 0.068296 0.038357

.. GENERATED FROM PYTHON SOURCE LINES 280-282 .. code-block:: Python print(f"Time taken: {end - start:.2f} seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Time taken: 20.98 seconds .. GENERATED FROM PYTHON SOURCE LINES 283-284 But the subsequent calls are fast because the predictions are cached. .. GENERATED FROM PYTHON SOURCE LINES 285-290 .. code-block:: Python start = time.time() result = report.metrics.report_metrics(aggregate=["mean", "std"]) end = time.time() result .. raw:: html
mean std
Metric Label / Average
Precision allowed 0.399561 0.126373
disallowed 0.959646 0.004407
Recall allowed 0.423438 0.084925
disallowed 0.943480 0.050043
ROC AUC 0.866834 0.037982
Brier score 0.068296 0.038357

.. GENERATED FROM PYTHON SOURCE LINES 291-293 .. code-block:: Python print(f"Time taken: {end - start:.2f} seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Time taken: 0.00 seconds .. GENERATED FROM PYTHON SOURCE LINES 294-295 Hence, we observe the same type of behaviour as we previously exposed. .. rst-class:: sphx-glr-timing **Total running time of the script:** (1 minutes 55.838 seconds) .. _sphx_glr_download_auto_examples_technical_details_plot_cache_mechanism.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_cache_mechanism.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_cache_mechanism.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_cache_mechanism.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_