Fit CatBoost model using extended features

  • Takes the model input data for flight delays
  • Split data based on external train/test data file
  • Define catboost model
  • Perform randomized search on selected parameter space
  • Retrain model for more iterations using optimal parameters
  • Save model as pickle
  • – save to mlflow –
  • Write prediction prediction output to csv

Parameters


  • input_file: Filepath of model input data of flight delays
  • train_test_file: Filepath of train/test csv file with columns [“id”, “model_set”]
  • output_file: Filepath to write output csv file with minimal modelling input

Returns


Trained baseline model that simply predicts the average flight delay from the training data in all predictions.

[2]:
# model params
input_file = "../lvt-schiphol-assignment-snakemake/data/model_input/delays_extended_input.csv"
train_test_file = "../lvt-schiphol-assignment-snakemake/data/model_input/train_test__0.2__timeseries.csv"
output_predictions = "./predictions.csv"

# mlflow params
log_mlflow = True
mlflow_tracking_uri = "../mlruns"
mlflow_experiment = "from_script"
mlflow_run = "catboost_simple"
[3]:
from pathlib import Path
output_dir = Path(output_predictions).parent.absolute()
output_dir
[3]:
WindowsPath('C:/Users/lodew/qualogy/schiphol-code-assignment/scripts')

Imports

[4]:
import pandas as pd
import numpy as np
from sklearn.pipeline import Pipeline
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.base import BaseEstimator, TransformerMixin

# catboost
from catboost import Pool
from catboost import CatBoostRegressor

import matplotlib.pyplot as plt
import seaborn as sns

import sys
sys.path.append("../")

from src.data.google_storage_io import read_csv_data, write_csv_data
from src.evaluation.metrics import get_regression_metrics
from src.evaluation.regression import make_regression_metrics_by_group, make_regression_metrics_by_datetime
from src.evaluation.predictions import make_predictions_dataframe
[5]:
plt.rcParams["figure.figsize"] = (16, 8)

Read data

[6]:
%%time
df = read_csv_data(input_file)
train_test = read_csv_data(train_test_file)
Reading file from local directory
File:   ../lvt-schiphol-assignment-snakemake/data/model_input/delays_extended_input.csv

Reading file from local directory
File:   ../lvt-schiphol-assignment-snakemake/data/model_input/train_test__0.2__timeseries.csv

Wall time: 1.51 s
[7]:
%%time

def split_train_test(df, train_test, target="scheduleDelaySeconds"):
    # merge by `id` and group by train/test set labels
    df_set_groups = pd.merge(df, train_test, on="id", how="left").groupby("model_set")

    # get data per train/test set
    df_train = df_set_groups.get_group("train").drop(columns="model_set")
    df_test = df_set_groups.get_group("test").drop(columns="model_set")

    # split target from features
    X_train, y_train = df_train.drop(columns=[target]), df_train[target]
    X_test, y_test = df_test.drop(columns=[target]),  df_test[target]

    print(f"""
        Split data shapes
        Input: {df.shape}
        Train: {X_train.shape},\t {y_train.shape}
        Test:  {X_test.shape},\t {y_test.shape}
        """)

    # assert that we haven't dropped values at this stage
    # failed assert could indicate duplicate ids found in the data
    assert (len(X_train) + len(X_test)) == len(df)

    return X_train, X_test, y_train, y_test

# split data
X_train, X_test, y_train, y_test = split_train_test(df, train_test)



        Split data shapes
        Input: (487714, 24)
        Train: (389439, 23),     (389439,)
        Test:  (98275, 23),      (98275,)

Wall time: 1.1 s

Prediction model

Define model

[8]:
class CatBoostChain(BaseEstimator, TransformerMixin):
    """
    Catboost estimator that uses .transform() to append output predictions
    """
    def __init__(self, catboost_kwargs, early_stopping_rounds=None):
        self.catboost = CatBoostRegressor(**catboost_kwargs)
        self.early_stopping_rounds = early_stopping_rounds

    def fit(self, X, y, **kwargs):
        """
        Expects 'y' to be the race finish times- not avg_speed_left
        """
        # calculate avg speed left
        y =  ((X["distance"] - X["Passed"]) / (y.values - X["seconds"])) \
            .replace([np.inf, -np.inf], np.nan) \
            .fillna(0).values

        X_pool_train, X_pool_eval, y_pool_train, y_pool_eval = train_test_split(X, y, test_size=0.2)

        train_pool = Pool(data=X_pool_train,
                          label=y_pool_train,
                          cat_features=self.catboost.get_param("cat_features"))
#                                   baseline=X_dist_train["yhat_finish_time"])

        eval_pool = Pool(data=X_pool_eval,
                         label=y_pool_eval,
                         cat_features=self.catboost.get_param("cat_features"))

        self.catboost.fit(train_pool, eval_set=eval_pool, early_stopping_rounds=self.early_stopping_rounds, **kwargs)

        return self

Select features for catboost model

[9]:
columns_to_drop = ["id", "scheduleDateTime", "actualOffBlockTime", "year", "month", "quarter"]
X_train_meta = X_train[["id", "scheduleDateTime"]]
X_test_meta = X_test[["id", "scheduleDateTime"]]

X_train = X_train[[col for col in X_train.columns if col not in columns_to_drop]]
X_test = X_test[[col for col in X_test.columns if col not in columns_to_drop]]

# type categorical features for catboost
cat_features = [
    'aircraftRegistration', 'airlineCode', 'terminal', 'serviceType',
    'final_destination', 'Country', 'City', 'DST',
    'dayofweek', 'dayofmonth', 'weekofyear', 'hour', 'minutes'
]

