Fit CatBoost model using extended features¶
- Takes the model input data for flight delays
- Split data based on external train/test data file
- Define catboost model
- Perform randomized search on selected parameter space
- Retrain model for more iterations using optimal parameters
- Save model as pickle
- – save to mlflow –
- Write prediction prediction output to csv
Parameters¶
- input_file: Filepath of model input data of flight delays
- train_test_file: Filepath of train/test csv file with columns [“id”, “model_set”]
- output_file: Filepath to write output csv file with minimal modelling input
Returns¶
Trained baseline model that simply predicts the average flight delay from the training data in all predictions.
[2]:
# model params
input_file = "../lvt-schiphol-assignment-snakemake/data/model_input/delays_extended_input.csv"
train_test_file = "../lvt-schiphol-assignment-snakemake/data/model_input/train_test__0.2__timeseries.csv"
output_predictions = "./predictions.csv"
# mlflow params
log_mlflow = True
mlflow_tracking_uri = "../mlruns"
mlflow_experiment = "from_script"
mlflow_run = "catboost_simple"
[3]:
from pathlib import Path
output_dir = Path(output_predictions).parent.absolute()
output_dir
[3]:
WindowsPath('C:/Users/lodew/qualogy/schiphol-code-assignment/scripts')
Imports¶
[4]:
import pandas as pd
import numpy as np
from sklearn.pipeline import Pipeline
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.base import BaseEstimator, TransformerMixin
# catboost
from catboost import Pool
from catboost import CatBoostRegressor
import matplotlib.pyplot as plt
import seaborn as sns
import sys
sys.path.append("../")
from src.data.google_storage_io import read_csv_data, write_csv_data
from src.evaluation.metrics import get_regression_metrics
from src.evaluation.regression import make_regression_metrics_by_group, make_regression_metrics_by_datetime
from src.evaluation.predictions import make_predictions_dataframe
[5]:
plt.rcParams["figure.figsize"] = (16, 8)
Read data¶
[6]:
%%time
df = read_csv_data(input_file)
train_test = read_csv_data(train_test_file)
Reading file from local directory
File: ../lvt-schiphol-assignment-snakemake/data/model_input/delays_extended_input.csv
Reading file from local directory
File: ../lvt-schiphol-assignment-snakemake/data/model_input/train_test__0.2__timeseries.csv
Wall time: 1.51 s
[7]:
%%time
def split_train_test(df, train_test, target="scheduleDelaySeconds"):
# merge by `id` and group by train/test set labels
df_set_groups = pd.merge(df, train_test, on="id", how="left").groupby("model_set")
# get data per train/test set
df_train = df_set_groups.get_group("train").drop(columns="model_set")
df_test = df_set_groups.get_group("test").drop(columns="model_set")
# split target from features
X_train, y_train = df_train.drop(columns=[target]), df_train[target]
X_test, y_test = df_test.drop(columns=[target]), df_test[target]
print(f"""
Split data shapes
Input: {df.shape}
Train: {X_train.shape},\t {y_train.shape}
Test: {X_test.shape},\t {y_test.shape}
""")
# assert that we haven't dropped values at this stage
# failed assert could indicate duplicate ids found in the data
assert (len(X_train) + len(X_test)) == len(df)
return X_train, X_test, y_train, y_test
# split data
X_train, X_test, y_train, y_test = split_train_test(df, train_test)
Split data shapes
Input: (487714, 24)
Train: (389439, 23), (389439,)
Test: (98275, 23), (98275,)
Wall time: 1.1 s
Prediction model¶
Define model¶
[8]:
class CatBoostChain(BaseEstimator, TransformerMixin):
"""
Catboost estimator that uses .transform() to append output predictions
"""
def __init__(self, catboost_kwargs, early_stopping_rounds=None):
self.catboost = CatBoostRegressor(**catboost_kwargs)
self.early_stopping_rounds = early_stopping_rounds
def fit(self, X, y, **kwargs):
"""
Expects 'y' to be the race finish times- not avg_speed_left
"""
# calculate avg speed left
y = ((X["distance"] - X["Passed"]) / (y.values - X["seconds"])) \
.replace([np.inf, -np.inf], np.nan) \
.fillna(0).values
X_pool_train, X_pool_eval, y_pool_train, y_pool_eval = train_test_split(X, y, test_size=0.2)
train_pool = Pool(data=X_pool_train,
label=y_pool_train,
cat_features=self.catboost.get_param("cat_features"))
# baseline=X_dist_train["yhat_finish_time"])
eval_pool = Pool(data=X_pool_eval,
label=y_pool_eval,
cat_features=self.catboost.get_param("cat_features"))
self.catboost.fit(train_pool, eval_set=eval_pool, early_stopping_rounds=self.early_stopping_rounds, **kwargs)
return self
Select features for catboost model¶
[9]:
columns_to_drop = ["id", "scheduleDateTime", "actualOffBlockTime", "year", "month", "quarter"]
X_train_meta = X_train[["id", "scheduleDateTime"]]
X_test_meta = X_test[["id", "scheduleDateTime"]]
X_train = X_train[[col for col in X_train.columns if col not in columns_to_drop]]
X_test = X_test[[col for col in X_test.columns if col not in columns_to_drop]]
# type categorical features for catboost
cat_features = [
'aircraftRegistration', 'airlineCode', 'terminal', 'serviceType',
'final_destination', 'Country', 'City', 'DST',
'dayofweek', 'dayofmonth', 'weekofyear', 'hour', 'minutes'
]
X_train[cat_features] = X_train[cat_features].astype(str).astype('category')
X_test[cat_features] = X_test[cat_features].astype(str).astype('category')
assert all(X_test.columns == X_train.columns)
X_train.columns
[9]:
Index(['aircraftRegistration', 'airlineCode', 'terminal', 'serviceType',
'final_destination', 'Country', 'City', 'Latitude', 'Longitude',
'Altitude', 'DST', 'destination_distance', 'dayofweek', 'dayofmonth',
'weekofyear', 'hour', 'minutes'],
dtype='object')
Perform random search for optimal parameters¶
- Random search over grid search for faster results
[10]:
# create catboost input data
X_pool_train, X_pool_eval, y_pool_train, y_pool_eval = train_test_split(X_train, y_train, test_size=0.2)
train_pool = Pool(data=X_pool_train,
label=y_pool_train,
cat_features=cat_features)
eval_pool = Pool(data=X_pool_eval,
label=y_pool_eval,
cat_features=cat_features)
# set initial catboost kwargs for random search
catboost_kwargs={
"verbose": 1,
"iterations": 10,
"depth": 4,
"learning_rate": 1,
"loss_function": "MAE",
"l2_leaf_reg": 4,
"train_dir": str(Path(output_dir, "catboost_random_search")),
"cat_features": cat_features}
# sensible values for random search after trial-error
grid = {'learning_rate': [0.1, 0.3, 0.5],
'depth': [4, 6, 10],
'l2_leaf_reg': [1, 3, 5, 7, 9]}
search_model = CatBoostRegressor(**catboost_kwargs)
randomized_search_result = search_model.randomized_search(grid,
X=train_pool,
plot=True)
bestTest = 804.7577782
bestIteration = 9
0: loss: 804.7577782 best: 804.7577782 (0) total: 953ms remaining: 8.57s
bestTest = 763.8192678
bestIteration = 9
1: loss: 763.8192678 best: 763.8192678 (1) total: 1.7s remaining: 6.79s
bestTest = 757.4576161
bestIteration = 9
2: loss: 757.4576161 best: 757.4576161 (2) total: 2.42s remaining: 5.64s
bestTest = 763.8192678
bestIteration = 9
3: loss: 763.8192678 best: 757.4576161 (2) total: 3.13s remaining: 4.69s
bestTest = 757.4576161
bestIteration = 9
4: loss: 757.4576161 best: 757.4576161 (2) total: 3.85s remaining: 3.85s
bestTest = 763.8192678
bestIteration = 9
5: loss: 763.8192678 best: 757.4576161 (2) total: 4.59s remaining: 3.06s
bestTest = 757.4576161
bestIteration = 9
6: loss: 757.4576161 best: 757.4576161 (2) total: 5.29s remaining: 2.27s
bestTest = 800.7015278
bestIteration = 9
7: loss: 800.7015278 best: 757.4576161 (2) total: 6.15s remaining: 1.54s
bestTest = 756.1623348
bestIteration = 9
8: loss: 756.1623348 best: 756.1623348 (8) total: 7.01s remaining: 779ms
bestTest = 792.3006764
bestIteration = 9
9: loss: 792.3006764 best: 756.1623348 (8) total: 8.13s remaining: 0us
Estimating final quality...
Train model¶
[11]:
# update catboost parameters with randomized search results
updated_catboost_kwargs = dict(
catboost_kwargs,
iterations = 200,
train_dir = str(output_dir),
**randomized_search_result['params'])
# create new catboost object with updated parameters
model = CatBoostRegressor(**updated_catboost_kwargs)
# train final model
model.fit(train_pool, eval_set = eval_pool, early_stopping_rounds=50, plot=True)
0: learn: 756.4992915 test: 736.1896539 best: 736.1896539 (0) total: 486ms remaining: 1m 36s
1: learn: 716.4095588 test: 689.1406306 best: 689.1406306 (1) total: 1.01s remaining: 1m 40s
2: learn: 697.8411741 test: 668.3668963 best: 668.3668963 (2) total: 1.27s remaining: 1m 23s
3: learn: 692.7714155 test: 663.3020848 best: 663.3020848 (3) total: 1.63s remaining: 1m 19s
4: learn: 678.4410044 test: 645.9118846 best: 645.9118846 (4) total: 2.06s remaining: 1m 20s
5: learn: 665.8758141 test: 631.6815334 best: 631.6815334 (5) total: 2.36s remaining: 1m 16s
6: learn: 659.9978535 test: 624.8699020 best: 624.8699020 (6) total: 2.82s remaining: 1m 17s
7: learn: 657.8774503 test: 622.9807212 best: 622.9807212 (7) total: 3.06s remaining: 1m 13s
8: learn: 655.4896950 test: 620.3435796 best: 620.3435796 (8) total: 3.36s remaining: 1m 11s
9: learn: 653.2327483 test: 618.2358122 best: 618.2358122 (9) total: 3.6s remaining: 1m 8s
10: learn: 651.1965779 test: 616.2311812 best: 616.2311812 (10) total: 3.94s remaining: 1m 7s
11: learn: 650.5140676 test: 615.5689544 best: 615.5689544 (11) total: 4.26s remaining: 1m 6s
12: learn: 649.0516013 test: 613.7866925 best: 613.7866925 (12) total: 4.63s remaining: 1m 6s
13: learn: 648.4775319 test: 613.2001549 best: 613.2001549 (13) total: 4.95s remaining: 1m 5s
14: learn: 647.8060426 test: 612.5047631 best: 612.5047631 (14) total: 5.35s remaining: 1m 5s
15: learn: 647.1291244 test: 611.8994895 best: 611.8994895 (15) total: 5.63s remaining: 1m 4s
16: learn: 645.8213327 test: 610.5029021 best: 610.5029021 (16) total: 5.98s remaining: 1m 4s
17: learn: 644.8812865 test: 609.5610178 best: 609.5610178 (17) total: 6.39s remaining: 1m 4s
18: learn: 644.5974207 test: 609.3279597 best: 609.3279597 (18) total: 6.79s remaining: 1m 4s
19: learn: 644.1965673 test: 608.9084781 best: 608.9084781 (19) total: 7.21s remaining: 1m 4s
20: learn: 642.3113860 test: 607.1499557 best: 607.1499557 (20) total: 7.51s remaining: 1m 3s
21: learn: 640.8101524 test: 605.1263016 best: 605.1263016 (21) total: 7.88s remaining: 1m 3s
22: learn: 639.1655080 test: 603.3016212 best: 603.3016212 (22) total: 8.24s remaining: 1m 3s
23: learn: 638.9469651 test: 603.1433498 best: 603.1433498 (23) total: 8.56s remaining: 1m 2s
24: learn: 638.7524524 test: 602.9733992 best: 602.9733992 (24) total: 8.89s remaining: 1m 2s
25: learn: 638.4115124 test: 602.6926753 best: 602.6926753 (25) total: 9.43s remaining: 1m 3s
26: learn: 637.9200366 test: 602.2545671 best: 602.2545671 (26) total: 9.87s remaining: 1m 3s
27: learn: 637.7586167 test: 602.1181496 best: 602.1181496 (27) total: 10.1s remaining: 1m 2s
28: learn: 637.6390371 test: 602.0260108 best: 602.0260108 (28) total: 10.4s remaining: 1m 1s
29: learn: 637.4954822 test: 601.9018914 best: 601.9018914 (29) total: 10.8s remaining: 1m
30: learn: 637.3643015 test: 601.7997421 best: 601.7997421 (30) total: 11.1s remaining: 1m
31: learn: 637.2449287 test: 601.6952385 best: 601.6952385 (31) total: 11.5s remaining: 1m
32: learn: 636.9031647 test: 601.3616772 best: 601.3616772 (32) total: 11.8s remaining: 59.9s
33: learn: 636.6963216 test: 601.2079507 best: 601.2079507 (33) total: 12.3s remaining: 59.9s
34: learn: 636.2495884 test: 600.8071639 best: 600.8071639 (34) total: 12.6s remaining: 59.3s
35: learn: 635.9025845 test: 600.4896005 best: 600.4896005 (35) total: 13s remaining: 59.1s
36: learn: 635.7190227 test: 600.3220853 best: 600.3220853 (36) total: 13.3s remaining: 58.4s
37: learn: 635.5344418 test: 600.1087458 best: 600.1087458 (37) total: 13.7s remaining: 58.5s
38: learn: 634.8079986 test: 599.4618361 best: 599.4618361 (38) total: 14.1s remaining: 58.4s
39: learn: 634.4043165 test: 599.0290034 best: 599.0290034 (39) total: 14.6s remaining: 58.5s
40: learn: 634.3035510 test: 598.9302567 best: 598.9302567 (40) total: 15s remaining: 58s
41: learn: 634.1347057 test: 598.7614733 best: 598.7614733 (41) total: 15.4s remaining: 58s
42: learn: 633.9306509 test: 598.6202432 best: 598.6202432 (42) total: 15.7s remaining: 57.4s
43: learn: 633.7490183 test: 598.4112566 best: 598.4112566 (43) total: 16.1s remaining: 57.1s
44: learn: 633.2496051 test: 598.0410102 best: 598.0410102 (44) total: 16.6s remaining: 57.1s
45: learn: 632.7809645 test: 597.6645594 best: 597.6645594 (45) total: 16.9s remaining: 56.6s
46: learn: 632.6504159 test: 597.5792268 best: 597.5792268 (46) total: 17.3s remaining: 56.5s
47: learn: 632.3499691 test: 597.2903086 best: 597.2903086 (47) total: 17.8s remaining: 56.2s
48: learn: 632.0798885 test: 597.0893884 best: 597.0893884 (48) total: 18.1s remaining: 55.8s
49: learn: 631.9518132 test: 596.9704911 best: 596.9704911 (49) total: 18.6s remaining: 55.8s
50: learn: 631.8371254 test: 596.8837289 best: 596.8837289 (50) total: 19s remaining: 55.5s
51: learn: 631.6555558 test: 596.7125866 best: 596.7125866 (51) total: 19.4s remaining: 55.2s
52: learn: 631.3500733 test: 596.5806829 best: 596.5806829 (52) total: 19.7s remaining: 54.8s
53: learn: 631.2160011 test: 596.4982592 best: 596.4982592 (53) total: 20.1s remaining: 54.4s
54: learn: 631.1155251 test: 596.4052152 best: 596.4052152 (54) total: 20.5s remaining: 54.1s
55: learn: 630.9322951 test: 596.2872326 best: 596.2872326 (55) total: 20.9s remaining: 53.7s
56: learn: 630.8313626 test: 596.2052357 best: 596.2052357 (56) total: 21.3s remaining: 53.4s
57: learn: 630.5880527 test: 596.0234432 best: 596.0234432 (57) total: 21.6s remaining: 52.8s
58: learn: 630.4350416 test: 595.9243026 best: 595.9243026 (58) total: 22s remaining: 52.5s
59: learn: 630.1765985 test: 595.6232071 best: 595.6232071 (59) total: 22.3s remaining: 52s
60: learn: 629.9420989 test: 595.3269454 best: 595.3269454 (60) total: 22.7s remaining: 51.8s
61: learn: 629.7855614 test: 595.1980082 best: 595.1980082 (61) total: 23.1s remaining: 51.5s
62: learn: 629.6821856 test: 595.1140363 best: 595.1140363 (62) total: 23.5s remaining: 51.1s
63: learn: 629.3721379 test: 594.7708591 best: 594.7708591 (63) total: 23.9s remaining: 50.8s
64: learn: 629.2935584 test: 594.7005343 best: 594.7005343 (64) total: 24.3s remaining: 50.4s
65: learn: 629.1437203 test: 594.5599510 best: 594.5599510 (65) total: 24.6s remaining: 50s
66: learn: 629.0597314 test: 594.4965610 best: 594.4965610 (66) total: 25s remaining: 49.6s
67: learn: 628.9604825 test: 594.4300063 best: 594.4300063 (67) total: 25.3s remaining: 49.2s
68: learn: 628.6144429 test: 594.0584278 best: 594.0584278 (68) total: 25.8s remaining: 49s
69: learn: 628.3600205 test: 593.7500512 best: 593.7500512 (69) total: 26.3s remaining: 48.8s
70: learn: 627.9630414 test: 593.2422692 best: 593.2422692 (70) total: 26.8s remaining: 48.6s
71: learn: 627.8194149 test: 593.1220129 best: 593.1220129 (71) total: 27.1s remaining: 48.1s
72: learn: 627.6032628 test: 592.9470811 best: 592.9470811 (72) total: 27.4s remaining: 47.7s
73: learn: 627.0955542 test: 592.5350740 best: 592.5350740 (73) total: 27.8s remaining: 47.3s
74: learn: 626.9570505 test: 592.4246023 best: 592.4246023 (74) total: 28.3s remaining: 47.1s
75: learn: 626.8580474 test: 592.3451068 best: 592.3451068 (75) total: 28.6s remaining: 46.6s
76: learn: 626.6240930 test: 592.1134100 best: 592.1134100 (76) total: 29.2s remaining: 46.6s
77: learn: 626.5238310 test: 591.9830204 best: 591.9830204 (77) total: 29.7s remaining: 46.5s
78: learn: 626.2930463 test: 591.7926196 best: 591.7926196 (78) total: 30.1s remaining: 46.2s
79: learn: 626.1785614 test: 591.7165412 best: 591.7165412 (79) total: 30.5s remaining: 45.7s
80: learn: 626.1020838 test: 591.6535367 best: 591.6535367 (80) total: 30.8s remaining: 45.3s
81: learn: 626.0103928 test: 591.5821192 best: 591.5821192 (81) total: 31.3s remaining: 45s
82: learn: 625.8423460 test: 591.4126421 best: 591.4126421 (82) total: 31.8s remaining: 44.8s
83: learn: 625.7572063 test: 591.3751926 best: 591.3751926 (83) total: 32.2s remaining: 44.5s
84: learn: 625.6648629 test: 591.2986017 best: 591.2986017 (84) total: 32.6s remaining: 44.1s
85: learn: 624.8146807 test: 590.5119781 best: 590.5119781 (85) total: 33s remaining: 43.8s
86: learn: 624.5953493 test: 590.2107524 best: 590.2107524 (86) total: 33.4s remaining: 43.3s
87: learn: 624.3865336 test: 590.0259422 best: 590.0259422 (87) total: 33.7s remaining: 42.9s
88: learn: 624.1381544 test: 589.7534605 best: 589.7534605 (88) total: 34.2s remaining: 42.7s
89: learn: 624.0491247 test: 589.7121752 best: 589.7121752 (89) total: 34.8s remaining: 42.5s
90: learn: 623.7640891 test: 589.4765438 best: 589.4765438 (90) total: 35.2s remaining: 42.1s
91: learn: 623.6698931 test: 589.3804272 best: 589.3804272 (91) total: 35.5s remaining: 41.7s
92: learn: 623.5908364 test: 589.3442882 best: 589.3442882 (92) total: 35.8s remaining: 41.2s
93: learn: 623.5306758 test: 589.3090298 best: 589.3090298 (93) total: 36.2s remaining: 40.8s
94: learn: 623.3227559 test: 589.0772829 best: 589.0772829 (94) total: 36.5s remaining: 40.3s
95: learn: 623.1951940 test: 588.9860376 best: 588.9860376 (95) total: 36.9s remaining: 40s
96: learn: 622.9879567 test: 588.8807494 best: 588.8807494 (96) total: 37.2s remaining: 39.5s
97: learn: 622.8090526 test: 588.7342637 best: 588.7342637 (97) total: 37.6s remaining: 39.1s
98: learn: 622.5838785 test: 588.7075079 best: 588.7075079 (98) total: 37.9s remaining: 38.7s
99: learn: 622.4825836 test: 588.6086717 best: 588.6086717 (99) total: 38.2s remaining: 38.2s
100: learn: 622.2791306 test: 588.3271311 best: 588.3271311 (100) total: 38.7s remaining: 37.9s
101: learn: 622.2020300 test: 588.2731782 best: 588.2731782 (101) total: 39.2s remaining: 37.6s
102: learn: 621.9820212 test: 588.1600497 best: 588.1600497 (102) total: 39.6s remaining: 37.3s
103: learn: 621.6717145 test: 588.0134407 best: 588.0134407 (103) total: 40s remaining: 36.9s
104: learn: 621.4919530 test: 587.8427733 best: 587.8427733 (104) total: 40.4s remaining: 36.6s
105: learn: 621.4406313 test: 587.7976833 best: 587.7976833 (105) total: 40.8s remaining: 36.2s
106: learn: 620.8423283 test: 587.2740215 best: 587.2740215 (106) total: 41.2s remaining: 35.8s
107: learn: 620.7610185 test: 587.1424969 best: 587.1424969 (107) total: 41.5s remaining: 35.4s
108: learn: 620.6532462 test: 587.0451113 best: 587.0451113 (108) total: 41.9s remaining: 34.9s
109: learn: 620.5812943 test: 587.0190482 best: 587.0190482 (109) total: 42.3s remaining: 34.6s
110: learn: 619.9161984 test: 586.3519279 best: 586.3519279 (110) total: 42.6s remaining: 34.1s
111: learn: 619.8554406 test: 586.2735098 best: 586.2735098 (111) total: 43s remaining: 33.8s
112: learn: 619.7443666 test: 586.2191013 best: 586.2191013 (112) total: 43.4s remaining: 33.4s
113: learn: 619.6755077 test: 586.1544571 best: 586.1544571 (113) total: 43.8s remaining: 33.1s
114: learn: 619.5537904 test: 586.0764185 best: 586.0764185 (114) total: 44.3s remaining: 32.8s
115: learn: 619.2903614 test: 586.0431799 best: 586.0431799 (115) total: 44.7s remaining: 32.4s
116: learn: 619.2025000 test: 585.9869016 best: 585.9869016 (116) total: 45.1s remaining: 32s
117: learn: 618.7112464 test: 585.3193213 best: 585.3193213 (117) total: 45.5s remaining: 31.6s
118: learn: 618.5342821 test: 585.1961176 best: 585.1961176 (118) total: 45.9s remaining: 31.3s
119: learn: 618.3582463 test: 585.0500864 best: 585.0500864 (119) total: 46.3s remaining: 30.9s
120: learn: 618.2471939 test: 584.9732303 best: 584.9732303 (120) total: 46.8s remaining: 30.5s
121: learn: 618.0731285 test: 584.8395377 best: 584.8395377 (121) total: 47.2s remaining: 30.2s
122: learn: 618.0000942 test: 584.7620554 best: 584.7620554 (122) total: 47.8s remaining: 29.9s
123: learn: 617.9257270 test: 584.7316075 best: 584.7316075 (123) total: 48.3s remaining: 29.6s
124: learn: 617.7787591 test: 584.6011831 best: 584.6011831 (124) total: 48.8s remaining: 29.3s
125: learn: 617.6997120 test: 584.5398819 best: 584.5398819 (125) total: 49.2s remaining: 28.9s
126: learn: 617.5244009 test: 584.3529823 best: 584.3529823 (126) total: 49.6s remaining: 28.5s
127: learn: 617.2962930 test: 584.1669826 best: 584.1669826 (127) total: 50s remaining: 28.1s
128: learn: 617.1825785 test: 584.1385620 best: 584.1385620 (128) total: 50.4s remaining: 27.7s
129: learn: 617.1099402 test: 584.0210716 best: 584.0210716 (129) total: 50.8s remaining: 27.3s
130: learn: 617.0114904 test: 583.9671104 best: 583.9671104 (130) total: 51.2s remaining: 27s
131: learn: 616.8188259 test: 583.8113980 best: 583.8113980 (131) total: 51.5s remaining: 26.5s
132: learn: 616.7738320 test: 583.7953743 best: 583.7953743 (132) total: 51.9s remaining: 26.2s
133: learn: 616.5926336 test: 583.7213538 best: 583.7213538 (133) total: 52.3s remaining: 25.8s
134: learn: 616.5084778 test: 583.7059402 best: 583.7059402 (134) total: 52.7s remaining: 25.4s
135: learn: 616.4312025 test: 583.6415595 best: 583.6415595 (135) total: 53.3s remaining: 25.1s
136: learn: 616.3240607 test: 583.5173499 best: 583.5173499 (136) total: 53.7s remaining: 24.7s
137: learn: 616.2783275 test: 583.5019564 best: 583.5019564 (137) total: 54.2s remaining: 24.3s
138: learn: 616.2333421 test: 583.4770307 best: 583.4770307 (138) total: 54.6s remaining: 23.9s
139: learn: 616.2008139 test: 583.4749795 best: 583.4749795 (139) total: 55s remaining: 23.6s
140: learn: 616.0859712 test: 583.3754610 best: 583.3754610 (140) total: 55.5s remaining: 23.2s
141: learn: 616.0157786 test: 583.3133219 best: 583.3133219 (141) total: 55.9s remaining: 22.8s
142: learn: 615.9341868 test: 583.2771837 best: 583.2771837 (142) total: 56.4s remaining: 22.5s
143: learn: 615.8869803 test: 583.2243588 best: 583.2243588 (143) total: 57s remaining: 22.2s
144: learn: 615.8266985 test: 583.2070106 best: 583.2070106 (144) total: 57.4s remaining: 21.8s
145: learn: 615.7423962 test: 583.1519266 best: 583.1519266 (145) total: 57.7s remaining: 21.4s
146: learn: 615.4880554 test: 582.9063992 best: 582.9063992 (146) total: 58.2s remaining: 21s
147: learn: 615.3819690 test: 582.8068397 best: 582.8068397 (147) total: 58.5s remaining: 20.6s
148: learn: 615.3397580 test: 582.7699536 best: 582.7699536 (148) total: 59s remaining: 20.2s
149: learn: 615.1919542 test: 582.6876877 best: 582.6876877 (149) total: 59.4s remaining: 19.8s
150: learn: 615.1626818 test: 582.6696281 best: 582.6696281 (150) total: 59.8s remaining: 19.4s
151: learn: 615.0243761 test: 582.5860315 best: 582.5860315 (151) total: 1m remaining: 19s
152: learn: 614.8813938 test: 582.5034830 best: 582.5034830 (152) total: 1m remaining: 18.6s
153: learn: 614.8605463 test: 582.4908016 best: 582.4908016 (153) total: 1m 1s remaining: 18.2s
154: learn: 614.7724205 test: 582.4161990 best: 582.4161990 (154) total: 1m 1s remaining: 17.8s
155: learn: 614.7300791 test: 582.4021538 best: 582.4021538 (155) total: 1m 1s remaining: 17.4s
156: learn: 614.3910751 test: 582.1378535 best: 582.1378535 (156) total: 1m 1s remaining: 17s
157: learn: 614.2645807 test: 582.0170679 best: 582.0170679 (157) total: 1m 2s remaining: 16.5s
158: learn: 614.1563803 test: 581.9316131 best: 581.9316131 (158) total: 1m 2s remaining: 16.1s
159: learn: 613.9578124 test: 581.7698210 best: 581.7698210 (159) total: 1m 2s remaining: 15.7s
160: learn: 613.7906887 test: 581.5912194 best: 581.5912194 (160) total: 1m 3s remaining: 15.3s
161: learn: 613.7044189 test: 581.5036286 best: 581.5036286 (161) total: 1m 3s remaining: 14.9s
162: learn: 613.5849392 test: 581.4266365 best: 581.4266365 (162) total: 1m 4s remaining: 14.6s
163: learn: 613.3722696 test: 581.2292450 best: 581.2292450 (163) total: 1m 4s remaining: 14.2s
164: learn: 613.1556289 test: 581.1037576 best: 581.1037576 (164) total: 1m 5s remaining: 13.8s
165: learn: 613.0290732 test: 580.9696302 best: 580.9696302 (165) total: 1m 5s remaining: 13.4s
166: learn: 612.9695643 test: 580.9225562 best: 580.9225562 (166) total: 1m 5s remaining: 13s
167: learn: 612.9182516 test: 580.8944932 best: 580.8944932 (167) total: 1m 6s remaining: 12.6s
168: learn: 612.8759733 test: 580.8487442 best: 580.8487442 (168) total: 1m 6s remaining: 12.2s
169: learn: 612.7284283 test: 580.7110513 best: 580.7110513 (169) total: 1m 7s remaining: 11.8s
170: learn: 612.6469981 test: 580.6450549 best: 580.6450549 (170) total: 1m 7s remaining: 11.4s
171: learn: 612.4922876 test: 580.4813941 best: 580.4813941 (171) total: 1m 7s remaining: 11s
172: learn: 612.3682492 test: 580.4154322 best: 580.4154322 (172) total: 1m 8s remaining: 10.6s
173: learn: 612.1525769 test: 580.2475661 best: 580.2475661 (173) total: 1m 8s remaining: 10.2s
174: learn: 612.0985593 test: 580.2149161 best: 580.2149161 (174) total: 1m 8s remaining: 9.85s
175: learn: 611.9614065 test: 580.0693887 best: 580.0693887 (175) total: 1m 9s remaining: 9.46s
176: learn: 611.8774389 test: 580.0103718 best: 580.0103718 (176) total: 1m 9s remaining: 9.07s
177: learn: 611.7511840 test: 579.9108673 best: 579.9108673 (177) total: 1m 10s remaining: 8.68s
178: learn: 611.7252624 test: 579.9017319 best: 579.9017319 (178) total: 1m 10s remaining: 8.29s
179: learn: 611.6226393 test: 579.8035861 best: 579.8035861 (179) total: 1m 11s remaining: 7.89s
180: learn: 611.4549767 test: 579.6788475 best: 579.6788475 (180) total: 1m 11s remaining: 7.51s
181: learn: 611.3454769 test: 579.5541309 best: 579.5541309 (181) total: 1m 12s remaining: 7.13s
182: learn: 611.3054737 test: 579.5266276 best: 579.5266276 (182) total: 1m 12s remaining: 6.73s
183: learn: 611.2498650 test: 579.4875620 best: 579.4875620 (183) total: 1m 12s remaining: 6.33s
184: learn: 611.1533009 test: 579.3775673 best: 579.3775673 (184) total: 1m 13s remaining: 5.92s
185: learn: 610.9020615 test: 579.1870376 best: 579.1870376 (185) total: 1m 13s remaining: 5.54s
186: learn: 610.8316611 test: 579.0995204 best: 579.0995204 (186) total: 1m 14s remaining: 5.16s
187: learn: 610.7659888 test: 579.0790918 best: 579.0790918 (187) total: 1m 14s remaining: 4.76s
188: learn: 610.7128587 test: 579.0734995 best: 579.0734995 (188) total: 1m 15s remaining: 4.37s
189: learn: 610.6606708 test: 579.0397298 best: 579.0397298 (189) total: 1m 15s remaining: 3.97s
190: learn: 610.6550951 test: 579.0353334 best: 579.0353334 (190) total: 1m 15s remaining: 3.58s
191: learn: 610.5822829 test: 578.9660564 best: 578.9660564 (191) total: 1m 16s remaining: 3.18s
192: learn: 610.5296158 test: 578.9261770 best: 578.9261770 (192) total: 1m 16s remaining: 2.78s
193: learn: 610.4417085 test: 578.8428978 best: 578.8428978 (193) total: 1m 17s remaining: 2.39s
194: learn: 610.3924358 test: 578.8206193 best: 578.8206193 (194) total: 1m 17s remaining: 1.99s
195: learn: 610.1421457 test: 578.5998934 best: 578.5998934 (195) total: 1m 17s remaining: 1.59s
196: learn: 610.0201709 test: 578.4677762 best: 578.4677762 (196) total: 1m 18s remaining: 1.19s
197: learn: 609.8806418 test: 578.3458600 best: 578.3458600 (197) total: 1m 18s remaining: 795ms
198: learn: 609.8425944 test: 578.3160699 best: 578.3160699 (198) total: 1m 18s remaining: 397ms
199: learn: 609.7698163 test: 578.3045065 best: 578.3045065 (199) total: 1m 19s remaining: 0us
bestTest = 578.3045065
bestIteration = 199
[11]:
<catboost.core.CatBoostRegressor at 0x1a6c7b74c48>
Evaluate model¶
[12]:
def datetime_to_date(dt):
return pd.to_datetime(dt, utc=True).dt.date
def datetime_to_date_hour(dt):
return pd.to_datetime(dt, utc=True).dt.floor('H')
[13]:
# create predictions on train/test sets
df_predictions = make_predictions_dataframe(model, X_train, X_test, y_train, y_test,
meta_train = X_train_meta,
meta_test = X_test_meta)
df_predictions
[13]:
| id | scheduleDateTime | y | yhat | error | model_set | |
|---|---|---|---|---|---|---|
| 0 | 123414481790510775 | 2018-01-01 03:30:00+01:00 | -480.0 | -197.459328 | 282.540672 | train |
| 1 | 123414479288269149 | 2018-01-01 06:00:00+01:00 | -98.0 | -100.360379 | -2.360379 | train |
| 2 | 123414479666542945 | 2018-01-01 06:05:00+01:00 | -300.0 | -120.601478 | 179.398522 | train |
| 3 | 123414479288365061 | 2018-01-01 06:05:00+01:00 | -300.0 | -121.494374 | 178.505626 | train |
| 4 | 123414479288274329 | 2018-01-01 06:15:00+01:00 | 694.0 | 176.538920 | -517.461080 | train |
| ... | ... | ... | ... | ... | ... | ... |
| 487709 | 124763270719624901 | 2018-07-12 17:25:00+02:00 | 80.0 | 280.244578 | 200.244578 | test |
| 487710 | 124763272032451639 | 2018-07-12 17:25:00+02:00 | 80.0 | 281.817452 | 201.817452 | test |
| 487711 | 124763270368084713 | 2018-07-12 17:25:00+02:00 | 12.0 | 278.095160 | 266.095160 | test |
| 487712 | 124763270625998761 | 2018-07-12 17:45:00+02:00 | -1935.0 | -2.120094 | 1932.879906 | test |
| 487713 | 124763271129903067 | 2018-07-12 17:50:00+02:00 | -8690.0 | 31.526172 | 8721.526172 | test |
487714 rows × 6 columns
[14]:
%%time
df_metrics_long = make_regression_metrics_by_group(df_predictions, group_cols = ["model_set"])
df_daily_metrics_long = make_regression_metrics_by_datetime(df_predictions, freq="D", alias="schedule_date")
df_hourly_metrics_long = make_regression_metrics_by_datetime(df_predictions, freq="H", alias="schedule_date")
Wall time: 9.12 s
[15]:
df_metrics_long.head()
[15]:
| model_set | variable | value | |
|---|---|---|---|
| 0 | test | mae | 841.293334 |
| 1 | train | mae | 558.490217 |
| 2 | test | mape | 365.325530 |
| 3 | train | mape | 448.172767 |
| 4 | test | rmse | 2429.517313 |
[16]:
import plotly.express as px
fig = px.line(df_hourly_metrics_long, x="schedule_date", y="value", facet_row="variable", color="model_set",
width=1200, height=1200, title="Hourly prediction metrics")
# Add range slider
fig.update_layout(
xaxis=dict(
rangeslider=dict(
visible=True
),
type="date"
),
hovermode="x"
)
fig.update_yaxes(matches=None)
# fig.update_xaxes(matches=None)
fig.show()
Plot some prediction results¶
[17]:
def predictions_daily_mean(df_predictions):
df_predictions["schedule_date"] = datetime_to_date(df_predictions["scheduleDateTime"])
df_predictions = df_predictions.drop(columns="id")
df_daily_mean = df_predictions.groupby(["model_set", "schedule_date"]).mean().reset_index()
return df_daily_mean
def predictions_hourly_mean(df_predictions):
df_predictions["schedule_date"] = datetime_to_date_hour(df_predictions["scheduleDateTime"])
df_predictions = df_predictions.drop(columns="id")
df_daily_mean = df_predictions.groupby(["model_set", "schedule_date"]).mean().reset_index()
return df_daily_mean
def get_safe_ylim(y, q=0.05, q2=None):
if q2 is None:
q2 = 1 - q
return (np.quantile(y, q), np.quantile(y, q2))
df_daily_mean = predictions_daily_mean(df_predictions)
y_ylim = get_safe_ylim(df_daily_mean["y"])
error_ylim = get_safe_ylim(df_daily_mean["error"])
df_daily_mean[["schedule_date", "y", "yhat", "model_set"]].plot(x="schedule_date", ylim=y_ylim)
df_daily_mean[["schedule_date", "error", "model_set"]].plot(x="schedule_date", ylim=error_ylim)
df_hourly_mean = predictions_hourly_mean(df_predictions)
y_ylim = get_safe_ylim(df_hourly_mean["y"])
error_ylim = get_safe_ylim(df_hourly_mean["error"])
df_hourly_mean[["schedule_date", "y", "yhat", "model_set"]].groupby("model_set").plot(x="schedule_date", ylim=y_ylim)
df_hourly_mean[["schedule_date", "error", "model_set"]].groupby("model_set").plot(x="schedule_date", ylim=error_ylim)
df_hourly_mean[["schedule_date", "y", "yhat", "error", "model_set"]]
plt.show()
fig = px.line(df_hourly_mean, x="schedule_date", y="error", color="model_set",
width=1200, height=1200, title="Hourly total prediction error")
fig.update_yaxes(matches=None)
fig.update_xaxes(matches=None)
fig.show()
Write output to output directory¶
[18]:
import joblib, pickle
from pathlib import Path
[ ]:
model_file = str(Path(output_dir, "model.pkl"))
predictions_file = str(Path(output_dir, "predictions.csv"))
overall_metrics_file = str(Path(output_dir, "overall_metrics_long.csv"))
daily_metrics_file = str(Path(output_dir, "daily_metrics_long.csv"))
hourly_metrics_file = str(Path(output_dir, "hourly_metrics_long.csv"))
Pickle output files for mlflow artifacts¶
- Pipeline serialized with
joblib - Model data or sample thereof
[19]:
joblib.dump(model, Path(output_dir, model_file))
[19]:
['C:\\Users\\lodew\\qualogy\\schiphol-code-assignment\\scripts\\model.pkl']
Write output to CSV¶
Local or Google Storage is both handled
[20]:
# write output file
write_csv_data(df_predictions, predictions_file, index=False)
write_csv_data(df_metrics_long, overall_metrics_file, index=False)
write_csv_data(df_daily_metrics_long, daily_metrics_file, index=False)
write_csv_data(df_hourly_metrics_long, hourly_metrics_file, index=False)
Writing file to local directory
File: C:\Users\lodew\qualogy\schiphol-code-assignment\scripts\predictions.csv
Writing file to local directory
File: C:\Users\lodew\qualogy\schiphol-code-assignment\scripts\overall_metrics_long.csv
Writing file to local directory
File: C:\Users\lodew\qualogy\schiphol-code-assignment\scripts\daily_metrics_long.csv
Writing file to local directory
File: C:\Users\lodew\qualogy\schiphol-code-assignment\scripts\hourly_metrics_long.csv
Log to MLFlow¶
[21]:
import mlflow
mlflow.set_tracking_uri(mlflow_tracking_uri)
mlflow.set_experiment(mlflow_experiment)
print(f"Logging to experiment: {mlflow_experiment}")
print(f"Run name: {mlflow_run}")
with mlflow.start_run(run_name=mlflow_run):
mlflow.log_param("Input file", input_file)
mlflow.log_param("Train-test file", train_test_file)
# Model metadata
for idx, metric_row in df_metrics_long.iterrows():
metric_name = "__".join([metric_row["variable"], metric_row["model_set"]])
mlflow.log_metric(metric_name, metric_row["value"])
# log artifacts
print("Logging artifacts")
mlflow.log_artifact(predictions_file)
mlflow.log_artifact(overall_metrics_file)
mlflow.log_artifact(daily_metrics_file)
mlflow.log_artifact(hourly_metrics_file)
INFO: 'from_script' does not exist. Creating a new experiment
Logging to experiment: from_script
Run name: catboost_simple
Logging artifacts
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-21-a6e1f7e0ce6b> in <module>
18 # log artifacts
19 print("Logging artifacts")
---> 20 mlflow.log_artifact(predictions_file)
21 mlflow.log_artifact(overall_metrics_file)
22 mlflow.log_artifact(daily_metrics_file)
NameError: name 'predictions_file' is not defined
Overview of the output data¶
[ ]:
df_predictions.info()