Create train/test split of the data¶
This notebook will take a DataFrame with at least ['id', 'scheduleDateTime'] and creates a train/test split
- Get
idcolumn of the data and split into a train/test set - Strategy of either ‘sample’ or ‘timeseries’
TODO:
- Stratification based on important groups we have yet to identify
- Decide whether validation splits should also be made here or if OK to do downstream
Parameters¶
- input_file: Filepath of flights data in format received from Schiphol
- output_file: Filepath to write output csv file with minimal modelling input
- strategy: One of [‘sample’, ‘timeseries’]
- test_size: (Optional) Default 0.3. Fraction to use as test data between 0 and 1
- val_size: (Optional) Default 0.1. Fraction to use as validation data between 0 and 1
If strategy == ‘timeseries’ then data is split on the scheduleDateTime column and takes the last test_size from the data as the test set.
Returns¶
Output format
id | model_set |
123414481790510775 | train |
123414479288269149 | train |
123414479666542945 | test |
123414479288365061 | test |
123414479288274329 | validation | # validation set not currently implemented
Script parameters¶
[7]:
# parameters
input_file = "../lvt-schiphol-assignment-snakemake/data/model_input/delays_extended_input.csv"
output_file = "train_test__sample__0.2.csv"
strategy = 'sample'
test_size = 0.2
[8]:
assert test_size < 1 and test_size > 0
assert strategy in ['sample', 'timeseries']
Imports¶
[9]:
import pandas as pd
from sklearn.model_selection import train_test_split
import sys
sys.path.append("../")
from src.data.google_storage_io import read_csv_data, write_csv_data
Load data¶
[10]:
%%time
df = read_csv_data(input_file)
# subset only columns we need
df = df[["id", "scheduleDateTime"]]
df["scheduleDateTime"] = pd.to_datetime(df["scheduleDateTime"])
df.head()
Reading file from local directory
File: ../lvt-schiphol-assignment-snakemake/data/model_input/delays_extended_input.csv
Wall time: 8.22 s
[10]:
| id | scheduleDateTime | |
|---|---|---|
| 0 | 123414481790510775 | 2018-01-01 03:30:00+01:00 |
| 1 | 123414479288269149 | 2018-01-01 06:00:00+01:00 |
| 2 | 123414479666542945 | 2018-01-01 06:05:00+01:00 |
| 3 | 123414479288365061 | 2018-01-01 06:05:00+01:00 |
| 4 | 123414479288274329 | 2018-01-01 06:15:00+01:00 |
Make train/test split¶
[11]:
print(f"Strategy: {strategy}")
if strategy == 'sample':
# sample like usual
train_ids, test_ids = train_test_split(df["id"], test_size=test_size)
elif strategy == 'timeseries':
# select last fraction of data as test set
df = df.sort_values("scheduleDateTime").reset_index()
test_size = int(len(df) * 0.2)
train_ids, test_ids = df.iloc[:-test_size]["id"], df.iloc[-test_size:]["id"]
df_train_test = pd.concat([
pd.DataFrame(dict(id = train_ids.values, model_set = "train")),
pd.DataFrame(dict(id = test_ids.values, model_set = "test"))
])
df_train_test
Strategy: sample
[11]:
| id | model_set | |
|---|---|---|
| 0 | 123583078680134091 | train |
| 1 | 124594671792954447 | train |
| 2 | 123990526571117207 | train |
| 3 | 124397975056975081 | train |
| 4 | 123421504625349921 | train |
| ... | ... | ... |
| 97538 | 123899201831170143 | test |
| 97539 | 124060775885955111 | test |
| 97540 | 124039700039122673 | test |
| 97541 | 123885150856850759 | test |
| 97542 | 124552523009520195 | test |
487714 rows × 2 columns
Visualize train/test split over time¶
Show number of samples per day to understand train/test distribution
[16]:
import plotly.express as px
df_plot = pd.merge(
df_train_test,
df, on="id", how="left") \
.assign(
schedule_date = lambda d: pd.to_datetime(d["scheduleDateTime"], utc=True).dt.date) \
.groupby(["schedule_date", "model_set"])["id"].count().reset_index(name="n_samples")
px.line(df_plot, x="schedule_date", y="n_samples", color="model_set")
Local or Google Storage is both handled
[11]:
# write output file
write_csv_data(df_train_test, output_file, index=False)
Writing file to local directory
File: processed_flights.csv
Overview of the output data¶
[13]:
df_train_test.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 487716 entries, 0 to 487715
Data columns (total 8 columns):
id 487716 non-null int64
aircraftRegistration 487713 non-null object
airlineCode 486503 non-null float64
terminal 477391 non-null float64
serviceType 482937 non-null object
scheduleDateTime 487716 non-null datetime64[ns, Europe/Amsterdam]
actualOffBlockTime 487716 non-null datetime64[ns, Europe/Amsterdam]
scheduleDelaySeconds 487716 non-null float64
dtypes: datetime64[ns, Europe/Amsterdam](2), float64(3), int64(1), object(2)
memory usage: 29.8+ MB