X_train[cat_features] = X_train[cat_features].astype(str).astype('category')
X_test[cat_features] = X_test[cat_features].astype(str).astype('category')

assert all(X_test.columns == X_train.columns)

X_train.columns
[9]:
Index(['aircraftRegistration', 'airlineCode', 'terminal', 'serviceType',
       'final_destination', 'Country', 'City', 'Latitude', 'Longitude',
       'Altitude', 'DST', 'destination_distance', 'dayofweek', 'dayofmonth',
       'weekofyear', 'hour', 'minutes'],
      dtype='object')

Perform random search for optimal parameters

  • Random search over grid search for faster results
[10]:
# create catboost input data
X_pool_train, X_pool_eval, y_pool_train, y_pool_eval = train_test_split(X_train, y_train, test_size=0.2)

train_pool = Pool(data=X_pool_train,
                  label=y_pool_train,
                  cat_features=cat_features)

eval_pool = Pool(data=X_pool_eval,
                 label=y_pool_eval,
                 cat_features=cat_features)

# set initial catboost kwargs for random search
catboost_kwargs={
        "verbose": 1,
        "iterations": 10,
        "depth": 4,
        "learning_rate": 1,
        "loss_function": "MAE",
        "l2_leaf_reg": 4,
        "train_dir": str(Path(output_dir, "catboost_random_search")),
        "cat_features": cat_features}

# sensible values for random search after trial-error
grid = {'learning_rate': [0.1, 0.3, 0.5],
        'depth': [4, 6, 10],
        'l2_leaf_reg': [1, 3, 5, 7, 9]}
search_model = CatBoostRegressor(**catboost_kwargs)
randomized_search_result = search_model.randomized_search(grid,
                                                   X=train_pool,
                                                   plot=True)

bestTest = 804.7577782
bestIteration = 9

0:      loss: 804.7577782       best: 804.7577782 (0)   total: 953ms    remaining: 8.57s

bestTest = 763.8192678
bestIteration = 9

1:      loss: 763.8192678       best: 763.8192678 (1)   total: 1.7s     remaining: 6.79s

bestTest = 757.4576161
bestIteration = 9

2:      loss: 757.4576161       best: 757.4576161 (2)   total: 2.42s    remaining: 5.64s

bestTest = 763.8192678
bestIteration = 9

3:      loss: 763.8192678       best: 757.4576161 (2)   total: 3.13s    remaining: 4.69s

bestTest = 757.4576161
bestIteration = 9

4:      loss: 757.4576161       best: 757.4576161 (2)   total: 3.85s    remaining: 3.85s

bestTest = 763.8192678
bestIteration = 9

5:      loss: 763.8192678       best: 757.4576161 (2)   total: 4.59s    remaining: 3.06s

bestTest = 757.4576161
bestIteration = 9

6:      loss: 757.4576161       best: 757.4576161 (2)   total: 5.29s    remaining: 2.27s

bestTest = 800.7015278
bestIteration = 9

7:      loss: 800.7015278       best: 757.4576161 (2)   total: 6.15s    remaining: 1.54s

bestTest = 756.1623348
bestIteration = 9

8:      loss: 756.1623348       best: 756.1623348 (8)   total: 7.01s    remaining: 779ms

bestTest = 792.3006764
bestIteration = 9

9:      loss: 792.3006764       best: 756.1623348 (8)   total: 8.13s    remaining: 0us
Estimating final quality...

Train model

[11]:
# update catboost parameters with randomized search results
updated_catboost_kwargs = dict(
    catboost_kwargs,
    iterations = 200,
    train_dir = str(output_dir),
    **randomized_search_result['params'])

