{
"cells": [
{
"cell_type": "markdown",
"id": "c855b45e",
"metadata": {},
"source": [
"# AutoML\n",
"\n",
"[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/etna-team/etna/master?filepath=examples/205-automl.ipynb)"
]
},
{
"cell_type": "markdown",
"id": "bca01a6c",
"metadata": {},
"source": [
"This notebooks covers AutoML utilities of ETNA library.\n",
"\n",
"**Table of contents**\n",
"\n",
"- [Hyperparameters tuning](#chapter_1)\n",
" - [How Tune works](#section_1_1)\n",
" - [Example](#section_1_2)\n",
"- [General AutoML](#chapter_2)\n",
" - [How Auto works](#section_2_1)\n",
" - [Example](#section_2_2)\n",
"- [Summary](chapter_3)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "45f65253",
"metadata": {},
"outputs": [],
"source": [
"!pip install \"etna[auto, prophet]\" -q"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6f70e872",
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b858a832",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"from etna.datasets import TSDataset\n",
"from etna.metrics import SMAPE\n",
"from etna.models import LinearPerSegmentModel\n",
"from etna.pipeline import Pipeline\n",
"from etna.transforms import DateFlagsTransform\n",
"from etna.transforms import LagTransform"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e50060f6",
"metadata": {},
"outputs": [],
"source": [
"HORIZON = 14"
]
},
{
"cell_type": "markdown",
"id": "33ad7417",
"metadata": {},
"source": [
"## 1. Hyperparameters tuning "
]
},
{
"cell_type": "markdown",
"id": "4542c8eb",
"metadata": {},
"source": [
"It is a common task to tune hyperparameters of existing pipeline to improve its quality. For this purpose there is an `etna.auto.Tune` class, which is responsible for creating [optuna](https://github.com/optuna/optuna) study to solve this problem.\n",
"\n",
"In the next sections we will see how it works and how to use it for your particular problems."
]
},
{
"cell_type": "markdown",
"id": "73194640",
"metadata": {},
"source": [
"### 1.1 How `Tune` works "
]
},
{
"cell_type": "markdown",
"id": "c7777ea3",
"metadata": {},
"source": [
"During init `Tune` accepts `pipeline`, its tuning parameters (`params_to_tune`), optimization metric (`target_metric`), parameters of backtest and parameters of optuna study.\n",
"\n",
"In `fit` the optuna study is created. During each trial the sample of parameters is generated from `params_to_tune` and applied to `pipeline`. After that, the new pipeline is checked in backtest and target metric is returned to optuna framework."
]
},
{
"cell_type": "markdown",
"id": "09e6cb8e",
"metadata": {},
"source": [
"Let's look closer at `params_to_tune` parameter. It expects dictionary with parameter names and its distributions. But how this parameter names should be chosen?"
]
},
{
"cell_type": "markdown",
"id": "5d7a777a",
"metadata": {},
"source": [
"#### 1.1.1 `set_params`"
]
},
{
"cell_type": "markdown",
"id": "cc05b85b",
"metadata": {},
"source": [
"We are going to make a little detour to explain the `set_params` method, which is supported by ETNA pipelines, models and transforms. Given a dictionary with parameters it allows to create from existing object a new one with changed parameters."
]
},
{
"cell_type": "markdown",
"id": "b291efa4",
"metadata": {},
"source": [
"First, we define some objects for our future examples."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9d6893b8",
"metadata": {},
"outputs": [],
"source": [
"model = LinearPerSegmentModel()\n",
"transforms = [\n",
" LagTransform(in_column=\"target\", lags=list(range(HORIZON, HORIZON + 10)), out_column=\"target_lag\"),\n",
" DateFlagsTransform(out_column=\"date_flags\"),\n",
"]\n",
"pipeline = Pipeline(model=model, transforms=transforms, horizon=HORIZON)"
]
},
{
"cell_type": "markdown",
"id": "01a57e5c",
"metadata": {},
"source": [
"Let's look at simple example, when we want to change `fit_intercept` parameter of the `model`."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "32c51370",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'fit_intercept': True,\n",
" 'kwargs': {},\n",
" '_target_': 'etna.models.linear.LinearPerSegmentModel'}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.to_dict()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "60bc963f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'fit_intercept': False,\n",
" 'kwargs': {},\n",
" '_target_': 'etna.models.linear.LinearPerSegmentModel'}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_model_params = {\"fit_intercept\": False}\n",
"new_model = model.set_params(**new_model_params)\n",
"new_model.to_dict()"
]
},
{
"cell_type": "markdown",
"id": "383931c2",
"metadata": {},
"source": [
"Great! On the next step we want to change the `fit_intercept` of `model` inside the `pipeline`."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7ff49f9a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'model': {'fit_intercept': True,\n",
" 'kwargs': {},\n",
" '_target_': 'etna.models.linear.LinearPerSegmentModel'},\n",
" 'transforms': [{'in_column': 'target',\n",
" 'lags': [14, 15, 16, 17, 18, 19, 20, 21, 22, 23],\n",
" 'out_column': 'target_lag',\n",
" '_target_': 'etna.transforms.math.lags.LagTransform'},\n",
" {'day_number_in_week': True,\n",
" 'day_number_in_month': True,\n",
" 'day_number_in_year': False,\n",
" 'week_number_in_month': False,\n",
" 'week_number_in_year': False,\n",
" 'month_number_in_year': False,\n",
" 'season_number': False,\n",
" 'year_number': False,\n",
" 'is_weekend': True,\n",
" 'special_days_in_week': (),\n",
" 'special_days_in_month': (),\n",
" 'out_column': 'date_flags',\n",
" '_target_': 'etna.transforms.timestamp.date_flags.DateFlagsTransform'}],\n",
" 'horizon': 14,\n",
" '_target_': 'etna.pipeline.pipeline.Pipeline'}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pipeline.to_dict()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "497662b6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'model': {'fit_intercept': False,\n",
" 'kwargs': {},\n",
" '_target_': 'etna.models.linear.LinearPerSegmentModel'},\n",
" 'transforms': [{'in_column': 'target',\n",
" 'lags': [14, 15, 16, 17, 18, 19, 20, 21, 22, 23],\n",
" 'out_column': 'target_lag',\n",
" '_target_': 'etna.transforms.math.lags.LagTransform'},\n",
" {'day_number_in_week': True,\n",
" 'day_number_in_month': True,\n",
" 'day_number_in_year': False,\n",
" 'week_number_in_month': False,\n",
" 'week_number_in_year': False,\n",
" 'month_number_in_year': False,\n",
" 'season_number': False,\n",
" 'year_number': False,\n",
" 'is_weekend': True,\n",
" 'special_days_in_week': (),\n",
" 'special_days_in_month': (),\n",
" 'out_column': 'date_flags',\n",
" '_target_': 'etna.transforms.timestamp.date_flags.DateFlagsTransform'}],\n",
" 'horizon': 14,\n",
" '_target_': 'etna.pipeline.pipeline.Pipeline'}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_pipeline_params = {\"model.fit_intercept\": False}\n",
"new_pipeline = pipeline.set_params(**new_pipeline_params)\n",
"new_pipeline.to_dict()"
]
},
{
"cell_type": "markdown",
"id": "8eba262b",
"metadata": {},
"source": [
"Ok, it looks like we managed to do this. On the last step we are going to change `is_weekend` flag of `DateFlagsTransform` inside our `pipeline`."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "28a1ac00",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'model': {'fit_intercept': True,\n",
" 'kwargs': {},\n",
" '_target_': 'etna.models.linear.LinearPerSegmentModel'},\n",
" 'transforms': [{'in_column': 'target',\n",
" 'lags': [14, 15, 16, 17, 18, 19, 20, 21, 22, 23],\n",
" 'out_column': 'target_lag',\n",
" '_target_': 'etna.transforms.math.lags.LagTransform'},\n",
" {'day_number_in_week': True,\n",
" 'day_number_in_month': True,\n",
" 'day_number_in_year': False,\n",
" 'week_number_in_month': False,\n",
" 'week_number_in_year': False,\n",
" 'month_number_in_year': False,\n",
" 'season_number': False,\n",
" 'year_number': False,\n",
" 'is_weekend': False,\n",
" 'special_days_in_week': (),\n",
" 'special_days_in_month': (),\n",
" 'out_column': 'date_flags',\n",
" '_target_': 'etna.transforms.timestamp.date_flags.DateFlagsTransform'}],\n",
" 'horizon': 14,\n",
" '_target_': 'etna.pipeline.pipeline.Pipeline'}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_pipeline_params = {\"transforms.1.is_weekend\": False}\n",
"new_pipeline = pipeline.set_params(**new_pipeline_params)\n",
"new_pipeline.to_dict()"
]
},
{
"cell_type": "markdown",
"id": "50c2bf0d",
"metadata": {},
"source": [
"As we can see, we managed to do this."
]
},
{
"cell_type": "markdown",
"id": "4deb8b8b",
"metadata": {},
"source": [
"#### 1.1.2 `params_to_tune`"
]
},