A little brainteaser (or i’m an idiot)

This took me waaay too long to work out today and i was thinking it could make a nice little interview coding type question (which i’d probably fail).

Suppose you have 10,000 rows of data and need to continually train and retrain a model training on at most 1,000 rows at a time and retraining the model every 500 rows, can you tell me how many “batches” of data this will create and the start and end index of each batch?

So thats:

n = 10000
train_max = 1000
train_every = 500

And we want a dictionary like this:

{
  1: {"start": 1, "end": 1000},
  2: {"start": 500, "end": 1000}, 
  ...
  ?: {"start": ?, "end": ?},
}

After doing some crazy loops in python for a while I decided to go back to basics and do it Jeremy Howard style in excel (well gsheets – i’m not a savage) – gsheet.

And here is my Python solution:

def calc_batches(train_max: int, train_every: int, n: int) -> dict:
batches = dict()
# loop over up to as many records as you have
for batch in range(n):
# work out the start of the batch, with a max() to handle first batch
start = max(train_every * batch, 1)
# work out the end of the batch, with a min() to handle last batch
end = min(train_max+(train_every * batch), n)
# add batch info to the dictionary
batches[batch+1] = {"start": start, "end": end}
# break out once you have assigned all rows to a batch
if end == n:
break
return batches
calc_batches(train_max=1000, train_every=500, n=10000)
'''
{1: {'start': 1, 'end': 1000},
2: {'start': 500, 'end': 1500},
3: {'start': 1000, 'end': 2000},
4: {'start': 1500, 'end': 2500},
5: {'start': 2000, 'end': 3000},
6: {'start': 2500, 'end': 3500},
7: {'start': 3000, 'end': 4000},
8: {'start': 3500, 'end': 4500},
9: {'start': 4000, 'end': 5000},
10: {'start': 4500, 'end': 5500},
11: {'start': 5000, 'end': 6000},
12: {'start': 5500, 'end': 6500},
13: {'start': 6000, 'end': 7000},
14: {'start': 6500, 'end': 7500},
15: {'start': 7000, 'end': 8000},
16: {'start': 7500, 'end': 8500},
17: {'start': 8000, 'end': 9000},
18: {'start': 8500, 'end': 9500},
19: {'start': 9000, 'end': 10000}}
'''
view raw calc_batches.py hosted with ❤ by GitHub

…I’m pretty sure someone will come along with a super pythonic one liner that shows maybe i am an idiot after all.

Ok now back to work.

Update: Actually i think what i want is more something like the below where you can define a minimum and maximum size of your training data and then roll that over your data.

def calc_batches(train_min: int, train_max: int, train_every: int, n: int) -> dict:
batches = dict()
batch = 0
for row in range(1,n+1):
if row < train_min:
pass
elif row == train_min:
batches[batch] = dict(start=0, end=row)
elif row % train_every == 0:
batch += 1
batches[batch] = dict(start=max(0,row-train_max), end=row)
return batches
calc_batches(train_min=1000, train_max=5000, train_every=500, n=10000)
"""
{0: {'start': 0, 'end': 1000},
1: {'start': 0, 'end': 1500},
2: {'start': 0, 'end': 2000},
3: {'start': 0, 'end': 2500},
4: {'start': 0, 'end': 3000},
5: {'start': 0, 'end': 3500},
6: {'start': 0, 'end': 4000},
7: {'start': 0, 'end': 4500},
8: {'start': 0, 'end': 5000},
9: {'start': 500, 'end': 5500},
10: {'start': 1000, 'end': 6000},
11: {'start': 1500, 'end': 6500},
12: {'start': 2000, 'end': 7000},
13: {'start': 2500, 'end': 7500},
14: {'start': 3000, 'end': 8000},
15: {'start': 3500, 'end': 8500},
16: {'start': 4000, 'end': 9000},
17: {'start': 4500, 'end': 9500},
18: {'start': 5000, 'end': 10000}}
"""

Leave a Reply