# create new catboost object with updated parameters
model = CatBoostRegressor(**updated_catboost_kwargs)
# train final model
model.fit(train_pool, eval_set = eval_pool, early_stopping_rounds=50, plot=True)
0:      learn: 756.4992915      test: 736.1896539       best: 736.1896539 (0)   total: 486ms    remaining: 1m 36s
1:      learn: 716.4095588      test: 689.1406306       best: 689.1406306 (1)   total: 1.01s    remaining: 1m 40s
2:      learn: 697.8411741      test: 668.3668963       best: 668.3668963 (2)   total: 1.27s    remaining: 1m 23s
3:      learn: 692.7714155      test: 663.3020848       best: 663.3020848 (3)   total: 1.63s    remaining: 1m 19s
4:      learn: 678.4410044      test: 645.9118846       best: 645.9118846 (4)   total: 2.06s    remaining: 1m 20s
5:      learn: 665.8758141      test: 631.6815334       best: 631.6815334 (5)   total: 2.36s    remaining: 1m 16s
6:      learn: 659.9978535      test: 624.8699020       best: 624.8699020 (6)   total: 2.82s    remaining: 1m 17s
7:      learn: 657.8774503      test: 622.9807212       best: 622.9807212 (7)   total: 3.06s    remaining: 1m 13s
8:      learn: 655.4896950      test: 620.3435796       best: 620.3435796 (8)   total: 3.36s    remaining: 1m 11s
9:      learn: 653.2327483      test: 618.2358122       best: 618.2358122 (9)   total: 3.6s     remaining: 1m 8s
10:     learn: 651.1965779      test: 616.2311812       best: 616.2311812 (10)  total: 3.94s    remaining: 1m 7s
11:     learn: 650.5140676      test: 615.5689544       best: 615.5689544 (11)  total: 4.26s    remaining: 1m 6s
12:     learn: 649.0516013      test: 613.7866925       best: 613.7866925 (12)  total: 4.63s    remaining: 1m 6s
13:     learn: 648.4775319      test: 613.2001549       best: 613.2001549 (13)  total: 4.95s    remaining: 1m 5s
14:     learn: 647.8060426      test: 612.5047631       best: 612.5047631 (14)  total: 5.35s    remaining: 1m 5s
15:     learn: 647.1291244      test: 611.8994895       best: 611.8994895 (15)  total: 5.63s    remaining: 1m 4s
16:     learn: 645.8213327      test: 610.5029021       best: 610.5029021 (16)  total: 5.98s    remaining: 1m 4s
17:     learn: 644.8812865      test: 609.5610178       best: 609.5610178 (17)  total: 6.39s    remaining: 1m 4s
18:     learn: 644.5974207      test: 609.3279597       best: 609.3279597 (18)  total: 6.79s    remaining: 1m 4s
19:     learn: 644.1965673      test: 608.9084781       best: 608.9084781 (19)  total: 7.21s    remaining: 1m 4s
20:     learn: 642.3113860      test: 607.1499557       best: 607.1499557 (20)  total: 7.51s    remaining: 1m 3s
21:     learn: 640.8101524      test: 605.1263016       best: 605.1263016 (21)  total: 7.88s    remaining: 1m 3s
22:     learn: 639.1655080      test: 603.3016212       best: 603.3016212 (22)  total: 8.24s    remaining: 1m 3s
23:     learn: 638.9469651      test: 603.1433498       best: 603.1433498 (23)  total: 8.56s    remaining: 1m 2s
24:     learn: 638.7524524      test: 602.9733992       best: 602.9733992 (24)  total: 8.89s    remaining: 1m 2s
25:     learn: 638.4115124      test: 602.6926753       best: 602.6926753 (25)  total: 9.43s    remaining: 1m 3s
26:     learn: 637.9200366      test: 602.2545671       best: 602.2545671 (26)  total: 9.87s    remaining: 1m 3s
27:     learn: 637.7586167      test: 602.1181496       best: 602.1181496 (27)  total: 10.1s    remaining: 1m 2s
28:     learn: 637.6390371      test: 602.0260108       best: 602.0260108 (28)  total: 10.4s    remaining: 1m 1s
29:     learn: 637.4954822      test: 601.9018914       best: 601.9018914 (29)  total: 10.8s    remaining: 1m
30:     learn: 637.3643015      test: 601.7997421       best: 601.7997421 (30)  total: 11.1s    remaining: 1m
31:     learn: 637.2449287      test: 601.6952385       best: 601.6952385 (31)  total: 11.5s    remaining: 1m
32:     learn: 636.9031647      test: 601.3616772       best: 601.3616772 (32)  total: 11.8s    remaining: 59.9s
33:     learn: 636.6963216      test: 601.2079507       best: 601.2079507 (33)  total: 12.3s    remaining: 59.9s
34:     learn: 636.2495884      test: 600.8071639       best: 600.8071639 (34)  total: 12.6s    remaining: 59.3s
35:     learn: 635.9025845      test: 600.4896005       best: 600.4896005 (35)  total: 13s      remaining: 59.1s
36:     learn: 635.7190227      test: 600.3220853       best: 600.3220853 (36)  total: 13.3s    remaining: 58.4s
37:     learn: 635.5344418      test: 600.1087458       best: 600.1087458 (37)  total: 13.7s    remaining: 58.5s
38:     learn: 634.8079986      test: 599.4618361       best: 599.4618361 (38)  total: 14.1s    remaining: 58.4s
39:     learn: 634.4043165      test: 599.0290034       best: 599.0290034 (39)  total: 14.6s    remaining: 58.5s
40:     learn: 634.3035510      test: 598.9302567       best: 598.9302567 (40)  total: 15s      remaining: 58s
41:     learn: 634.1347057      test: 598.7614733       best: 598.7614733 (41)  total: 15.4s    remaining: 58s
42:     learn: 633.9306509      test: 598.6202432       best: 598.6202432 (42)  total: 15.7s    remaining: 57.4s
43:     learn: 633.7490183      test: 598.4112566       best: 598.4112566 (43)  total: 16.1s    remaining: 57.1s
44:     learn: 633.2496051      test: 598.0410102       best: 598.0410102 (44)  total: 16.6s    remaining: 57.1s
45:     learn: 632.7809645      test: 597.6645594       best: 597.6645594 (45)  total: 16.9s    remaining: 56.6s
46:     learn: 632.6504159      test: 597.5792268       best: 597.5792268 (46)  total: 17.3s    remaining: 56.5s
47:     learn: 632.3499691      test: 597.2903086       best: 597.2903086 (47)  total: 17.8s    remaining: 56.2s
48:     learn: 632.0798885      test: 597.0893884       best: 597.0893884 (48)  total: 18.1s    remaining: 55.8s
49:     learn: 631.9518132      test: 596.9704911       best: 596.9704911 (49)  total: 18.6s    remaining: 55.8s
50:     learn: 631.8371254      test: 596.8837289       best: 596.8837289 (50)  total: 19s      remaining: 55.5s
51:     learn: 631.6555558      test: 596.7125866       best: 596.7125866 (51)  total: 19.4s    remaining: 55.2s
52:     learn: 631.3500733      test: 596.5806829       best: 596.5806829 (52)  total: 19.7s    remaining: 54.8s
53:     learn: 631.2160011      test: 596.4982592       best: 596.4982592 (53)  total: 20.1s    remaining: 54.4s
54:     learn: 631.1155251      test: 596.4052152       best: 596.4052152 (54)  total: 20.5s    remaining: 54.1s
55:     learn: 630.9322951      test: 596.2872326       best: 596.2872326 (55)  total: 20.9s    remaining: 53.7s
56:     learn: 630.8313626      test: 596.2052357       best: 596.2052357 (56)  total: 21.3s    remaining: 53.4s
57:     learn: 630.5880527      test: 596.0234432       best: 596.0234432 (57)  total: 21.6s    remaining: 52.8s
58:     learn: 630.4350416      test: 595.9243026       best: 595.9243026 (58)  total: 22s      remaining: 52.5s
59:     learn: 630.1765985      test: 595.6232071       best: 595.6232071 (59)  total: 22.3s    remaining: 52s
60:     learn: 629.9420989      test: 595.3269454       best: 595.3269454 (60)  total: 22.7s    remaining: 51.8s
61:     learn: 629.7855614      test: 595.1980082       best: 595.1980082 (61)  total: 23.1s    remaining: 51.5s
62:     learn: 629.6821856      test: 595.1140363       best: 595.1140363 (62)  total: 23.5s    remaining: 51.1s
63:     learn: 629.3721379      test: 594.7708591       best: 594.7708591 (63)  total: 23.9s    remaining: 50.8s
64:     learn: 629.2935584      test: 594.7005343       best: 594.7005343 (64)  total: 24.3s    remaining: 50.4s
65:     learn: 629.1437203      test: 594.5599510       best: 594.5599510 (65)  total: 24.6s    remaining: 50s
66:     learn: 629.0597314      test: 594.4965610       best: 594.4965610 (66)  total: 25s      remaining: 49.6s
67:     learn: 628.9604825      test: 594.4300063       best: 594.4300063 (67)  total: 25.3s    remaining: 49.2s
68:     learn: 628.6144429      test: 594.0584278       best: 594.0584278 (68)  total: 25.8s    remaining: 49s
69:     learn: 628.3600205      test: 593.7500512       best: 593.7500512 (69)  total: 26.3s    remaining: 48.8s
70:     learn: 627.9630414      test: 593.2422692       best: 593.2422692 (70)  total: 26.8s    remaining: 48.6s
71:     learn: 627.8194149      test: 593.1220129       best: 593.1220129 (71)  total: 27.1s    remaining: 48.1s
72:     learn: 627.6032628      test: 592.9470811       best: 592.9470811 (72)  total: 27.4s    remaining: 47.7s
73:     learn: 627.0955542      test: 592.5350740       best: 592.5350740 (73)  total: 27.8s    remaining: 47.3s
74:     learn: 626.9570505      test: 592.4246023       best: 592.4246023 (74)  total: 28.3s    remaining: 47.1s
75:     learn: 626.8580474      test: 592.3451068       best: 592.3451068 (75)  total: 28.6s    remaining: 46.6s
76:     learn: 626.6240930      test: 592.1134100       best: 592.1134100 (76)  total: 29.2s    remaining: 46.6s
77:     learn: 626.5238310      test: 591.9830204       best: 591.9830204 (77)  total: 29.7s    remaining: 46.5s
78:     learn: 626.2930463      test: 591.7926196       best: 591.7926196 (78)  total: 30.1s    remaining: 46.2s
79:     learn: 626.1785614      test: 591.7165412       best: 591.7165412 (79)  total: 30.5s    remaining: 45.7s
80:     learn: 626.1020838      test: 591.6535367       best: 591.6535367 (80)  total: 30.8s    remaining: 45.3s
81:     learn: 626.0103928      test: 591.5821192       best: 591.5821192 (81)  total: 31.3s    remaining: 45s
82:     learn: 625.8423460      test: 591.4126421       best: 591.4126421 (82)  total: 31.8s    remaining: 44.8s
83:     learn: 625.7572063      test: 591.3751926       best: 591.3751926 (83)  total: 32.2s    remaining: 44.5s
84:     learn: 625.6648629      test: 591.2986017       best: 591.2986017 (84)  total: 32.6s    remaining: 44.1s
85:     learn: 624.8146807      test: 590.5119781       best: 590.5119781 (85)  total: 33s      remaining: 43.8s
86:     learn: 624.5953493      test: 590.2107524       best: 590.2107524 (86)  total: 33.4s    remaining: 43.3s
87:     learn: 624.3865336      test: 590.0259422       best: 590.0259422 (87)  total: 33.7s    remaining: 42.9s
88:     learn: 624.1381544      test: 589.7534605       best: 589.7534605 (88)  total: 34.2s    remaining: 42.7s
89:     learn: 624.0491247      test: 589.7121752       best: 589.7121752 (89)  total: 34.8s    remaining: 42.5s
90:     learn: 623.7640891      test: 589.4765438       best: 589.4765438 (90)  total: 35.2s    remaining: 42.1s
91:     learn: 623.6698931      test: 589.3804272       best: 589.3804272 (91)  total: 35.5s    remaining: 41.7s
92:     learn: 623.5908364      test: 589.3442882       best: 589.3442882 (92)  total: 35.8s    remaining: 41.2s
93:     learn: 623.5306758      test: 589.3090298       best: 589.3090298 (93)  total: 36.2s    remaining: 40.8s
94:     learn: 623.3227559      test: 589.0772829       best: 589.0772829 (94)  total: 36.5s    remaining: 40.3s
95:     learn: 623.1951940      test: 588.9860376       best: 588.9860376 (95)  total: 36.9s    remaining: 40s
96:     learn: 622.9879567      test: 588.8807494       best: 588.8807494 (96)  total: 37.2s    remaining: 39.5s
97:     learn: 622.8090526      test: 588.7342637       best: 588.7342637 (97)  total: 37.6s    remaining: 39.1s
98:     learn: 622.5838785      test: 588.7075079       best: 588.7075079 (98)  total: 37.9s    remaining: 38.7s
99:     learn: 622.4825836      test: 588.6086717       best: 588.6086717 (99)  total: 38.2s    remaining: 38.2s
100:    learn: 622.2791306      test: 588.3271311       best: 588.3271311 (100) total: 38.7s    remaining: 37.9s
101:    learn: 622.2020300      test: 588.2731782       best: 588.2731782 (101) total: 39.2s    remaining: 37.6s
102:    learn: 621.9820212      test: 588.1600497       best: 588.1600497 (102) total: 39.6s    remaining: 37.3s
103:    learn: 621.6717145      test: 588.0134407       best: 588.0134407 (103) total: 40s      remaining: 36.9s
104:    learn: 621.4919530      test: 587.8427733       best: 587.8427733 (104) total: 40.4s    remaining: 36.6s
105:    learn: 621.4406313      test: 587.7976833       best: 587.7976833 (105) total: 40.8s    remaining: 36.2s
106:    learn: 620.8423283      test: 587.2740215       best: 587.2740215 (106) total: 41.2s    remaining: 35.8s
107:    learn: 620.7610185      test: 587.1424969       best: 587.1424969 (107) total: 41.5s    remaining: 35.4s
108:    learn: 620.6532462      test: 587.0451113       best: 587.0451113 (108) total: 41.9s    remaining: 34.9s
109:    learn: 620.5812943      test: 587.0190482       best: 587.0190482 (109) total: 42.3s    remaining: 34.6s
110:    learn: 619.9161984      test: 586.3519279       best: 586.3519279 (110) total: 42.6s    remaining: 34.1s
111:    learn: 619.8554406      test: 586.2735098       best: 586.2735098 (111) total: 43s      remaining: 33.8s
112:    learn: 619.7443666      test: 586.2191013       best: 586.2191013 (112) total: 43.4s    remaining: 33.4s
113:    learn: 619.6755077      test: 586.1544571       best: 586.1544571 (113) total: 43.8s    remaining: 33.1s
114:    learn: 619.5537904      test: 586.0764185       best: 586.0764185 (114) total: 44.3s    remaining: 32.8s
115:    learn: 619.2903614      test: 586.0431799       best: 586.0431799 (115) total: 44.7s    remaining: 32.4s
116:    learn: 619.2025000      test: 585.9869016       best: 585.9869016 (116) total: 45.1s    remaining: 32s
117:    learn: 618.7112464      test: 585.3193213       best: 585.3193213 (117) total: 45.5s    remaining: 31.6s
118:    learn: 618.5342821      test: 585.1961176       best: 585.1961176 (118) total: 45.9s    remaining: 31.3s
119:    learn: 618.3582463      test: 585.0500864       best: 585.0500864 (119) total: 46.3s    remaining: 30.9s
120:    learn: 618.2471939      test: 584.9732303       best: 584.9732303 (120) total: 46.8s    remaining: 30.5s
121:    learn: 618.0731285      test: 584.8395377       best: 584.8395377 (121) total: 47.2s    remaining: 30.2s
122:    learn: 618.0000942      test: 584.7620554       best: 584.7620554 (122) total: 47.8s    remaining: 29.9s
123:    learn: 617.9257270      test: 584.7316075       best: 584.7316075 (123) total: 48.3s    remaining: 29.6s
124:    learn: 617.7787591      test: 584.6011831       best: 584.6011831 (124) total: 48.8s    remaining: 29.3s
125:    learn: 617.6997120      test: 584.5398819       best: 584.5398819 (125) total: 49.2s    remaining: 28.9s
126:    learn: 617.5244009      test: 584.3529823       best: 584.3529823 (126) total: 49.6s    remaining: 28.5s
127:    learn: 617.2962930      test: 584.1669826       best: 584.1669826 (127) total: 50s      remaining: 28.1s
128:    learn: 617.1825785      test: 584.1385620       best: 584.1385620 (128) total: 50.4s    remaining: 27.7s
129:    learn: 617.1099402      test: 584.0210716       best: 584.0210716 (129) total: 50.8s    remaining: 27.3s
130:    learn: 617.0114904      test: 583.9671104       best: 583.9671104 (130) total: 51.2s    remaining: 27s
131:    learn: 616.8188259      test: 583.8113980       best: 583.8113980 (131) total: 51.5s    remaining: 26.5s
132:    learn: 616.7738320      test: 583.7953743       best: 583.7953743 (132) total: 51.9s    remaining: 26.2s
133:    learn: 616.5926336      test: 583.7213538       best: 583.7213538 (133) total: 52.3s    remaining: 25.8s
134:    learn: 616.5084778      test: 583.7059402       best: 583.7059402 (134) total: 52.7s    remaining: 25.4s
135:    learn: 616.4312025      test: 583.6415595       best: 583.6415595 (135) total: 53.3s    remaining: 25.1s
136:    learn: 616.3240607      test: 583.5173499       best: 583.5173499 (136) total: 53.7s    remaining: 24.7s
137:    learn: 616.2783275      test: 583.5019564       best: 583.5019564 (137) total: 54.2s    remaining: 24.3s
138:    learn: 616.2333421      test: 583.4770307       best: 583.4770307 (138) total: 54.6s    remaining: 23.9s
139:    learn: 616.2008139      test: 583.4749795       best: 583.4749795 (139) total: 55s      remaining: 23.6s
140:    learn: 616.0859712      test: 583.3754610       best: 583.3754610 (140) total: 55.5s    remaining: 23.2s
141:    learn: 616.0157786      test: 583.3133219       best: 583.3133219 (141) total: 55.9s    remaining: 22.8s
142:    learn: 615.9341868      test: 583.2771837       best: 583.2771837 (142) total: 56.4s    remaining: 22.5s
143:    learn: 615.8869803      test: 583.2243588       best: 583.2243588 (143) total: 57s      remaining: 22.2s
144:    learn: 615.8266985      test: 583.2070106       best: 583.2070106 (144) total: 57.4s    remaining: 21.8s
145:    learn: 615.7423962      test: 583.1519266       best: 583.1519266 (145) total: 57.7s    remaining: 21.4s
146:    learn: 615.4880554      test: 582.9063992       best: 582.9063992 (146) total: 58.2s    remaining: 21s
147:    learn: 615.3819690      test: 582.8068397       best: 582.8068397 (147) total: 58.5s    remaining: 20.6s
148:    learn: 615.3397580      test: 582.7699536       best: 582.7699536 (148) total: 59s      remaining: 20.2s
149:    learn: 615.1919542      test: 582.6876877       best: 582.6876877 (149) total: 59.4s    remaining: 19.8s
150:    learn: 615.1626818      test: 582.6696281       best: 582.6696281 (150) total: 59.8s    remaining: 19.4s
151:    learn: 615.0243761      test: 582.5860315       best: 582.5860315 (151) total: 1m       remaining: 19s
152:    learn: 614.8813938      test: 582.5034830       best: 582.5034830 (152) total: 1m       remaining: 18.6s
153:    learn: 614.8605463      test: 582.4908016       best: 582.4908016 (153) total: 1m 1s    remaining: 18.2s
154:    learn: 614.7724205      test: 582.4161990       best: 582.4161990 (154) total: 1m 1s    remaining: 17.8s
155:    learn: 614.7300791      test: 582.4021538       best: 582.4021538 (155) total: 1m 1s    remaining: 17.4s
156:    learn: 614.3910751      test: 582.1378535       best: 582.1378535 (156) total: 1m 1s    remaining: 17s
157:    learn: 614.2645807      test: 582.0170679       best: 582.0170679 (157) total: 1m 2s    remaining: 16.5s
158:    learn: 614.1563803      test: 581.9316131       best: 581.9316131 (158) total: 1m 2s    remaining: 16.1s
159:    learn: 613.9578124      test: 581.7698210       best: 581.7698210 (159) total: 1m 2s    remaining: 15.7s
160:    learn: 613.7906887      test: 581.5912194       best: 581.5912194 (160) total: 1m 3s    remaining: 15.3s
161:    learn: 613.7044189      test: 581.5036286       best: 581.5036286 (161) total: 1m 3s    remaining: 14.9s
162:    learn: 613.5849392      test: 581.4266365       best: 581.4266365 (162) total: 1m 4s    remaining: 14.6s
163:    learn: 613.3722696      test: 581.2292450       best: 581.2292450 (163) total: 1m 4s    remaining: 14.2s
164:    learn: 613.1556289      test: 581.1037576       best: 581.1037576 (164) total: 1m 5s    remaining: 13.8s
165:    learn: 613.0290732      test: 580.9696302       best: 580.9696302 (165) total: 1m 5s    remaining: 13.4s
166:    learn: 612.9695643      test: 580.9225562       best: 580.9225562 (166) total: 1m 5s    remaining: 13s
167:    learn: 612.9182516      test: 580.8944932       best: 580.8944932 (167) total: 1m 6s    remaining: 12.6s
168:    learn: 612.8759733      test: 580.8487442       best: 580.8487442 (168) total: 1m 6s    remaining: 12.2s
169:    learn: 612.7284283      test: 580.7110513       best: 580.7110513 (169) total: 1m 7s    remaining: 11.8s
170:    learn: 612.6469981      test: 580.6450549       best: 580.6450549 (170) total: 1m 7s    remaining: 11.4s
171:    learn: 612.4922876      test: 580.4813941       best: 580.4813941 (171) total: 1m 7s    remaining: 11s
172:    learn: 612.3682492      test: 580.4154322       best: 580.4154322 (172) total: 1m 8s    remaining: 10.6s
173:    learn: 612.1525769      test: 580.2475661       best: 580.2475661 (173) total: 1m 8s    remaining: 10.2s
174:    learn: 612.0985593      test: 580.2149161       best: 580.2149161 (174) total: 1m 8s    remaining: 9.85s
175:    learn: 611.9614065      test: 580.0693887       best: 580.0693887 (175) total: 1m 9s    remaining: 9.46s
176:    learn: 611.8774389      test: 580.0103718       best: 580.0103718 (176) total: 1m 9s    remaining: 9.07s
177:    learn: 611.7511840      test: 579.9108673       best: 579.9108673 (177) total: 1m 10s   remaining: 8.68s
178:    learn: 611.7252624      test: 579.9017319       best: 579.9017319 (178) total: 1m 10s   remaining: 8.29s
179:    learn: 611.6226393      test: 579.8035861       best: 579.8035861 (179) total: 1m 11s   remaining: 7.89s
180:    learn: 611.4549767      test: 579.6788475       best: 579.6788475 (180) total: 1m 11s   remaining: 7.51s
181:    learn: 611.3454769      test: 579.5541309       best: 579.5541309 (181) total: 1m 12s   remaining: 7.13s
182:    learn: 611.3054737      test: 579.5266276       best: 579.5266276 (182) total: 1m 12s   remaining: 6.73s
183:    learn: 611.2498650      test: 579.4875620       best: 579.4875620 (183) total: 1m 12s   remaining: 6.33s
184:    learn: 611.1533009      test: 579.3775673       best: 579.3775673 (184) total: 1m 13s   remaining: 5.92s
185:    learn: 610.9020615      test: 579.1870376       best: 579.1870376 (185) total: 1m 13s   remaining: 5.54s
186:    learn: 610.8316611      test: 579.0995204       best: 579.0995204 (186) total: 1m 14s   remaining: 5.16s
187:    learn: 610.7659888      test: 579.0790918       best: 579.0790918 (187) total: 1m 14s   remaining: 4.76s
188:    learn: 610.7128587      test: 579.0734995       best: 579.0734995 (188) total: 1m 15s   remaining: 4.37s
189:    learn: 610.6606708      test: 579.0397298       best: 579.0397298 (189) total: 1m 15s   remaining: 3.97s
190:    learn: 610.6550951      test: 579.0353334       best: 579.0353334 (190) total: 1m 15s   remaining: 3.58s
191:    learn: 610.5822829      test: 578.9660564       best: 578.9660564 (191) total: 1m 16s   remaining: 3.18s
192:    learn: 610.5296158      test: 578.9261770       best: 578.9261770 (192) total: 1m 16s   remaining: 2.78s
193:    learn: 610.4417085      test: 578.8428978       best: 578.8428978 (193) total: 1m 17s   remaining: 2.39s
194:    learn: 610.3924358      test: 578.8206193       best: 578.8206193 (194) total: 1m 17s   remaining: 1.99s
195:    learn: 610.1421457      test: 578.5998934       best: 578.5998934 (195) total: 1m 17s   remaining: 1.59s
196:    learn: 610.0201709      test: 578.4677762       best: 578.4677762 (196) total: 1m 18s   remaining: 1.19s
197:    learn: 609.8806418      test: 578.3458600       best: 578.3458600 (197) total: 1m 18s   remaining: 795ms
198:    learn: 609.8425944      test: 578.3160699       best: 578.3160699 (198) total: 1m 18s   remaining: 397ms
199:    learn: 609.7698163      test: 578.3045065       best: 578.3045065 (199) total: 1m 19s   remaining: 0us

bestTest = 578.3045065
bestIteration = 199

[11]:
<catboost.core.CatBoostRegressor at 0x1a6c7b74c48>

Evaluate model

[12]:
def datetime_to_date(dt):
    return pd.to_datetime(dt, utc=True).dt.date

def datetime_to_date_hour(dt):
    return pd.to_datetime(dt, utc=True).dt.floor('H')

[13]:
# create predictions on train/test sets
df_predictions = make_predictions_dataframe(model, X_train, X_test, y_train, y_test,
                                           meta_train = X_train_meta,
                                           meta_test = X_test_meta)
df_predictions
[13]:
id scheduleDateTime y yhat error model_set
0 123414481790510775 2018-01-01 03:30:00+01:00 -480.0 -197.459328 282.540672 train
1 123414479288269149 2018-01-01 06:00:00+01:00 -98.0 -100.360379 -2.360379 train
2 123414479666542945 2018-01-01 06:05:00+01:00 -300.0 -120.601478 179.398522 train
3 123414479288365061 2018-01-01 06:05:00+01:00 -300.0 -121.494374 178.505626 train
4 123414479288274329 2018-01-01 06:15:00+01:00 694.0 176.538920 -517.461080 train
... ... ... ... ... ... ...
487709 124763270719624901 2018-07-12 17:25:00+02:00 80.0 280.244578 200.244578 test
487710 124763272032451639 2018-07-12 17:25:00+02:00 80.0 281.817452 201.817452 test
487711 124763270368084713 2018-07-12 17:25:00+02:00 12.0 278.095160 266.095160 test
487712 124763270625998761 2018-07-12 17:45:00+02:00 -1935.0 -2.120094 1932.879906 test
487713 124763271129903067 2018-07-12 17:50:00+02:00 -8690.0 31.526172 8721.526172 test

