How to sync batches in petastorm?

Issue

This Content is from Stack Overflow. Question asked by Omar Puentes

Im new with petastorm and Im facing some issues.
I need to iterate over a dataset getting three equals batches to transform 2 of them to extract some info.
The dataset consist on users ratings movies (like the Movie-Lens dataset). I need to get three batches with the same ratings(rows) to extract each user(in ratings the user could appear repeated) and extract each movie rated. I write this code.

Creating fake dataset and spark converter:

ratings_l = [
    {'uid_dec': 0, 'mid_dec': 6, 'eval': 2.18},
    {'uid_dec': 0, 'mid_dec': 7, 'eval': 3.83},
    {'uid_dec': 0, 'mid_dec': 8, 'eval': 3.94},
    {'uid_dec': 0, 'mid_dec': 9, 'eval': 4.31},
    {'uid_dec': 0, 'mid_dec': 10, 'eval': 4.48},
    {'uid_dec': 0, 'mid_dec': 11, 'eval': 3.74},
    {'uid_dec': 1, 'mid_dec': 6, 'eval': 3.21},
    {'uid_dec': 1, 'mid_dec': 7, 'eval': 2.05},
    {'uid_dec': 1, 'mid_dec': 8, 'eval': 2.24},
    {'uid_dec': 1, 'mid_dec': 9, 'eval': 2.08},
    {'uid_dec': 1, 'mid_dec': 10, 'eval': 4.94},
    {'uid_dec': 1, 'mid_dec': 11, 'eval': 4.22},
    {'uid_dec': 2, 'mid_dec': 6, 'eval': 3.52},
    {'uid_dec': 2, 'mid_dec': 7, 'eval': 2.67},
    {'uid_dec': 2, 'mid_dec': 8, 'eval': 2.69},
    {'uid_dec': 2, 'mid_dec': 9, 'eval': 2.75},
    {'uid_dec': 2, 'mid_dec': 10, 'eval': 4.93},
    {'uid_dec': 2, 'mid_dec': 11, 'eval': 2.9},
    {'uid_dec': 3, 'mid_dec': 6, 'eval': 2.0},
    {'uid_dec': 3, 'mid_dec': 7, 'eval': 2.9},
    {'uid_dec': 3, 'mid_dec': 8, 'eval': 4.74},
    {'uid_dec': 3, 'mid_dec': 9, 'eval': 2.5},
    {'uid_dec': 3, 'mid_dec': 10, 'eval': 2.18},
    {'uid_dec': 3, 'mid_dec': 11, 'eval': 4.93},
    {'uid_dec': 4, 'mid_dec': 6, 'eval': 4.46},
    {'uid_dec': 4, 'mid_dec': 7, 'eval': 2.23},
    {'uid_dec': 4, 'mid_dec': 8, 'eval': 4.42},
    {'uid_dec': 4, 'mid_dec': 9, 'eval': 4.67},
    {'uid_dec': 4, 'mid_dec': 10, 'eval': 2.65},
    {'uid_dec': 4, 'mid_dec': 11, 'eval': 2.11},
    {'uid_dec': 5, 'mid_dec': 6, 'eval': 2.31},
    {'uid_dec': 5, 'mid_dec': 7, 'eval': 2.69},
    {'uid_dec': 5, 'mid_dec': 8, 'eval': 2.41},
    {'uid_dec': 5, 'mid_dec': 9, 'eval': 4.62},
    {'uid_dec': 5, 'mid_dec': 10, 'eval': 3.96},
    {'uid_dec': 5, 'mid_dec': 11, 'eval': 2.23}
]

train_ds = spark.createDataFrame(ratings_l)

conv_train = make_spark_converter(train_ds)

Get three batches from the same converter(hoping they are the same):

epochs = 4
batch_size = 6
with conv_train.make_tf_dataset(batch_size=batch_size, num_epochs=epochs, seed=1) as train, 
     conv_train.make_tf_dataset(batch_size=batch_size, num_epochs=epochs, seed=1) as train1, 
     conv_train.make_tf_dataset(batch_size=batch_size, num_epochs=epochs, seed=1) as train2:
     epoch_eval = True
     for i, (b, b1, b2) in enumerate(zip(train, train1, train2)):
        if i%(36//batch_size) == 0:
            print('==========Epoch==========: {0}'.format(i//(36//batch_size)))
        print('==========Batch: {}'.format(i%(36//batch_size)))
        print(b[0].numpy())
        print(b1[0].numpy())
        print(b2[0].numpy())

This is the output:

==========Epoch==========: 0
==========Batch: 0
[2.   2.9  4.74 2.5  2.18 4.93]
[2.18 3.83 3.94 4.31 4.48 3.74]
[2.   2.9  4.74 2.5  2.18 4.93]
==========Batch: 1
[4.46 2.23 4.42 4.67 2.65 2.11]
[3.21 2.05 2.24 2.08 4.94 4.22]
[4.46 2.23 4.42 4.67 2.65 2.11]
==========Batch: 2
[2.31 2.69 2.41 4.62 3.96 2.23]
[3.52 2.67 2.69 2.75 4.93 2.9 ]
[2.31 2.69 2.41 4.62 3.96 2.23]
==========Batch: 3
[2.18 3.83 3.94 4.31 4.48 3.74]
[2.18 3.83 3.94 4.31 4.48 3.74]
[2.18 3.83 3.94 4.31 4.48 3.74]
==========Batch: 4
[3.21 2.05 2.24 2.08 4.94 4.22]
[3.21 2.05 2.24 2.08 4.94 4.22]
[3.21 2.05 2.24 2.08 4.94 4.22]
==========Batch: 5
[3.52 2.67 2.69 2.75 4.93 2.9 ]
[3.52 2.67 2.69 2.75 4.93 2.9 ]
[3.52 2.67 2.69 2.75 4.93 2.9 ]
==========Epoch==========: 1
==========Batch: 0
[2.18 3.83 3.94 4.31 4.48 3.74]
[2.   2.9  4.74 2.5  2.18 4.93]
[2.18 3.83 3.94 4.31 4.48 3.74]
==========Batch: 1
[3.21 2.05 2.24 2.08 4.94 4.22]
[4.46 2.23 4.42 4.67 2.65 2.11]
[3.21 2.05 2.24 2.08 4.94 4.22]
==========Batch: 2
[3.52 2.67 2.69 2.75 4.93 2.9 ]
[2.31 2.69 2.41 4.62 3.96 2.23]
[3.52 2.67 2.69 2.75 4.93 2.9 ]
==========Batch: 3
[2.   2.9  4.74 2.5  2.18 4.93]
[2.   2.9  4.74 2.5  2.18 4.93]
[2.   2.9  4.74 2.5  2.18 4.93]
==========Batch: 4
[4.46 2.23 4.42 4.67 2.65 2.11]
[4.46 2.23 4.42 4.67 2.65 2.11]
[4.46 2.23 4.42 4.67 2.65 2.11]
==========Batch: 5
[2.31 2.69 2.41 4.62 3.96 2.23]
[2.31 2.69 2.41 4.62 3.96 2.23]
[2.31 2.69 2.41 4.62 3.96 2.23]

The question is: Why Im getting three differents batches insted of three batches with the same values.



Solution

This question is not yet answered, be the first one who answer using the comment. Later the confirmed answer will be published as the solution.

This Question and Answer are collected from stackoverflow and tested by JTuto community, is licensed under the terms of CC BY-SA 2.5. - CC BY-SA 3.0. - CC BY-SA 4.0.

people found this article helpful. What about you?