Bookmarks

You haven't yet saved any bookmarks. To bookmark a post, just click .

  • Random splitting is Wicked.

  • You would have probably seen something like this in every machine learning tutorial out there:

    from sklearn.model_selection import train_test_split
    
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, 
        test_size=0.33, 
        random_state=42)

    But there is a problem, it is rare that the rows are independent.

    Take, for example, if we were trying to predict the arrival delays of flights on a particular day, the instances/rows will be highly correlated. This can lead to leakage of information between the training and test dataset.

    Plus, unless we set random_state the train_test_split will produce complete different splits every time it is run. This will pose a problem when we are trying to consider reproducibility in our machine learning workflow.

    This is where Repeatable Splitting comes in handy. Repeatable splitting of the data that works regardless of programming language or random seeds. This also makes sure that correlated rows fall into the same split.


    The solution is to first identify a column that captures the correlation relationship between rows. Then, we use the last few digits as input to a hash function on that column to split the data.

    So, as in a time series dataset, where often the rows are correlated, we can use the date column and pass it to the Farm Fingerprint hashing algorithm to split the available data into required splits.

    SELECT
      airline,
      feature_1,
      feature_2,
      feature_3,
      feature_4
    FROM
      `timeseries-data`.airline_ontime_data.flights
    WHERE
      ABS(MOD(FARM_FINGERPRINT(date), 10)) < 8 -- 80% for TRAIN

    Here, we compute the hash using the FARM_FINGERPRINT function and then use the modulo function to find an arbitrary 80% subset of the rows.

    This is now repeatable–because the FARM_FINGERPRINT function returns the same value any time it is invoked on a specific timestamp, we can be sure we will get the same 80% of data each time.

    But, there are some considerations when choosing which column to split on:

    • Rows on the same date tend to be correlated. Correlation is the biggest factor in the selection of column(s) on which to split.
    • date is not an input to the model even though it is used as a criterion for splitting. We can't use an actual input as the field with which to split because the trained model will not have seen 20% of the possible input values for the date column if we use 80% of the data for training (say the date column, 80% of values would remain in the test set, 20% unseen).
    • There have to enough date values. A rule of thumb is to shoot for 3-5% the denominator for the modulo, so in this case, we want 40 or so unique dates.
    • The label has to be well distributed among the dates. To be safe, look at the distribution graph and make sure that all three splits have a similar distribution of labels.

    Kolomogorov-Smirnov Test

    To check whether the label distributions are similar across the three datasets, plot the cumulative distribution functions of the label in the three datasets and find the maximum distance between each pair.

    The smaller the maximum distance, the better the split.

    Trade-Offs and Alternatives

    Single Query

    We can have a single query to generate training, validation, test splits:

    CREATE OR REPLACE TABLE mydataset.mytable AS
    SELECT
      airline,
      feature_1,
      feature_2,
      feature_3,
      feature_4,
      CASE(ABS(FARM_FINGERPRINT(date), 10)))
           WHEN 9 THEN 'test'
           WHEN 8 THEN 'validation'
           ELSE 'training' END AS split_col
    FROM
      `timeseries-data`.airline_ontime_data.flights

    Random split

    If the rows are not correlated, we can hash the entire row of data by converting it to a string and hashing that string:

    SELECT
      airline,
      feature_1,
      feature_2,
      feature_3,
      feature_4,
    FROM
      `timeseries-data`.airline_ontime_data.flights f
    WHERE
      ABS(MOD(FARM_FINGERPRINT(TO_JSON_STRING(f), 10)) < 8
    Duplicate rows will always fall in the same split. If that's not the behavior we want, add a unique ID to SELECT query.

    Split on multiple columns

    It might happen that a combination of multiple rows might be correlated, say the date and weather. In that case, we can simply concatenate the fields (creating a feature cross) before computing the hash.

    CREATE OR REPLACE TABLE mydataset.mytable AS
    SELECT
      airline,
      feature_1,
      feature_2,
      feature_3,
      arrival_airport,
    FROM
      `timeseries-data`.airline_ontime_data.flights
    WHERE
      ABS(MOD(FARM_FINGERPRINT(CONCAT(date, arrival_airport), 10)) < 8

    If we split on a feature cross of multiple columns, we can use arrival_airport (or any other feature used in conjunction) as one of the inputs to the model, since there will be examples of any particular airport in both the training and test sets.

    Repeatable sampling

    If we wanted to create a smaller dataset out of a bigger one (say for local development), how would we go about doing it repeatable? If we have a dataset of 50 million examples and we want a smaller dataset of one million flights? How would we pick 1 in 50 flights, and then 80% of those as training?

    What we cannot do is:

    SELECT
      airline,
      feature_1,
      feature_2,
      feature_3,
      feature_4,
    FROM
      `timeseries-data`.airline_ontime_data.flights f
    WHERE
      ABS(MOD(FARM_FINGERPRINT(date), 50)) = 0
      AND ABS(MOD(FARM_FINGERPRINT(date), 10)) < 8
    We shouldn't do!

    We cannot pick 1 in 50 rows and then pick 8 in 10. Those rows which are divisible by 50 are also going to be divisible by 10.

    What we can do however is:

    SELECT
      airline,
      feature_1,
      feature_2,
      feature_3,
      feature_4,
    FROM
      `timeseries-data`.airline_ontime_data.flights f
    WHERE
      ABS(MOD(FARM_FINGERPRINT(date), 50)) = 0
      AND ABS(MOD(FARM_FINGERPRINT(date), 500)) < 400

    In this query, the 500 is 70*10, and 400 is 50*8 (80% as training).

    The first modulo picks 1 in 50 rows and the second modulo picks 8 in 10 of those rows.

    For validation, you can change the query as:

      ABS(MOD(FARM_FINGERPRINT(date), 50)) = 0
      AND ABS(MOD(FARM_FINGERPRINT(date), 500)) BETWEEN 400 AND 449 -- (9*50)

    Sequential split

    In the case of time series models, a very common approach is to use sequential splits of data. The idea is to assign blocks or intervals of series data to various splits preserving the correlation among those examples in the individual split.

    Sequential split of data is also necessary for fast-moving environments such as fraud detection or spam detection even if the goal is not to predict the future value of time series. The goal instead is to quickly adapt to new data and predict behavior in sooner future.

    Another instance where a sequential slit of data is needed is when there are high correlations between successive times and we need to take seasonality into account. Take weather forecasts for example. Successive day's weather depends on the previous day's one and is affected year long.

    To do a sequential split in this case, we'll take the first 20 days of every month in the training dataset, the next 5 days in the validation dataset, and the last 5 days in the testing dataset.

    Stratified split

    In the above example, it was required that the splitting needs to happen after the dataset is stratified. Means we needed to account for the distribution of individual category/type of examples to remain same in splits, matching the distribution in the complete unsplitted dataset.

    The larger the dataset, the less concerned we have to be with stratification. Therefore, in large-scale machine learning, the need to stratify isn't very common unless in the case of skewed datasets.

    Unstructured data

    Performing repeatable splitting in the case of structured data is quite straightforward. In the case of unstructured data, we can perform the same by using metadata information.

    It is worth noting that, many problems with poor performance of ML can be addressed by designing the data split (and data collection) with potential correlations in mind.

    Hope you learned something new.

    This is Anurag Dhadse, signing off.