487714 rows × 6 columns

[14]:
%%time

df_metrics_long = make_regression_metrics_by_group(df_predictions, group_cols = ["model_set"])
df_daily_metrics_long = make_regression_metrics_by_datetime(df_predictions, freq="D", alias="schedule_date")
df_hourly_metrics_long = make_regression_metrics_by_datetime(df_predictions, freq="H", alias="schedule_date")
Wall time: 9.12 s
[15]:
df_metrics_long.head()
[15]:
model_set variable value
0 test mae 841.293334
1 train mae 558.490217
2 test mape 365.325530
3 train mape 448.172767
4 test rmse 2429.517313
[16]:
import plotly.express as px
fig = px.line(df_hourly_metrics_long, x="schedule_date", y="value", facet_row="variable", color="model_set",
             width=1200, height=1200, title="Hourly prediction metrics")
# Add range slider
fig.update_layout(
    xaxis=dict(
        rangeslider=dict(
            visible=True
        ),
        type="date"
    ),
    hovermode="x"
)
fig.update_yaxes(matches=None)
# fig.update_xaxes(matches=None)
fig.show()

Plot some prediction results

[17]:
def predictions_daily_mean(df_predictions):
    df_predictions["schedule_date"] = datetime_to_date(df_predictions["scheduleDateTime"])
    df_predictions = df_predictions.drop(columns="id")
    df_daily_mean = df_predictions.groupby(["model_set", "schedule_date"]).mean().reset_index()
    return df_daily_mean

