Create train/test split of the data

This notebook will take a DataFrame with at least ['id', 'scheduleDateTime'] and creates a train/test split

  1. Get id column of the data and split into a train/test set
  2. Strategy of either ‘sample’ or ‘timeseries’

TODO:

  • Stratification based on important groups we have yet to identify
  • Decide whether validation splits should also be made here or if OK to do downstream

Parameters


  • input_file: Filepath of flights data in format received from Schiphol
  • output_file: Filepath to write output csv file with minimal modelling input
  • strategy: One of [‘sample’, ‘timeseries’]
  • test_size: (Optional) Default 0.3. Fraction to use as test data between 0 and 1
  • val_size: (Optional) Default 0.1. Fraction to use as validation data between 0 and 1

If strategy == ‘timeseries’ then data is split on the scheduleDateTime column and takes the last test_size from the data as the test set.

Returns


Output format

id                   |   model_set |
123414481790510775   |  train      |
123414479288269149   |  train      |
123414479666542945   |  test       |
123414479288365061   |  test       |
123414479288274329   |  validation |    # validation set not currently implemented

Script parameters

[7]:
# parameters
input_file = "../lvt-schiphol-assignment-snakemake/data/model_input/delays_extended_input.csv"
output_file = "train_test__sample__0.2.csv"
strategy = 'sample'
test_size = 0.2
[8]:
assert test_size < 1 and test_size > 0
assert strategy in ['sample', 'timeseries']

Imports

[9]:
import pandas as pd
from sklearn.model_selection import train_test_split

import sys
sys.path.append("../")

from src.data.google_storage_io import read_csv_data, write_csv_data

Load data

[10]:
%%time

df = read_csv_data(input_file)

# subset only columns we need
df = df[["id", "scheduleDateTime"]]
df["scheduleDateTime"] = pd.to_datetime(df["scheduleDateTime"])
df.head()
Reading file from local directory
File:   ../lvt-schiphol-assignment-snakemake/data/model_input/delays_extended_input.csv

Wall time: 8.22 s
[10]:
id scheduleDateTime
0 123414481790510775 2018-01-01 03:30:00+01:00
1 123414479288269149 2018-01-01 06:00:00+01:00
2 123414479666542945 2018-01-01 06:05:00+01:00
3 123414479288365061 2018-01-01 06:05:00+01:00
4 123414479288274329 2018-01-01 06:15:00+01:00

Make train/test split

[11]:
print(f"Strategy: {strategy}")

if strategy == 'sample':
    # sample like usual
    train_ids, test_ids = train_test_split(df["id"], test_size=test_size)
elif strategy == 'timeseries':
    # select last fraction of data as test set
    df = df.sort_values("scheduleDateTime").reset_index()
    test_size = int(len(df) * 0.2)
    train_ids, test_ids = df.iloc[:-test_size]["id"], df.iloc[-test_size:]["id"]

df_train_test = pd.concat([
    pd.DataFrame(dict(id = train_ids.values, model_set = "train")),
    pd.DataFrame(dict(id = test_ids.values, model_set = "test"))
])

df_train_test
Strategy: sample
[11]:
id model_set
0 123583078680134091 train
1 124594671792954447 train
2 123990526571117207 train
3 124397975056975081 train
4 123421504625349921 train
... ... ...
97538 123899201831170143 test
97539 124060775885955111 test
97540 124039700039122673 test
97541 123885150856850759 test
97542 124552523009520195 test

487714 rows × 2 columns

Visualize train/test split over time

Show number of samples per day to understand train/test distribution

[16]:
import plotly.express as px
df_plot = pd.merge(
        df_train_test,
        df, on="id", how="left") \
    .assign(
        schedule_date = lambda d: pd.to_datetime(d["scheduleDateTime"], utc=True).dt.date) \
    .groupby(["schedule_date", "model_set"])["id"].count().reset_index(name="n_samples")

px.line(df_plot, x="schedule_date", y="n_samples", color="model_set")

Local or Google Storage is both handled

[11]:
# write output file
write_csv_data(df_train_test, output_file, index=False)
Writing file to local directory
File:   processed_flights.csv

Overview of the output data

[13]:
df_train_test.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 487716 entries, 0 to 487715
Data columns (total 8 columns):
id                      487716 non-null int64
aircraftRegistration    487713 non-null object
airlineCode             486503 non-null float64
terminal                477391 non-null float64
serviceType             482937 non-null object
scheduleDateTime        487716 non-null datetime64[ns, Europe/Amsterdam]
actualOffBlockTime      487716 non-null datetime64[ns, Europe/Amsterdam]
scheduleDelaySeconds    487716 non-null float64
dtypes: datetime64[ns, Europe/Amsterdam](2), float64(3), int64(1), object(2)
memory usage: 29.8+ MB