
{
"cell_type": "markdown",
"id": "8c732645",
"metadata": {},
"source": [
"Let's get back to our initial question about `params_to_tune`. In our optuna study we are going to sample each parameter value from its distribution and pass it into `pipeline.set_params` method. So, the keys for `params_to_tune` should be a valid for `set_params` method.\n",
"\n",
"Distributions are taken from `etna.distributions` and they are matching `optuna.Trial.suggest_` methods."
]
},
{
"cell_type": "markdown",
"id": "f6a39f16",
"metadata": {},
"source": [
"For example, something like this will be valid for our `pipeline` defined above:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "4dab566f",
"metadata": {},
"outputs": [],
"source": [
"from etna.distributions import CategoricalDistribution\n",
"\n",
"example_params_to_tune = {\n",
" \"model.fit_intercept\": CategoricalDistribution([False, True]),\n",
" \"transforms.0.is_weekend\": CategoricalDistribution([False, True]),\n",
"}"
]
},
{
"cell_type": "markdown",
"id": "a0c81b84",
"metadata": {},
"source": [
"There are some good news: it isn't necessary for our users to define `params_to_tune`, because we have a default grid for many of our classes. The default grid is available by calling `params_to_tune` method on pipeline, model or transform. Let's check our `pipeline`:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "b493dace",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'model.fit_intercept': CategoricalDistribution(choices=[False, True]),\n",
" 'transforms.1.day_number_in_week': CategoricalDistribution(choices=[False, True]),\n",
" 'transforms.1.day_number_in_month': CategoricalDistribution(choices=[False, True]),\n",
" 'transforms.1.day_number_in_year': CategoricalDistribution(choices=[False, True]),\n",
" 'transforms.1.week_number_in_month': CategoricalDistribution(choices=[False, True]),\n",
" 'transforms.1.week_number_in_year': CategoricalDistribution(choices=[False, True]),\n",
" 'transforms.1.month_number_in_year': CategoricalDistribution(choices=[False, True]),\n",
" 'transforms.1.season_number': CategoricalDistribution(choices=[False, True]),\n",
" 'transforms.1.year_number': CategoricalDistribution(choices=[False, True]),\n",
" 'transforms.1.is_weekend': CategoricalDistribution(choices=[False, True])}"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pipeline.params_to_tune()"
]
},
{
"cell_type": "markdown",
"id": "554c5af2",
"metadata": {},
"source": [
"Now we are ready to use it in practice."
]
},
{
"cell_type": "markdown",
"id": "df2102f8",
"metadata": {},
"source": [
"### 1.2 Example "
]
},
{
"cell_type": "markdown",
"id": "535c0b18",
"metadata": {},
"source": [
"#### 1.2.1 Loading data"
]
},
{
"cell_type": "markdown",
"id": "9352eeb4",
"metadata": {},
"source": [
"Let's start by loading example data."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "0041c9ab",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" timestamp \n",
" segment \n",
" target \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 2019-01-01 \n",
" segment_a \n",
" 170 \n",
" \n",
" \n",
" 1 \n",
" 2019-01-02 \n",
" segment_a \n",
" 243 \n",
" \n",
" \n",
" 2 \n",
" 2019-01-03 \n",
" segment_a \n",
" 267 \n",
" \n",
" \n",
" 3 \n",
" 2019-01-04 \n",
" segment_a \n",
" 287 \n",
" \n",
" \n",
" 4 \n",
" 2019-01-05 \n",
" segment_a \n",
" 279 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" timestamp segment target\n",
"0 2019-01-01 segment_a 170\n",
"1 2019-01-02 segment_a 243\n",
"2 2019-01-03 segment_a 267\n",
"3 2019-01-04 segment_a 287\n",
"4 2019-01-05 segment_a 279"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv(\"data/example_dataset.csv\")\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "8996f93a",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"full_ts = TSDataset(df, freq=\"D\")\n",
"full_ts.plot()"
]
},
{
"cell_type": "markdown",
"id": "b1f95a3e",
"metadata": {},
"source": [
"Let's divide current dataset into train and validation parts. We will use validation part later to check final results."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "d72c9f19",
"metadata": {},
"outputs": [],
"source": [
"ts, _ = full_ts.train_test_split(test_size=HORIZON * 5)"
]
},
{
"cell_type": "markdown",
"id": "3efade22",
"metadata": {},
"source": [
"#### 1.2.2 Running `Tune`"
]
},
{
"cell_type": "markdown",
"id": "1338a41f",
"metadata": {},
"source": [
"We are going to define our `Tune` object:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "5e4efd0c",
"metadata": {},
"outputs": [],
"source": [
"from etna.auto import Tune\n",
"\n",
"tune = Tune(pipeline=pipeline, target_metric=SMAPE(), horizon=HORIZON, backtest_params=dict(n_folds=5))"
]
},
{
"cell_type": "markdown",
"id": "6d61b949",
"metadata": {},
"source": [
"We used mostly default parameters for this example. But for your own experiments you might want to also set up other parameters. \n",
"\n",
"For example, parameter `runner` allows you to run tuning in parallel on a local machine, and parameter `storage` makes it possible to store optuna results on a dedicated remote server.\n",
"\n",
"For a full list of parameters we advise you to check our documentation."
]
},
{
"cell_type": "markdown",
"id": "50779a99",
"metadata": {},
"source": [
"Let's hide the logs of optuna, there are too many of them for a notebook."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "1d6650e3",
"metadata": {},
"outputs": [],
"source": [
"import optuna\n",
"\n",
"optuna.logging.set_verbosity(optuna.logging.CRITICAL)"
]
},
{
"cell_type": "markdown",
"id": "032a192b",
"metadata": {},
"source": [
"Let's run the tuning"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "49c86098",
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"best_pipeline = tune.fit(ts=ts, n_trials=20)"
]
},
{
"cell_type": "markdown",
"id": "b644325b",
"metadata": {},
"source": [
"Command `%%capture` just hides the output."
]
},
{
"cell_type": "markdown",
"id": "218a48f2",
"metadata": {},
"source": [
"#### 1.2.3 Analysis"
]
},
{
"cell_type": "markdown",
"id": "1fd09627",
"metadata": {},
"source": [
"In the last section dedicated to `Tune` we will look at methods for result analysis."
]
},
{
"cell_type": "markdown",
"id": "3faf63b9",
"metadata": {},
"source": [
"First of all there is `summary` method that shows us the results of optuna trials."
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "14525b55",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" pipeline \n",
" hash \n",
" Sign_median \n",
" Sign_mean \n",
" Sign_std \n",
" Sign_percentile_5 \n",
" Sign_percentile_25 \n",
" Sign_percentile_75 \n",
" Sign_percentile_95 \n",
" SMAPE_median \n",
" ... \n",
" MSE_percentile_75 \n",
" MSE_percentile_95 \n",
" MedAE_median \n",
" MedAE_mean \n",
" MedAE_std \n",
" MedAE_percentile_5 \n",
" MedAE_percentile_25 \n",
" MedAE_percentile_75 \n",
" MedAE_percentile_95 \n",
" state \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" f4f02e1d5f60b8f322a4a8a622dd1c1e \n",
" -0.500000 \n",
" -0.478571 \n",
" 0.205204 \n",
" -0.672857 \n",
" -0.621429 \n",
" -0.357143 \n",
" -0.254286 \n",
" 5.806429 \n",
" ... \n",
" 2220.282484 \n",
" 2953.865443 \n",
" 21.000232 \n",
" 22.334611 \n",
" 8.070926 \n",
" 14.955846 \n",
" 18.861388 \n",
" 24.473455 \n",
" 31.581505 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 1 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 3d7b7af16d71a36f3b935f69e113e22d \n",
" -0.457143 \n",
" -0.485714 \n",
" 0.242437 \n",
" -0.745714 \n",
" -0.642857 \n",
" -0.300000 \n",
" -0.265714 \n",
" 5.856039 \n",
" ... \n",
" 2644.982216 \n",
" 3294.855806 \n",
" 22.762122 \n",
" 23.389796 \n",
" 8.482028 \n",
" 14.897792 \n",
" 19.344439 \n",
" 26.807479 \n",
" 32.760543 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 2 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 7c7932114268832a5458acfecfb453fc \n",
" -0.200000 \n",
" -0.271429 \n",
" 0.264447 \n",
" -0.581429 \n",
" -0.392857 \n",
" -0.078571 \n",
" -0.061429 \n",
" 5.693983 \n",
" ... \n",
" 3457.757162 \n",
" 4209.624737 \n",
" 22.572681 \n",
" 23.336111 \n",
" 12.049564 \n",
" 11.235277 \n",
" 18.503043 \n",
" 27.405750 \n",
" 36.505748 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 3 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" b7ac5f7fcf9c8959626befe263a9d561 \n",
" 0.000000 \n",
" -0.085714 \n",
" 0.211248 \n",
" -0.340000 \n",
" -0.100000 \n",
" 0.014286 \n",
" 0.048571 \n",
" 7.881275 \n",
" ... \n",
" 5039.841145 \n",
" 5665.228696 \n",
" 35.976862 \n",
" 33.937644 \n",
" 17.252826 \n",
" 14.444379 \n",
" 27.282228 \n",
" 42.632278 \n",
" 50.576005 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 4 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" e928929f89156d88ef49e28abaf55847 \n",
" -0.414286 \n",
" -0.421429 \n",
" 0.207840 \n",
" -0.620000 \n",
" -0.585714 \n",
" -0.250000 \n",
" -0.232857 \n",
" 6.032319 \n",
" ... \n",
" 3091.962427 \n",
" 3181.592755 \n",
" 23.166650 \n",
" 25.265089 \n",
" 13.224461 \n",
" 13.001779 \n",
" 18.666844 \n",
" 29.764896 \n",
" 40.466215 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 5 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 3b4311d41fcaab7307235ea23b6d4599 \n",
" -0.400000 \n",
" -0.385714 \n",
" 0.396927 \n",
" -0.788571 \n",
" -0.514286 \n",
" -0.271429 \n",
" 0.037143 \n",
" 6.653462 \n",
" ... \n",
" 3800.976318 \n",
" 4837.444681 \n",
" 35.792514 \n",
" 32.276030 \n",
" 16.296588 \n",
" 13.499409 \n",
" 24.106508 \n",
" 43.962035 \n",
" 46.129572 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 6 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 74065ebc11c81bed6a9819d026c7cd84 \n",
" -0.442857 \n",
" -0.435714 \n",
" 0.246196 \n",
" -0.672857 \n",
" -0.621429 \n",
" -0.257143 \n",
" -0.188571 \n",
" 5.739626 \n",
" ... \n",
" 2933.246064 \n",
" 4802.299660 \n",
" 27.304852 \n",
" 24.936077 \n",
" 8.294963 \n",
" 15.108636 \n",
" 21.478207 \n",
" 30.762723 \n",
" 31.447233 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 7 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" b0d0420255c6117045f8254bf8f377a0 \n",
" -0.442857 \n",
" -0.464286 \n",
" 0.260167 \n",
" -0.725714 \n",
" -0.657143 \n",
" -0.250000 \n",
" -0.232857 \n",
" 6.042134 \n",
" ... \n",
" 2682.735922 \n",
" 3688.168155 \n",
" 28.393903 \n",
" 25.819143 \n",
" 8.652993 \n",
" 15.618131 \n",
" 21.989342 \n",
" 32.223704 \n",
" 32.415490 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 8 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 25dcd8bb095f87a1ffc499fa6a83ef5d \n",
" -0.457143 \n",
" -0.457143 \n",
" 0.265986 \n",
" -0.705714 \n",
" -0.671429 \n",
" -0.242857 \n",
" -0.208571 \n",
" 5.869280 \n",
" ... \n",
" 3098.567787 \n",
" 3154.538337 \n",
" 22.380642 \n",
" 24.289797 \n",
" 11.998603 \n",
" 13.252341 \n",
" 19.168974 \n",
" 27.501465 \n",
" 38.000072 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 9 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 3f1ca1759261598081fa3bb2f32fe0ac \n",
" -0.414286 \n",
" -0.435714 \n",
" 0.292654 \n",
" -0.725714 \n",
" -0.657143 \n",
" -0.192857 \n",
" -0.175714 \n",
" 6.608191 \n",
" ... \n",
" 3044.388978 \n",
" 3611.477391 \n",
" 23.750327 \n",
" 26.488927 \n",
" 13.825791 \n",
" 14.242057 \n",
" 20.027917 \n",
" 30.211337 \n",
" 42.569838 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 10 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" -0.157143 \n",
" -0.185714 \n",
" 0.226779 \n",
" -0.431429 \n",
" -0.328571 \n",
" -0.014286 \n",
" 0.020000 \n",
" 5.974832 \n",
" ... \n",
" 2902.306123 \n",
" 3526.513999 \n",
" 17.027383 \n",
" 21.682156 \n",
" 15.988286 \n",
" 9.110958 \n",
" 11.100846 \n",
" 27.608693 \n",
" 40.770037 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 11 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" -0.157143 \n",
" -0.185714 \n",
" 0.226779 \n",
" -0.431429 \n",
" -0.328571 \n",
" -0.014286 \n",
" 0.020000 \n",
" 5.974832 \n",
" ... \n",
" 2902.306123 \n",
" 3526.513999 \n",
" 17.027383 \n",
" 21.682156 \n",
" 15.988286 \n",
" 9.110958 \n",
" 11.100846 \n",
" 27.608693 \n",
" 40.770037 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 12 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" -0.157143 \n",
" -0.185714 \n",
" 0.226779 \n",
" -0.431429 \n",
" -0.328571 \n",
" -0.014286 \n",
" 0.020000 \n",
" 5.974832 \n",
" ... \n",
" 2902.306123 \n",
" 3526.513999 \n",
" 17.027383 \n",
" 21.682156 \n",
" 15.988286 \n",
" 9.110958 \n",
" 11.100846 \n",
" 27.608693 \n",
" 40.770037 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 13 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" -0.157143 \n",
" -0.185714 \n",
" 0.226779 \n",
" -0.431429 \n",
" -0.328571 \n",
" -0.014286 \n",
" 0.020000 \n",
" 5.974832 \n",
" ... \n",
" 2902.306123 \n",
" 3526.513999 \n",
" 17.027383 \n",
" 21.682156 \n",
" 15.988286 \n",
" 9.110958 \n",
" 11.100846 \n",
" 27.608693 \n",
" 40.770037 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 14 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" -0.157143 \n",
" -0.185714 \n",
" 0.226779 \n",
" -0.431429 \n",
" -0.328571 \n",
" -0.014286 \n",
" 0.020000 \n",
" 5.974832 \n",
" ... \n",
" 2902.306123 \n",
" 3526.513999 \n",
" 17.027383 \n",
" 21.682156 \n",
" 15.988286 \n",
" 9.110958 \n",
" 11.100846 \n",
" 27.608693 \n",
" 40.770037 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 15 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" -0.157143 \n",
" -0.185714 \n",
" 0.226779 \n",
" -0.431429 \n",
" -0.328571 \n",
" -0.014286 \n",
" 0.020000 \n",
" 5.974832 \n",
" ... \n",
" 2902.306123 \n",
" 3526.513999 \n",
" 17.027383 \n",
" 21.682156 \n",
" 15.988286 \n",
" 9.110958 \n",
" 11.100846 \n",
" 27.608693 \n",
" 40.770037 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 16 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" -0.157143 \n",
" -0.185714 \n",
" 0.226779 \n",
" -0.431429 \n",
" -0.328571 \n",
" -0.014286 \n",
" 0.020000 \n",
" 5.974832 \n",
" ... \n",
" 2902.306123 \n",
" 3526.513999 \n",
" 17.027383 \n",
" 21.682156 \n",
" 15.988286 \n",
" 9.110958 \n",
" 11.100846 \n",
" 27.608693 \n",
" 40.770037 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 17 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" -0.157143 \n",
" -0.185714 \n",
" 0.226779 \n",
" -0.431429 \n",
" -0.328571 \n",
" -0.014286 \n",
" 0.020000 \n",
" 5.974832 \n",
" ... \n",
" 2902.306123 \n",
" 3526.513999 \n",
" 17.027383 \n",
" 21.682156 \n",
" 15.988286 \n",
" 9.110958 \n",
" 11.100846 \n",
" 27.608693 \n",
" 40.770037 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 18 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 6f595f4f43b323804c04d4cea49c169b \n",
" -0.414286 \n",
" -0.435714 \n",
" 0.325242 \n",
" -0.754286 \n",
" -0.685714 \n",
" -0.164286 \n",
" -0.147143 \n",
" 5.657316 \n",
" ... \n",
" 2247.347025 \n",
" 2681.501259 \n",
" 21.624614 \n",
" 22.111993 \n",
" 7.952462 \n",
" 14.197890 \n",
" 17.080865 \n",
" 26.655742 \n",
" 30.708428 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 19 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" -0.157143 \n",
" -0.185714 \n",
" 0.226779 \n",
" -0.431429 \n",
" -0.328571 \n",
" -0.014286 \n",
" 0.020000 \n",
" 5.974832 \n",
" ... \n",
" 2902.306123 \n",
" 3526.513999 \n",
" 17.027383 \n",
" 21.682156 \n",
" 15.988286 \n",
" 9.110958 \n",
" 11.100846 \n",
" 27.608693 \n",
" 40.770037 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
"
\n",
"
20 rows × 38 columns
\n",
"
"
],
"text/plain": [
" pipeline \\\n",
"0 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"1 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"2 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"3 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"4 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"5 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"6 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"7 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"8 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"9 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"10 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"11 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"12 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"13 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"14 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"15 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"16 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"17 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"18 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"19 Pipeline(model = LinearPerSegmentModel(fit_int... \n",
"\n",
" hash Sign_median Sign_mean Sign_std \\\n",
"0 f4f02e1d5f60b8f322a4a8a622dd1c1e -0.500000 -0.478571 0.205204 \n",
"1 3d7b7af16d71a36f3b935f69e113e22d -0.457143 -0.485714 0.242437 \n",
"2 7c7932114268832a5458acfecfb453fc -0.200000 -0.271429 0.264447 \n",
"3 b7ac5f7fcf9c8959626befe263a9d561 0.000000 -0.085714 0.211248 \n",
"4 e928929f89156d88ef49e28abaf55847 -0.414286 -0.421429 0.207840 \n",
"5 3b4311d41fcaab7307235ea23b6d4599 -0.400000 -0.385714 0.396927 \n",
"6 74065ebc11c81bed6a9819d026c7cd84 -0.442857 -0.435714 0.246196 \n",
"7 b0d0420255c6117045f8254bf8f377a0 -0.442857 -0.464286 0.260167 \n",
"8 25dcd8bb095f87a1ffc499fa6a83ef5d -0.457143 -0.457143 0.265986 \n",
"9 3f1ca1759261598081fa3bb2f32fe0ac -0.414286 -0.435714 0.292654 \n",
"10 8363309e454e72993f86f10c7fc7c137 -0.157143 -0.185714 0.226779 \n",
"11 8363309e454e72993f86f10c7fc7c137 -0.157143 -0.185714 0.226779 \n",
"12 8363309e454e72993f86f10c7fc7c137 -0.157143 -0.185714 0.226779 \n",
"13 8363309e454e72993f86f10c7fc7c137 -0.157143 -0.185714 0.226779 \n",
"14 8363309e454e72993f86f10c7fc7c137 -0.157143 -0.185714 0.226779 \n",
"15 8363309e454e72993f86f10c7fc7c137 -0.157143 -0.185714 0.226779 \n",
"16 8363309e454e72993f86f10c7fc7c137 -0.157143 -0.185714 0.226779 \n",
"17 8363309e454e72993f86f10c7fc7c137 -0.157143 -0.185714 0.226779 \n",
"18 6f595f4f43b323804c04d4cea49c169b -0.414286 -0.435714 0.325242 \n",
"19 8363309e454e72993f86f10c7fc7c137 -0.157143 -0.185714 0.226779 \n",
"\n",
" Sign_percentile_5 Sign_percentile_25 Sign_percentile_75 \\\n",
"0 -0.672857 -0.621429 -0.357143 \n",
"1 -0.745714 -0.642857 -0.300000 \n",
"2 -0.581429 -0.392857 -0.078571 \n",
"3 -0.340000 -0.100000 0.014286 \n",
"4 -0.620000 -0.585714 -0.250000 \n",
"5 -0.788571 -0.514286 -0.271429 \n",
"6 -0.672857 -0.621429 -0.257143 \n",
"7 -0.725714 -0.657143 -0.250000 \n",
"8 -0.705714 -0.671429 -0.242857 \n",
"9 -0.725714 -0.657143 -0.192857 \n",
"10 -0.431429 -0.328571 -0.014286 \n",
"11 -0.431429 -0.328571 -0.014286 \n",
"12 -0.431429 -0.328571 -0.014286 \n",
"13 -0.431429 -0.328571 -0.014286 \n",
"14 -0.431429 -0.328571 -0.014286 \n",
"15 -0.431429 -0.328571 -0.014286 \n",
"16 -0.431429 -0.328571 -0.014286 \n",
"17 -0.431429 -0.328571 -0.014286 \n",
"18 -0.754286 -0.685714 -0.164286 \n",
"19 -0.431429 -0.328571 -0.014286 \n",
"\n",
" Sign_percentile_95 SMAPE_median ... MSE_percentile_75 \\\n",
"0 -0.254286 5.806429 ... 2220.282484 \n",
"1 -0.265714 5.856039 ... 2644.982216 \n",
"2 -0.061429 5.693983 ... 3457.757162 \n",
"3 0.048571 7.881275 ... 5039.841145 \n",
"4 -0.232857 6.032319 ... 3091.962427 \n",
"5 0.037143 6.653462 ... 3800.976318 \n",
"6 -0.188571 5.739626 ... 2933.246064 \n",
"7 -0.232857 6.042134 ... 2682.735922 \n",
"8 -0.208571 5.869280 ... 3098.567787 \n",
"9 -0.175714 6.608191 ... 3044.388978 \n",
"10 0.020000 5.974832 ... 2902.306123 \n",
"11 0.020000 5.974832 ... 2902.306123 \n",
"12 0.020000 5.974832 ... 2902.306123 \n",
"13 0.020000 5.974832 ... 2902.306123 \n",
"14 0.020000 5.974832 ... 2902.306123 \n",
"15 0.020000 5.974832 ... 2902.306123 \n",
"16 0.020000 5.974832 ... 2902.306123 \n",
"17 0.020000 5.974832 ... 2902.306123 \n",
"18 -0.147143 5.657316 ... 2247.347025 \n",
"19 0.020000 5.974832 ... 2902.306123 \n",
"\n",
" MSE_percentile_95 MedAE_median MedAE_mean MedAE_std \\\n",
"0 2953.865443 21.000232 22.334611 8.070926 \n",
"1 3294.855806 22.762122 23.389796 8.482028 \n",
"2 4209.624737 22.572681 23.336111 12.049564 \n",
"3 5665.228696 35.976862 33.937644 17.252826 \n",
"4 3181.592755 23.166650 25.265089 13.224461 \n",
"5 4837.444681 35.792514 32.276030 16.296588 \n",
"6 4802.299660 27.304852 24.936077 8.294963 \n",
"7 3688.168155 28.393903 25.819143 8.652993 \n",
"8 3154.538337 22.380642 24.289797 11.998603 \n",
"9 3611.477391 23.750327 26.488927 13.825791 \n",
"10 3526.513999 17.027383 21.682156 15.988286 \n",
"11 3526.513999 17.027383 21.682156 15.988286 \n",
"12 3526.513999 17.027383 21.682156 15.988286 \n",
"13 3526.513999 17.027383 21.682156 15.988286 \n",
"14 3526.513999 17.027383 21.682156 15.988286 \n",
"15 3526.513999 17.027383 21.682156 15.988286 \n",
"16 3526.513999 17.027383 21.682156 15.988286 \n",
"17 3526.513999 17.027383 21.682156 15.988286 \n",
"18 2681.501259 21.624614 22.111993 7.952462 \n",
"19 3526.513999 17.027383 21.682156 15.988286 \n",
"\n",
" MedAE_percentile_5 MedAE_percentile_25 MedAE_percentile_75 \\\n",
"0 14.955846 18.861388 24.473455 \n",
"1 14.897792 19.344439 26.807479 \n",
"2 11.235277 18.503043 27.405750 \n",
"3 14.444379 27.282228 42.632278 \n",
"4 13.001779 18.666844 29.764896 \n",
"5 13.499409 24.106508 43.962035 \n",
"6 15.108636 21.478207 30.762723 \n",
"7 15.618131 21.989342 32.223704 \n",
"8 13.252341 19.168974 27.501465 \n",
"9 14.242057 20.027917 30.211337 \n",
"10 9.110958 11.100846 27.608693 \n",
"11 9.110958 11.100846 27.608693 \n",
"12 9.110958 11.100846 27.608693 \n",
"13 9.110958 11.100846 27.608693 \n",
"14 9.110958 11.100846 27.608693 \n",
"15 9.110958 11.100846 27.608693 \n",
"16 9.110958 11.100846 27.608693 \n",
"17 9.110958 11.100846 27.608693 \n",
"18 14.197890 17.080865 26.655742 \n",
"19 9.110958 11.100846 27.608693 \n",
"\n",
" MedAE_percentile_95 state \n",
"0 31.581505 TrialState.COMPLETE \n",
"1 32.760543 TrialState.COMPLETE \n",
"2 36.505748 TrialState.COMPLETE \n",
"3 50.576005 TrialState.COMPLETE \n",
"4 40.466215 TrialState.COMPLETE \n",
"5 46.129572 TrialState.COMPLETE \n",
"6 31.447233 TrialState.COMPLETE \n",
"7 32.415490 TrialState.COMPLETE \n",
"8 38.000072 TrialState.COMPLETE \n",
"9 42.569838 TrialState.COMPLETE \n",
"10 40.770037 TrialState.COMPLETE \n",
"11 40.770037 TrialState.COMPLETE \n",
"12 40.770037 TrialState.COMPLETE \n",
"13 40.770037 TrialState.COMPLETE \n",
"14 40.770037 TrialState.COMPLETE \n",
"15 40.770037 TrialState.COMPLETE \n",
"16 40.770037 TrialState.COMPLETE \n",
"17 40.770037 TrialState.COMPLETE \n",
"18 30.708428 TrialState.COMPLETE \n",
"19 40.770037 TrialState.COMPLETE \n",
"\n",
"[20 rows x 38 columns]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tune.summary()"
]
},