def predictions_hourly_mean(df_predictions):
    df_predictions["schedule_date"] = datetime_to_date_hour(df_predictions["scheduleDateTime"])
    df_predictions = df_predictions.drop(columns="id")
    df_daily_mean = df_predictions.groupby(["model_set", "schedule_date"]).mean().reset_index()
    return df_daily_mean

def get_safe_ylim(y, q=0.05, q2=None):
    if q2 is None:
        q2 = 1 - q
    return (np.quantile(y, q), np.quantile(y, q2))


df_daily_mean = predictions_daily_mean(df_predictions)
y_ylim = get_safe_ylim(df_daily_mean["y"])
error_ylim = get_safe_ylim(df_daily_mean["error"])

df_daily_mean[["schedule_date", "y", "yhat", "model_set"]].plot(x="schedule_date", ylim=y_ylim)
df_daily_mean[["schedule_date", "error", "model_set"]].plot(x="schedule_date", ylim=error_ylim)

df_hourly_mean = predictions_hourly_mean(df_predictions)
y_ylim = get_safe_ylim(df_hourly_mean["y"])
error_ylim = get_safe_ylim(df_hourly_mean["error"])

df_hourly_mean[["schedule_date", "y", "yhat", "model_set"]].groupby("model_set").plot(x="schedule_date", ylim=y_ylim)
df_hourly_mean[["schedule_date", "error", "model_set"]].groupby("model_set").plot(x="schedule_date", ylim=error_ylim)

