Skip to content

Commit 1925d8b

Browse files
authored
Fix seed setting in samplers (#5816)
1 parent a450123 commit 1925d8b

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

scio-core/src/main/scala/com/spotify/scio/util/random/RandomSampler.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ private[scio] object RandomSampler {
4040

4141
abstract private[scio] class RandomSampler[T, R] extends DoFn[T, T] {
4242
protected var rng: R = _
43-
protected var seed: Long = -1
43+
protected var seed: Option[Long] = None
4444

4545
// TODO: is it necessary to setSeed for each instance like Spark does?
4646
@StartBundle
@@ -58,7 +58,7 @@ abstract private[scio] class RandomSampler[T, R] extends DoFn[T, T] {
5858

5959
def init: R
6060
def samples: Int
61-
def setSeed(seed: Long): Unit = this.seed = seed
61+
def setSeed(seed: Long): Unit = this.seed = Some(seed)
6262
}
6363

6464
/**
@@ -74,6 +74,7 @@ abstract private[scio] class RandomSampler[T, R] extends DoFn[T, T] {
7474
*/
7575
private[scio] class BernoulliSampler[T](val fraction: Double, private val seedOpt: Option[Long])
7676
extends RandomSampler[T, JRandom] {
77+
this.seed = seedOpt
7778

7879
/** Epsilon slop to avoid failure from floating point jitter */
7980
require(
@@ -84,7 +85,7 @@ private[scio] class BernoulliSampler[T](val fraction: Double, private val seedOp
8485

8586
override def init: JRandom = {
8687
val r = RandomSampler.newDefaultRNG
87-
seedOpt.foreach(r.setSeed)
88+
this.seed.foreach(r.setSeed)
8889
r
8990
}
9091

@@ -117,6 +118,7 @@ private[scio] object BernoulliSampler {
117118
*/
118119
private[scio] class PoissonSampler[T](val fraction: Double, private val seedOpt: Option[Long])
119120
extends RandomSampler[T, IntegerDistribution] {
121+
this.seed = seedOpt
120122

121123
/** Epsilon slop to avoid failure from floating point jitter. */
122124
require(
@@ -128,7 +130,7 @@ private[scio] class PoissonSampler[T](val fraction: Double, private val seedOpt:
128130
// If fraction is <= 0, 0 is used below, so we can use any placeholder value.
129131
override def init: IntegerDistribution = {
130132
val r = new PoissonDistribution(if (fraction > 0.0) fraction else 1.0)
131-
seedOpt.foreach(r.reseedRandomGenerator)
133+
this.seed.foreach(r.reseedRandomGenerator)
132134
r
133135
}
134136

0 commit comments

Comments
 (0)