
{
"cell_type": "markdown",
"id": "cf987a2e",
"metadata": {},
"source": [
"Let's show only the columns we are interested in."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "b650bfc7",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" hash \n",
" pipeline \n",
" SMAPE_mean \n",
" state \n",
" \n",
" \n",
" \n",
" \n",
" 19 \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8.556535 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 17 \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8.556535 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 16 \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8.556535 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 15 \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8.556535 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 14 \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8.556535 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 13 \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8.556535 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 12 \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8.556535 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 10 \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8.556535 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 11 \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8.556535 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 2 \n",
" 7c7932114268832a5458acfecfb453fc \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 9.210183 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 8 \n",
" 25dcd8bb095f87a1ffc499fa6a83ef5d \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 9.943658 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 4 \n",
" e928929f89156d88ef49e28abaf55847 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 9.946866 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 0 \n",
" f4f02e1d5f60b8f322a4a8a622dd1c1e \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 9.957781 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 18 \n",
" 6f595f4f43b323804c04d4cea49c169b \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 10.061742 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 1 \n",
" 3d7b7af16d71a36f3b935f69e113e22d \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 10.306909 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 9 \n",
" 3f1ca1759261598081fa3bb2f32fe0ac \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 10.554444 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 5 \n",
" 3b4311d41fcaab7307235ea23b6d4599 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 10.756703 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 6 \n",
" 74065ebc11c81bed6a9819d026c7cd84 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 10.917164 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 3 \n",
" b7ac5f7fcf9c8959626befe263a9d561 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 11.255320 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 7 \n",
" b0d0420255c6117045f8254bf8f377a0 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 11.478760 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" hash \\\n",
"19 8363309e454e72993f86f10c7fc7c137 \n",
"17 8363309e454e72993f86f10c7fc7c137 \n",
"16 8363309e454e72993f86f10c7fc7c137 \n",
"15 8363309e454e72993f86f10c7fc7c137 \n",
"14 8363309e454e72993f86f10c7fc7c137 \n",
"13 8363309e454e72993f86f10c7fc7c137 \n",
"12 8363309e454e72993f86f10c7fc7c137 \n",
"10 8363309e454e72993f86f10c7fc7c137 \n",
"11 8363309e454e72993f86f10c7fc7c137 \n",
"2 7c7932114268832a5458acfecfb453fc \n",
"8 25dcd8bb095f87a1ffc499fa6a83ef5d \n",
"4 e928929f89156d88ef49e28abaf55847 \n",
"0 f4f02e1d5f60b8f322a4a8a622dd1c1e \n",
"18 6f595f4f43b323804c04d4cea49c169b \n",
"1 3d7b7af16d71a36f3b935f69e113e22d \n",
"9 3f1ca1759261598081fa3bb2f32fe0ac \n",
"5 3b4311d41fcaab7307235ea23b6d4599 \n",
"6 74065ebc11c81bed6a9819d026c7cd84 \n",
"3 b7ac5f7fcf9c8959626befe263a9d561 \n",
"7 b0d0420255c6117045f8254bf8f377a0 \n",
"\n",
" pipeline SMAPE_mean \\\n",
"19 Pipeline(model = LinearPerSegmentModel(fit_int... 8.556535 \n",
"17 Pipeline(model = LinearPerSegmentModel(fit_int... 8.556535 \n",
"16 Pipeline(model = LinearPerSegmentModel(fit_int... 8.556535 \n",
"15 Pipeline(model = LinearPerSegmentModel(fit_int... 8.556535 \n",
"14 Pipeline(model = LinearPerSegmentModel(fit_int... 8.556535 \n",
"13 Pipeline(model = LinearPerSegmentModel(fit_int... 8.556535 \n",
"12 Pipeline(model = LinearPerSegmentModel(fit_int... 8.556535 \n",
"10 Pipeline(model = LinearPerSegmentModel(fit_int... 8.556535 \n",
"11 Pipeline(model = LinearPerSegmentModel(fit_int... 8.556535 \n",
"2 Pipeline(model = LinearPerSegmentModel(fit_int... 9.210183 \n",
"8 Pipeline(model = LinearPerSegmentModel(fit_int... 9.943658 \n",
"4 Pipeline(model = LinearPerSegmentModel(fit_int... 9.946866 \n",
"0 Pipeline(model = LinearPerSegmentModel(fit_int... 9.957781 \n",
"18 Pipeline(model = LinearPerSegmentModel(fit_int... 10.061742 \n",
"1 Pipeline(model = LinearPerSegmentModel(fit_int... 10.306909 \n",
"9 Pipeline(model = LinearPerSegmentModel(fit_int... 10.554444 \n",
"5 Pipeline(model = LinearPerSegmentModel(fit_int... 10.756703 \n",
"6 Pipeline(model = LinearPerSegmentModel(fit_int... 10.917164 \n",
"3 Pipeline(model = LinearPerSegmentModel(fit_int... 11.255320 \n",
"7 Pipeline(model = LinearPerSegmentModel(fit_int... 11.478760 \n",
"\n",
" state \n",
"19 TrialState.COMPLETE \n",
"17 TrialState.COMPLETE \n",
"16 TrialState.COMPLETE \n",
"15 TrialState.COMPLETE \n",
"14 TrialState.COMPLETE \n",
"13 TrialState.COMPLETE \n",
"12 TrialState.COMPLETE \n",
"10 TrialState.COMPLETE \n",
"11 TrialState.COMPLETE \n",
"2 TrialState.COMPLETE \n",
"8 TrialState.COMPLETE \n",
"4 TrialState.COMPLETE \n",
"0 TrialState.COMPLETE \n",
"18 TrialState.COMPLETE \n",
"1 TrialState.COMPLETE \n",
"9 TrialState.COMPLETE \n",
"5 TrialState.COMPLETE \n",
"6 TrialState.COMPLETE \n",
"3 TrialState.COMPLETE \n",
"7 TrialState.COMPLETE "
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tune.summary()[[\"hash\", \"pipeline\", \"SMAPE_mean\", \"state\"]].sort_values(\"SMAPE_mean\")"
]
},