df_hourly_mean[["schedule_date", "y", "yhat", "error", "model_set"]]
plt.show()

fig = px.line(df_hourly_mean, x="schedule_date", y="error", color="model_set",
             width=1200, height=1200, title="Hourly total prediction error")
fig.update_yaxes(matches=None)
fig.update_xaxes(matches=None)
fig.show()
../_images/scripts_model__catboost_simple_25_0.png
../_images/scripts_model__catboost_simple_25_1.png
../_images/scripts_model__catboost_simple_25_2.png
../_images/scripts_model__catboost_simple_25_3.png
../_images/scripts_model__catboost_simple_25_4.png
../_images/scripts_model__catboost_simple_25_5.png

Write output to output directory

[18]:
import joblib, pickle
from pathlib import Path
[ ]:
model_file = str(Path(output_dir, "model.pkl"))
predictions_file = str(Path(output_dir, "predictions.csv"))
overall_metrics_file = str(Path(output_dir, "overall_metrics_long.csv"))
daily_metrics_file = str(Path(output_dir, "daily_metrics_long.csv"))
hourly_metrics_file = str(Path(output_dir, "hourly_metrics_long.csv"))

Pickle output files for mlflow artifacts

  • Pipeline serialized with joblib
  • Model data or sample thereof
[19]:
joblib.dump(model, Path(output_dir, model_file))
[19]:
['C:\\Users\\lodew\\qualogy\\schiphol-code-assignment\\scripts\\model.pkl']

