Skip to content

Commit ef08790

Browse files
authored
IA-4720 Calculate spark.driver.memory (#4007)
``` terra-quality-c73fdcc7/automation-test-a0fs2ov5z errored due to List(RuntimeError(Jupyter, Welder failed to start after 10 minutes.,None,2023-12-06T21:11:49Z,Some(TraceId(eafd22dba736444681b48f1541d714f5/3939911508519804487)))) ``` One test failure and this PR doesn't affect GCE VMs..so going to merge as is
1 parent dbdda15 commit ef08790

File tree

5 files changed

+90
-8
lines changed

5 files changed

+90
-8
lines changed

core/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/runtimeModels.scala

+7-1
Original file line numberDiff line numberDiff line change
@@ -542,8 +542,14 @@ object MemorySize {
542542
* Resource constraints for a runtime.
543543
* See https://docs.docker.com/compose/compose-file/compose-file-v2/#cpu-and-other-resources
544544
* for other types of resources we may want to add here.
545+
*
546+
* driverMemory will be populated if it's Dataproc, and machine type is either n1-standard-{x} or n1-highmem-{x} when using a different algorithm
547+
* for calculating spark:spark.driver.memory
545548
*/
546-
final case class RuntimeResourceConstraints(memoryLimit: MemorySize, totalMachineMemory: MemorySize)
549+
final case class RuntimeResourceConstraints(memoryLimit: MemorySize,
550+
totalMachineMemory: MemorySize,
551+
driverMemory: Option[MemorySize]
552+
)
547553

548554
final case class RuntimeMetrics(cloudContext: CloudContext,
549555
runtimeName: RuntimeName,

http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/util/DataprocInterpreter.scala

+15-3
Original file line numberDiff line numberDiff line change
@@ -786,13 +786,21 @@ class DataprocInterpreter[F[_]: Parallel](
786786
val sparkMemoryConfigRatio = config.dataprocConfig.sparkMemoryConfigRatio.getOrElse(0.8)
787787
// We still want a minimum to run Jupyter and other system processes.
788788
val minRuntimeMemoryGb = config.dataprocConfig.minimumRuntimeMemoryInGb.getOrElse(4.0)
789+
// Note this algorithm is recommended by Hail team. See more info in https://broadworkbench.atlassian.net/browse/IA-4720
790+
val sparkDriverMemory = machineType match {
791+
case MachineTypeName(n1standard) if n1standard.startsWith("n1-standard") =>
792+
Some(MemorySize.fromGb((total.bytes / MemorySize.gbInBytes - 7) * 0.9))
793+
case MachineTypeName(n1highmem) if n1highmem.startsWith("n1-highmem") =>
794+
Some(MemorySize.fromGb((total.bytes / MemorySize.gbInBytes - 11) * 0.9))
795+
case _ => none[MemorySize]
796+
}
789797
val runtimeAllocatedMemory =
790798
Math.max(
791799
(total.bytes * (1 - sparkMemoryConfigRatio)).toLong,
792800
MemorySize.fromGb(minRuntimeMemoryGb).bytes
793801
)
794802

795-
RuntimeResourceConstraints(MemorySize(runtimeAllocatedMemory), MemorySize(total.bytes))
803+
RuntimeResourceConstraints(MemorySize(runtimeAllocatedMemory), MemorySize(total.bytes), sparkDriverMemory)
796804
}
797805

798806
/**
@@ -887,8 +895,12 @@ class DataprocInterpreter[F[_]: Parallel](
887895
Map("dataproc:dataproc.allow.zero.workers" -> "true")
888896
} else Map.empty[String, String]
889897

890-
val memoryLimitInMb =
891-
(jupyterResourceConstraints.totalMachineMemory.bytes - jupyterResourceConstraints.memoryLimit.bytes) / MemorySize.mbInBytes
898+
val memoryLimitInMb = jupyterResourceConstraints.driverMemory match {
899+
case Some(value) => value.bytes / MemorySize.mbInBytes
900+
case None =>
901+
// We use a different algorithm to calculate spark.driver.memory when machine type is not n1-standard and n1-highmem
902+
(jupyterResourceConstraints.totalMachineMemory.bytes - jupyterResourceConstraints.memoryLimit.bytes) / MemorySize.mbInBytes
903+
}
892904
val driverMemoryProp = Map("spark:spark.driver.memory" -> s"${memoryLimitInMb}m")
893905

894906
val yarnProps = Map(

http/src/main/scala/org/broadinstitute/dsde/workbench/leonardo/util/GceInterpreter.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ class GceInterpreter[F[_]](
544544
gceAllocated = config.gceConfig.gceReservedMemory.map(_.bytes).getOrElse(0L)
545545
welderAllocated = config.welderConfig.welderReservedMemory.map(_.bytes).getOrElse(0L)
546546
result = MemorySize(total.bytes - gceAllocated - welderAllocated)
547-
} yield RuntimeResourceConstraints(result, total)
547+
} yield RuntimeResourceConstraints(result, total, None)
548548

549549
private def buildNetworkInterfaces(runtimeProjectAndName: RuntimeProjectAndName,
550550
subnetwork: SubnetworkName,

http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/CommonTestData.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ object CommonTestData {
239239
val cryptoDetectorImage =
240240
RuntimeImage(CryptoDetector, "crypto/crypto:0.0.1", None, Instant.now.truncatedTo(ChronoUnit.MICROS))
241241

242-
val clusterResourceConstraints = RuntimeResourceConstraints(MemorySize.fromMb(3584), MemorySize.fromMb(7680))
242+
val clusterResourceConstraints = RuntimeResourceConstraints(MemorySize.fromMb(3584), MemorySize.fromMb(7680), None)
243243
val hostToIpMapping = Ref.unsafe[IO, Map[String, IP]](Map.empty)
244244

245245
def makeAsyncRuntimeFields(index: Int): AsyncRuntimeFields =

http/src/test/scala/org/broadinstitute/dsde/workbench/leonardo/util/DataprocInterpreterSpec.scala

+66-2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ import java.time.Instant
4646
import scala.concurrent.ExecutionContext.Implicits.global
4747
import scala.concurrent.Future
4848
import scala.concurrent.duration._
49+
import org.scalatest.prop.TableDrivenPropertyChecks.Table
50+
import org.scalatest.prop.TableDrivenPropertyChecks.forAll
4951

5052
class DataprocInterpreterSpec
5153
extends TestKit(ActorSystem("leonardotest"))
@@ -214,7 +216,7 @@ class DataprocInterpreterSpec
214216
)
215217

216218
val runtimeConfig = RuntimeConfig.DataprocConfig(0,
217-
MachineTypeName("n1-standard-4"),
219+
MachineTypeName("n2-standard-4"),
218220
DiskSize(500),
219221
None,
220222
None,
@@ -251,9 +253,71 @@ class DataprocInterpreterSpec
251253

252254
}
253255

256+
it should "calculate cluster resource constraints and software config correctly for n1-standard and n1-highmem machine types" in isolatedDbTest {
257+
val highMemGoogleComputeService = new FakeGoogleComputeService {
258+
override def getMachineType(project: GoogleProject, zone: ZoneName, machineTypeName: MachineTypeName)(implicit
259+
ev: Ask[IO, TraceId]
260+
): IO[Option[MachineType]] =
261+
IO.pure(Some(MachineType.newBuilder().setName("pass").setMemoryMb(104 * 1024).setGuestCpus(4).build()))
262+
}
263+
264+
def dataprocInterpHighMem(computeService: GoogleComputeService[IO] = highMemGoogleComputeService,
265+
dataprocCluster: GoogleDataprocService[IO] = MockGoogleDataprocService,
266+
googleDirectoryDao: GoogleDirectoryDAO = mockGoogleDirectoryDAO
267+
) =
268+
new DataprocInterpreter[IO](
269+
Config.dataprocInterpreterConfig,
270+
bucketHelper,
271+
vpcInterp,
272+
dataprocCluster,
273+
computeService,
274+
MockGoogleDiskService,
275+
googleDirectoryDao,
276+
mockGoogleIamDAO,
277+
mockGoogleResourceService,
278+
MockWelderDAO
279+
)
280+
281+
val machineTypes = Table("machineType", MachineTypeName("n1-standard-4"), MachineTypeName("n1-highmem-64"))
282+
forAll(machineTypes) { machineType: MachineTypeName =>
283+
val runtimeConfig = RuntimeConfig.DataprocConfig(0,
284+
machineType,
285+
DiskSize(500),
286+
None,
287+
None,
288+
None,
289+
None,
290+
Map.empty[String, String],
291+
RegionName("us-central1"),
292+
true,
293+
false
294+
)
295+
val resourceConstraints = dataprocInterpHighMem()
296+
.getDataprocRuntimeResourceContraints(testClusterClusterProjectAndName,
297+
runtimeConfig.machineType,
298+
RegionName("us-central1")
299+
)
300+
.unsafeRunSync()(cats.effect.unsafe.IORuntime.global)
301+
302+
val dataProcSoftwareConfig = dataprocInterp().getSoftwareConfig(
303+
GoogleProject("MyGoogleProject"),
304+
RuntimeName("MyRuntimeName"),
305+
runtimeConfig,
306+
resourceConstraints
307+
)
308+
309+
val propertyMap = dataProcSoftwareConfig.getPropertiesMap()
310+
val expectedMemory =
311+
if (machineType == MachineTypeName("n1-standard-4")) (104 - 7) * 0.9 * 1024 else (104 - 11) * 0.9 * 1024
312+
propertyMap.get(
313+
"spark:spark.driver.memory"
314+
) shouldBe s"${expectedMemory.toInt}m"
315+
}
316+
}
317+
254318
it should "create correct softwareConfig - minimum runtime memory 4gb" in isolatedDbTest {
255319
val runtimeConfig = RuntimeConfig.DataprocConfig(0,
256-
MachineTypeName("n1-highmem-64"),
320+
MachineTypeName("n2-highmem-64"),
257321
DiskSize(500),
258322
None,
259323
None,

0 commit comments

Comments
 (0)