
{
"cell_type": "markdown",
"id": "95721277",
"metadata": {},
"source": [
"As we can see, we have duplicate lines according to the `hash` column. Some trials have the same sampled hyperparameters and they have the same results. We have a special handling for such duplicates: they are skipped during optimization and the previously computed metric values are returned.\n",
"\n",
"Duplicates on the summary can be eliminated using `hash` column."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "7506fe96",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" hash \n",
" pipeline \n",
" SMAPE_mean \n",
" state \n",
" \n",
" \n",
" \n",
" \n",
" 19 \n",
" 8363309e454e72993f86f10c7fc7c137 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 8.556535 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 2 \n",
" 7c7932114268832a5458acfecfb453fc \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 9.210183 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 8 \n",
" 25dcd8bb095f87a1ffc499fa6a83ef5d \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 9.943658 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 4 \n",
" e928929f89156d88ef49e28abaf55847 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 9.946866 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 0 \n",
" f4f02e1d5f60b8f322a4a8a622dd1c1e \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 9.957781 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 18 \n",
" 6f595f4f43b323804c04d4cea49c169b \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 10.061742 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 1 \n",
" 3d7b7af16d71a36f3b935f69e113e22d \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 10.306909 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 9 \n",
" 3f1ca1759261598081fa3bb2f32fe0ac \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 10.554444 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 5 \n",
" 3b4311d41fcaab7307235ea23b6d4599 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 10.756703 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 6 \n",
" 74065ebc11c81bed6a9819d026c7cd84 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 10.917164 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 3 \n",
" b7ac5f7fcf9c8959626befe263a9d561 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 11.255320 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
" 7 \n",
" b0d0420255c6117045f8254bf8f377a0 \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 11.478760 \n",
" TrialState.COMPLETE \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" hash \\\n",
"19 8363309e454e72993f86f10c7fc7c137 \n",
"2 7c7932114268832a5458acfecfb453fc \n",
"8 25dcd8bb095f87a1ffc499fa6a83ef5d \n",
"4 e928929f89156d88ef49e28abaf55847 \n",
"0 f4f02e1d5f60b8f322a4a8a622dd1c1e \n",
"18 6f595f4f43b323804c04d4cea49c169b \n",
"1 3d7b7af16d71a36f3b935f69e113e22d \n",
"9 3f1ca1759261598081fa3bb2f32fe0ac \n",
"5 3b4311d41fcaab7307235ea23b6d4599 \n",
"6 74065ebc11c81bed6a9819d026c7cd84 \n",
"3 b7ac5f7fcf9c8959626befe263a9d561 \n",
"7 b0d0420255c6117045f8254bf8f377a0 \n",
"\n",
" pipeline SMAPE_mean \\\n",
"19 Pipeline(model = LinearPerSegmentModel(fit_int... 8.556535 \n",
"2 Pipeline(model = LinearPerSegmentModel(fit_int... 9.210183 \n",
"8 Pipeline(model = LinearPerSegmentModel(fit_int... 9.943658 \n",
"4 Pipeline(model = LinearPerSegmentModel(fit_int... 9.946866 \n",
"0 Pipeline(model = LinearPerSegmentModel(fit_int... 9.957781 \n",
"18 Pipeline(model = LinearPerSegmentModel(fit_int... 10.061742 \n",
"1 Pipeline(model = LinearPerSegmentModel(fit_int... 10.306909 \n",
"9 Pipeline(model = LinearPerSegmentModel(fit_int... 10.554444 \n",
"5 Pipeline(model = LinearPerSegmentModel(fit_int... 10.756703 \n",
"6 Pipeline(model = LinearPerSegmentModel(fit_int... 10.917164 \n",
"3 Pipeline(model = LinearPerSegmentModel(fit_int... 11.255320 \n",
"7 Pipeline(model = LinearPerSegmentModel(fit_int... 11.478760 \n",
"\n",
" state \n",
"19 TrialState.COMPLETE \n",
"2 TrialState.COMPLETE \n",
"8 TrialState.COMPLETE \n",
"4 TrialState.COMPLETE \n",
"0 TrialState.COMPLETE \n",
"18 TrialState.COMPLETE \n",
"1 TrialState.COMPLETE \n",
"9 TrialState.COMPLETE \n",
"5 TrialState.COMPLETE \n",
"6 TrialState.COMPLETE \n",
"3 TrialState.COMPLETE \n",
"7 TrialState.COMPLETE "
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tune.summary()[[\"hash\", \"pipeline\", \"SMAPE_mean\", \"state\"]].sort_values(\"SMAPE_mean\").drop_duplicates(subset=\"hash\")"
]
},