Write output to CSV

Local or Google Storage is both handled

[20]:
# write output file
write_csv_data(df_predictions,       predictions_file, index=False)
write_csv_data(df_metrics_long,       overall_metrics_file, index=False)
write_csv_data(df_daily_metrics_long, daily_metrics_file, index=False)
write_csv_data(df_hourly_metrics_long, hourly_metrics_file, index=False)
Writing file to local directory
File:   C:\Users\lodew\qualogy\schiphol-code-assignment\scripts\predictions.csv

Writing file to local directory
File:   C:\Users\lodew\qualogy\schiphol-code-assignment\scripts\overall_metrics_long.csv

Writing file to local directory
File:   C:\Users\lodew\qualogy\schiphol-code-assignment\scripts\daily_metrics_long.csv

Writing file to local directory
File:   C:\Users\lodew\qualogy\schiphol-code-assignment\scripts\hourly_metrics_long.csv

Log to MLFlow

[21]:
import mlflow

mlflow.set_tracking_uri(mlflow_tracking_uri)
mlflow.set_experiment(mlflow_experiment)

print(f"Logging to experiment: {mlflow_experiment}")
print(f"Run name: {mlflow_run}")

with mlflow.start_run(run_name=mlflow_run):
    mlflow.log_param("Input file", input_file)
    mlflow.log_param("Train-test file", train_test_file)

    # Model metadata
    for idx, metric_row in df_metrics_long.iterrows():
        metric_name = "__".join([metric_row["variable"], metric_row["model_set"]])
        mlflow.log_metric(metric_name, metric_row["value"])

    # log artifacts
    print("Logging artifacts")
    mlflow.log_artifact(predictions_file)
    mlflow.log_artifact(overall_metrics_file)
    mlflow.log_artifact(daily_metrics_file)
    mlflow.log_artifact(hourly_metrics_file)

INFO: 'from_script' does not exist. Creating a new experiment
Logging to experiment: from_script
Run name: catboost_simple
Logging artifacts
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-21-a6e1f7e0ce6b> in <module>
     18     # log artifacts
     19     print("Logging artifacts")
---> 20     mlflow.log_artifact(predictions_file)
     21     mlflow.log_artifact(overall_metrics_file)
     22     mlflow.log_artifact(daily_metrics_file)

NameError: name 'predictions_file' is not defined

Overview of the output data

[ ]:
df_predictions.info()