
{
"cell_type": "markdown",
"id": "a642c361",
"metadata": {},
"source": [
"The second method `top_k` is useful when you want to check out best tried pipelines without duplicates."
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "6f707553",
"metadata": {},
"outputs": [],
"source": [
"top_3_pipelines = tune.top_k(k=3)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "7fd2b238",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Pipeline(model = LinearPerSegmentModel(fit_intercept = True, ), transforms = [LagTransform(in_column = 'target', lags = [14, 15, 16, 17, 18, 19, 20, 21, 22, 23], out_column = 'target_lag', ), DateFlagsTransform(day_number_in_week = False, day_number_in_month = True, day_number_in_year = False, week_number_in_month = True, week_number_in_year = False, month_number_in_year = False, season_number = False, year_number = False, is_weekend = True, special_days_in_week = (), special_days_in_month = (), out_column = 'date_flags', )], horizon = 14, ),\n",
" Pipeline(model = LinearPerSegmentModel(fit_intercept = True, ), transforms = [LagTransform(in_column = 'target', lags = [14, 15, 16, 17, 18, 19, 20, 21, 22, 23], out_column = 'target_lag', ), DateFlagsTransform(day_number_in_week = False, day_number_in_month = True, day_number_in_year = False, week_number_in_month = True, week_number_in_year = False, month_number_in_year = False, season_number = False, year_number = False, is_weekend = False, special_days_in_week = (), special_days_in_month = (), out_column = 'date_flags', )], horizon = 14, ),\n",
" Pipeline(model = LinearPerSegmentModel(fit_intercept = False, ), transforms = [LagTransform(in_column = 'target', lags = [14, 15, 16, 17, 18, 19, 20, 21, 22, 23], out_column = 'target_lag', ), DateFlagsTransform(day_number_in_week = True, day_number_in_month = False, day_number_in_year = True, week_number_in_month = False, week_number_in_year = False, month_number_in_year = False, season_number = False, year_number = True, is_weekend = False, special_days_in_week = (), special_days_in_month = (), out_column = 'date_flags', )], horizon = 14, )]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"top_3_pipelines"
]
},
{
"cell_type": "markdown",
"id": "15ef8f10",
"metadata": {},
"source": [
"## 2. General AutoML "
]
},
{
"cell_type": "markdown",
"id": "fef33f7e",
"metadata": {},
"source": [
"Hyperparameters tuning is useful, but can be too narrow. In this section we move our attention to general AutoML pipeline.\n",
"In ETNA we have an `etna.auto.Auto` class for making automatic pipeline selection. It can be useful to quickly create a good baseline for your forecasting task."
]
},
{
"cell_type": "markdown",
"id": "8c1763e2",
"metadata": {},
"source": [
"### 2.1 How `Auto` works "
]
},
{
"cell_type": "markdown",
"id": "4e20092d",
"metadata": {},
"source": [
"`Auto` init has similar parameters to `Tune`, but instead of `pipeline` it works with `pool`. Pool, in general, is just a list of pipelines.\n",
"\n",
"During `fit` there are two stages:\n",
"\n",
"- pool stage,\n",
"- tuning stage.\n",
"\n",
"Pool stage is responsible for checking every pipeline suggested in a given `pool`. For each pipeline we run a backtest and compute `target_metric`. Results are saved in optuna study.\n",
"\n",
"Tuning stage takes `tune_size` best pipelines according to the resuls of the pool stage. And then runs `Tune` with default `params_to_tune` for them sequentially from best to the worst. \n",
"\n",
"Limit parameters `n_trials` and `timeout` are shared between pool and tuning stages. First, we run pool stage with given `n_trials` and `timeout`. After that, the remaining values are divided equally among `tune_size` tuning steps."
]
},
{
"cell_type": "markdown",
"id": "96b2fb38",
"metadata": {},
"source": [
"### 2.2 Example "
]
},
{
"cell_type": "markdown",
"id": "02b2c527",
"metadata": {},
"source": [
"We will move stright to the example."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "ea97e2f7",
"metadata": {},
"outputs": [],
"source": [
"from etna.auto import Auto\n",
"\n",
"auto = Auto(target_metric=SMAPE(), horizon=HORIZON, backtest_params=dict(n_folds=5))"
]
},
{
"cell_type": "markdown",
"id": "83fe5077",
"metadata": {},
"source": [
"We used mostly default parameters, even pool. There is also a default `sampler`, but to make results more reproducible we fixed the `seed`."
]
},
{
"cell_type": "markdown",
"id": "aa87e050",
"metadata": {},
"source": [
"Let's start the fitting. We can start by running only pool stage."
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "47ccd63b",
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"best_pool_pipeline = auto.fit(ts=ts, tune_size=0)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "d972dfb5",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" hash \n",
" pipeline \n",
" SMAPE_mean \n",
" state \n",
" study \n",
" \n",
" \n",
" \n",
" \n",
" 9 \n",
" af8088ac0abfde46e93a8dbb407a2117 \n",
" Pipeline(model = CatBoostPerSegmentModel(itera... \n",
" 5.057438 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 0 \n",
" d8215d95e2c6c9a4b4fdacf3fa77dddc \n",
" Pipeline(model = NaiveModel(lag = 7, ), transf... \n",
" 5.164436 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 2 \n",
" 8f640faabcac0552153ca19337179f3b \n",
" Pipeline(model = HoltWintersModel(trend = 'add... \n",
" 5.931951 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 16 \n",
" d6a44adb551f1aec09ef37c14aed260f \n",
" Pipeline(model = SeasonalMovingAverageModel(wi... \n",
" 6.197182 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 18 \n",
" 16eb77200eb2fd5dc1f6f2a5067884cd \n",
" Pipeline(model = HoltWintersModel(trend = 'add... \n",
" 6.347734 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 1 \n",
" 4c07749e913403906cd033e4882fc4f9 \n",
" Pipeline(model = SeasonalMovingAverageModel(wi... \n",
" 6.529721 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 12 \n",
" 6e2eb71d033b6d0607f5b6d0a7596ce9 \n",
" Pipeline(model = ProphetModel(growth = 'linear... \n",
" 7.792707 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 8 \n",
" 6bb58e7ce09eab00448d5732240ec2ec \n",
" Pipeline(model = CatBoostMultiSegmentModel(ite... \n",
" 7.814187 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 4 \n",
" a640ddfb767ea0cbf31751ddda6e36ee \n",
" Pipeline(model = CatBoostMultiSegmentModel(ite... \n",
" 7.816528 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 15 \n",
" cfeb21bcf2e922a390ade8be9d845e0d \n",
" Pipeline(model = ProphetModel(growth = 'linear... \n",
" 7.867342 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 10 \n",
" a5e036978ef9cc9f297c9eb2c280af05 \n",
" Pipeline(model = AutoARIMAModel(), transforms ... \n",
" 8.297048 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 13 \n",
" 2e36e0b9cb67a43bb1bf96fa2ccf718f \n",
" Pipeline(model = LinearMultiSegmentModel(fit_i... \n",
" 9.205423 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 5 \n",
" 8b9f5fa09754a80f17380dec2b998f1d \n",
" Pipeline(model = LinearPerSegmentModel(fit_int... \n",
" 10.997462 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 3 \n",
" d62c0579459d4a1b88aea8ed6effdf4e \n",
" Pipeline(model = MovingAverageModel(window = 1... \n",
" 11.317256 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 6 \n",
" 5916e5b653295271c79caae490618ee9 \n",
" Pipeline(model = MovingAverageModel(window = 2... \n",
" 12.028916 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 19 \n",
" 5a91b6c8acc2c461913df44fd1429375 \n",
" Pipeline(model = ElasticPerSegmentModel(alpha ... \n",
" 12.213320 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 7 \n",
" 403b3e18012af5ff9815b408f5c2e47d \n",
" Pipeline(model = MovingAverageModel(window = 4... \n",
" 12.243011 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 17 \n",
" 6cf8605e6c513053ac4f5203e330c59d \n",
" Pipeline(model = HoltWintersModel(trend = None... \n",
" 15.473118 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 14 \n",
" 53e90ae4cf7f1f71e6396107549c25ef \n",
" Pipeline(model = NaiveModel(lag = 1, ), transf... \n",
" 19.361078 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 11 \n",
" 90b31b54cb8c01867be05a3320852682 \n",
" Pipeline(model = ElasticMultiSegmentModel(alph... \n",
" 35.971241 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" hash \\\n",
"9 af8088ac0abfde46e93a8dbb407a2117 \n",
"0 d8215d95e2c6c9a4b4fdacf3fa77dddc \n",
"2 8f640faabcac0552153ca19337179f3b \n",
"16 d6a44adb551f1aec09ef37c14aed260f \n",
"18 16eb77200eb2fd5dc1f6f2a5067884cd \n",
"1 4c07749e913403906cd033e4882fc4f9 \n",
"12 6e2eb71d033b6d0607f5b6d0a7596ce9 \n",
"8 6bb58e7ce09eab00448d5732240ec2ec \n",
"4 a640ddfb767ea0cbf31751ddda6e36ee \n",
"15 cfeb21bcf2e922a390ade8be9d845e0d \n",
"10 a5e036978ef9cc9f297c9eb2c280af05 \n",
"13 2e36e0b9cb67a43bb1bf96fa2ccf718f \n",
"5 8b9f5fa09754a80f17380dec2b998f1d \n",
"3 d62c0579459d4a1b88aea8ed6effdf4e \n",
"6 5916e5b653295271c79caae490618ee9 \n",
"19 5a91b6c8acc2c461913df44fd1429375 \n",
"7 403b3e18012af5ff9815b408f5c2e47d \n",
"17 6cf8605e6c513053ac4f5203e330c59d \n",
"14 53e90ae4cf7f1f71e6396107549c25ef \n",
"11 90b31b54cb8c01867be05a3320852682 \n",
"\n",
" pipeline SMAPE_mean \\\n",
"9 Pipeline(model = CatBoostPerSegmentModel(itera... 5.057438 \n",
"0 Pipeline(model = NaiveModel(lag = 7, ), transf... 5.164436 \n",
"2 Pipeline(model = HoltWintersModel(trend = 'add... 5.931951 \n",
"16 Pipeline(model = SeasonalMovingAverageModel(wi... 6.197182 \n",
"18 Pipeline(model = HoltWintersModel(trend = 'add... 6.347734 \n",
"1 Pipeline(model = SeasonalMovingAverageModel(wi... 6.529721 \n",
"12 Pipeline(model = ProphetModel(growth = 'linear... 7.792707 \n",
"8 Pipeline(model = CatBoostMultiSegmentModel(ite... 7.814187 \n",
"4 Pipeline(model = CatBoostMultiSegmentModel(ite... 7.816528 \n",
"15 Pipeline(model = ProphetModel(growth = 'linear... 7.867342 \n",
"10 Pipeline(model = AutoARIMAModel(), transforms ... 8.297048 \n",
"13 Pipeline(model = LinearMultiSegmentModel(fit_i... 9.205423 \n",
"5 Pipeline(model = LinearPerSegmentModel(fit_int... 10.997462 \n",
"3 Pipeline(model = MovingAverageModel(window = 1... 11.317256 \n",
"6 Pipeline(model = MovingAverageModel(window = 2... 12.028916 \n",
"19 Pipeline(model = ElasticPerSegmentModel(alpha ... 12.213320 \n",
"7 Pipeline(model = MovingAverageModel(window = 4... 12.243011 \n",
"17 Pipeline(model = HoltWintersModel(trend = None... 15.473118 \n",
"14 Pipeline(model = NaiveModel(lag = 1, ), transf... 19.361078 \n",
"11 Pipeline(model = ElasticMultiSegmentModel(alph... 35.971241 \n",
"\n",
" state study \n",
"9 TrialState.COMPLETE pool \n",
"0 TrialState.COMPLETE pool \n",
"2 TrialState.COMPLETE pool \n",
"16 TrialState.COMPLETE pool \n",
"18 TrialState.COMPLETE pool \n",
"1 TrialState.COMPLETE pool \n",
"12 TrialState.COMPLETE pool \n",
"8 TrialState.COMPLETE pool \n",
"4 TrialState.COMPLETE pool \n",
"15 TrialState.COMPLETE pool \n",
"10 TrialState.COMPLETE pool \n",
"13 TrialState.COMPLETE pool \n",
"5 TrialState.COMPLETE pool \n",
"3 TrialState.COMPLETE pool \n",
"6 TrialState.COMPLETE pool \n",
"19 TrialState.COMPLETE pool \n",
"7 TrialState.COMPLETE pool \n",
"17 TrialState.COMPLETE pool \n",
"14 TrialState.COMPLETE pool \n",
"11 TrialState.COMPLETE pool "
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"auto.summary()[[\"hash\", \"pipeline\", \"SMAPE_mean\", \"state\", \"study\"]].sort_values(\"SMAPE_mean\")"
]
},
{
"cell_type": "markdown",
"id": "ff62ced9",
"metadata": {},
"source": [
"We can continue our training. The pool stage is over and there will be only the tuning stage. If we don't want to wait forever we should limit the tuning by fixing `n_trials` or `timeout`. \n",
"\n",
"We also set some parameters for `optuna.Study.optimize`: \n",
"\n",
"- `gc_after_trial=True`: to prevent `fit` from increasing memory consumption\n",
"- `catch=(Exception,)`: to prevent failing if some trials are erroneous."
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "13a1861a",
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"best_tuning_pipeline = auto.fit(ts=ts, tune_size=3, n_trials=100, gc_after_trial=True, catch=(Exception,))"
]
},
{
"cell_type": "markdown",
"id": "09f78f63",
"metadata": {},
"source": [
"Let's look at the results."
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "95c854eb",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" hash \n",
" pipeline \n",
" SMAPE_mean \n",
" state \n",
" study \n",
" \n",
" \n",
" \n",
" \n",
" 56 \n",
" 419fc80cf634ba0888c4f899f666ad45 \n",
" Pipeline(model = HoltWintersModel(trend = 'mul... \n",
" 4.769471 \n",
" TrialState.COMPLETE \n",
" tuning/8f640faabcac0552153ca19337179f3b \n",
" \n",
" \n",
" 89 \n",
" 731ccb72a473bec81789b7f186001ddd \n",
" Pipeline(model = CatBoostPerSegmentModel(itera... \n",
" 4.899715 \n",
" TrialState.COMPLETE \n",
" tuning/af8088ac0abfde46e93a8dbb407a2117 \n",
" \n",
" \n",
" 97 \n",
" 9c302769456b4adb9143f11c582f7264 \n",
" Pipeline(model = CatBoostPerSegmentModel(itera... \n",
" 4.927197 \n",
" TrialState.COMPLETE \n",
" tuning/af8088ac0abfde46e93a8dbb407a2117 \n",
" \n",
" \n",
" 88 \n",
" 182c748af70287ab3a12bf32c03320f5 \n",
" Pipeline(model = CatBoostPerSegmentModel(itera... \n",
" 4.941247 \n",
" TrialState.COMPLETE \n",
" tuning/af8088ac0abfde46e93a8dbb407a2117 \n",
" \n",
" \n",
" 96 \n",
" 4f426335c0eb00d847d9dd1e0a421415 \n",
" Pipeline(model = CatBoostPerSegmentModel(itera... \n",
" 4.977773 \n",
" TrialState.COMPLETE \n",
" tuning/af8088ac0abfde46e93a8dbb407a2117 \n",
" \n",
" \n",
" 98 \n",
" 2cafd0750f191e7ab2d4160da50a7c64 \n",
" Pipeline(model = CatBoostPerSegmentModel(itera... \n",
" 5.056993 \n",
" TrialState.COMPLETE \n",
" tuning/af8088ac0abfde46e93a8dbb407a2117 \n",
" \n",
" \n",
" 9 \n",
" af8088ac0abfde46e93a8dbb407a2117 \n",
" Pipeline(model = CatBoostPerSegmentModel(itera... \n",
" 5.057438 \n",
" TrialState.COMPLETE \n",
" pool \n",
" \n",
" \n",
" 75 \n",
" 382825866425cac211691205a9537c95 \n",
" Pipeline(model = CatBoostPerSegmentModel(itera... \n",
" 5.081609 \n",
" TrialState.COMPLETE \n",
" tuning/af8088ac0abfde46e93a8dbb407a2117 \n",
" \n",
" \n",
" 95 \n",
" c2a8d498fe35873d060e173e1af042d5 \n",
" Pipeline(model = CatBoostPerSegmentModel(itera... \n",
" 5.117583 \n",
" TrialState.COMPLETE \n",
" tuning/af8088ac0abfde46e93a8dbb407a2117 \n",
" \n",
" \n",
" 91 \n",
" 035f8e28180bc7491a30b3d0d67060c9 \n",
" Pipeline(model = CatBoostPerSegmentModel(itera... \n",
" 5.135956 \n",
" TrialState.COMPLETE \n",
" tuning/af8088ac0abfde46e93a8dbb407a2117 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" hash \\\n",
"56 419fc80cf634ba0888c4f899f666ad45 \n",
"89 731ccb72a473bec81789b7f186001ddd \n",
"97 9c302769456b4adb9143f11c582f7264 \n",
"88 182c748af70287ab3a12bf32c03320f5 \n",
"96 4f426335c0eb00d847d9dd1e0a421415 \n",
"98 2cafd0750f191e7ab2d4160da50a7c64 \n",
"9 af8088ac0abfde46e93a8dbb407a2117 \n",
"75 382825866425cac211691205a9537c95 \n",
"95 c2a8d498fe35873d060e173e1af042d5 \n",
"91 035f8e28180bc7491a30b3d0d67060c9 \n",
"\n",
" pipeline SMAPE_mean \\\n",
"56 Pipeline(model = HoltWintersModel(trend = 'mul... 4.769471 \n",
"89 Pipeline(model = CatBoostPerSegmentModel(itera... 4.899715 \n",
"97 Pipeline(model = CatBoostPerSegmentModel(itera... 4.927197 \n",
"88 Pipeline(model = CatBoostPerSegmentModel(itera... 4.941247 \n",
"96 Pipeline(model = CatBoostPerSegmentModel(itera... 4.977773 \n",
"98 Pipeline(model = CatBoostPerSegmentModel(itera... 5.056993 \n",
"9 Pipeline(model = CatBoostPerSegmentModel(itera... 5.057438 \n",
"75 Pipeline(model = CatBoostPerSegmentModel(itera... 5.081609 \n",
"95 Pipeline(model = CatBoostPerSegmentModel(itera... 5.117583 \n",
"91 Pipeline(model = CatBoostPerSegmentModel(itera... 5.135956 \n",
"\n",
" state study \n",
"56 TrialState.COMPLETE tuning/8f640faabcac0552153ca19337179f3b \n",
"89 TrialState.COMPLETE tuning/af8088ac0abfde46e93a8dbb407a2117 \n",
"97 TrialState.COMPLETE tuning/af8088ac0abfde46e93a8dbb407a2117 \n",
"88 TrialState.COMPLETE tuning/af8088ac0abfde46e93a8dbb407a2117 \n",
"96 TrialState.COMPLETE tuning/af8088ac0abfde46e93a8dbb407a2117 \n",
"98 TrialState.COMPLETE tuning/af8088ac0abfde46e93a8dbb407a2117 \n",
"9 TrialState.COMPLETE pool \n",
"75 TrialState.COMPLETE tuning/af8088ac0abfde46e93a8dbb407a2117 \n",
"95 TrialState.COMPLETE tuning/af8088ac0abfde46e93a8dbb407a2117 \n",
"91 TrialState.COMPLETE tuning/af8088ac0abfde46e93a8dbb407a2117 "
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"auto.summary()[[\"hash\", \"pipeline\", \"SMAPE_mean\", \"state\", \"study\"]].sort_values(\"SMAPE_mean\").drop_duplicates(\n",
" subset=(\"hash\", \"study\")\n",
").head(10)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "640269ba",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Pipeline(model = HoltWintersModel(trend = 'mul', damped_trend = False, seasonal = 'mul', seasonal_periods = None, initialization_method = 'estimated', initial_level = None, initial_trend = None, initial_seasonal = None, use_boxcox = True, bounds = None, dates = None, freq = None, missing = 'none', smoothing_level = None, smoothing_trend = None, smoothing_seasonal = None, damping_trend = None, ), transforms = [], horizon = 14, ),\n",
" Pipeline(model = CatBoostPerSegmentModel(iterations = None, depth = 9, learning_rate = 0.0435214895575014, logging_level = 'Silent', l2_leaf_reg = 1.588756097852857, thread_count = None, random_strength = 0.0001602176189749599, ), transforms = [LagTransform(in_column = 'target', lags = [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28], out_column = None, ), DateFlagsTransform(day_number_in_week = True, day_number_in_month = False, day_number_in_year = False, week_number_in_month = False, week_number_in_year = False, month_number_in_year = False, season_number = False, year_number = False, is_weekend = True, special_days_in_week = [], special_days_in_month = [], out_column = None, )], horizon = 14, ),\n",
" Pipeline(model = CatBoostPerSegmentModel(iterations = None, depth = 10, learning_rate = 0.066387199945575, logging_level = 'Silent', l2_leaf_reg = 3.8476771557403033, thread_count = None, random_strength = 2.6976801196146113e-05, ), transforms = [LagTransform(in_column = 'target', lags = [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28], out_column = None, ), DateFlagsTransform(day_number_in_week = True, day_number_in_month = False, day_number_in_year = False, week_number_in_month = False, week_number_in_year = False, month_number_in_year = False, season_number = False, year_number = False, is_weekend = True, special_days_in_week = [], special_days_in_month = [], out_column = None, )], horizon = 14, ),\n",
" Pipeline(model = CatBoostPerSegmentModel(iterations = None, depth = 8, learning_rate = 0.1368955392889537, logging_level = 'Silent', l2_leaf_reg = 1.8121398100968207, thread_count = None, random_strength = 1.0292981436693363e-05, ), transforms = [LagTransform(in_column = 'target', lags = [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28], out_column = None, ), DateFlagsTransform(day_number_in_week = True, day_number_in_month = True, day_number_in_year = True, week_number_in_month = False, week_number_in_year = False, month_number_in_year = False, season_number = False, year_number = False, is_weekend = True, special_days_in_week = [], special_days_in_month = [], out_column = None, )], horizon = 14, ),\n",
" Pipeline(model = CatBoostPerSegmentModel(iterations = None, depth = 10, learning_rate = 0.04930475651736648, logging_level = 'Silent', l2_leaf_reg = 1.2938317623739193, thread_count = None, random_strength = 0.00020141074677370956, ), transforms = [LagTransform(in_column = 'target', lags = [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28], out_column = None, ), DateFlagsTransform(day_number_in_week = True, day_number_in_month = False, day_number_in_year = False, week_number_in_month = False, week_number_in_year = False, month_number_in_year = False, season_number = False, year_number = False, is_weekend = True, special_days_in_week = [], special_days_in_month = [], out_column = None, )], horizon = 14, )]"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"auto.top_k(k=5)"
]
},
{
"cell_type": "markdown",
"id": "7451f135",
"metadata": {},
"source": [
"If we look at `study` column we will see that best trial from tuning stage is better then best trial from pool stage. It means, that tuning stage was successful and improved the final result. \n",
"\n",
"Let's compare best pipeline on pool and tuning stages on hold-out part of initial `ts`."
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "ce8953ab",
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"best_pool_metrics, _, _ = best_pool_pipeline.backtest(ts=full_ts, metrics=[SMAPE()], n_folds=5)\n",
"best_tuning_metrics, _, _ = best_tuning_pipeline.backtest(ts=full_ts, metrics=[SMAPE()], n_folds=5)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "7a42cc84",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best pool SMAPE: 8.262\n",
"Best tuning SMAPE: 8.188\n"
]
}
],
"source": [
"best_pool_smape = best_pool_metrics[\"SMAPE\"].mean()\n",
"best_tuning_smape = best_tuning_metrics[\"SMAPE\"].mean()\n",
"print(f\"Best pool SMAPE: {best_pool_smape:.3f}\")\n",
"print(f\"Best tuning SMAPE: {best_tuning_smape:.3f}\")"
]
},
{
"cell_type": "markdown",
"id": "3f41537f",
"metadata": {},
"source": [
"As we can see, the results are slightly better after the tuning stage, but it can be statistically insignificant. For your datasets the results could be different."
]
},
{
"cell_type": "markdown",
"id": "3322d9c2",
"metadata": {},
"source": [
"## 3. Summary "
]
},
{
"cell_type": "markdown",
"id": "39b4a081",
"metadata": {},
"source": [
"In this notebook we discussed how AutoML works in ETNA library and how to use it. There are two supported scenarios:\n",
"\n",
"- Tuning your existing pipeline;\n",
"- Automatic search of the pipeline for your forecasting task."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}