Skip to content

Commit

Permalink
add SamUtils trait
Browse files Browse the repository at this point in the history
  • Loading branch information
marctalbott committed Jan 2, 2025
1 parent ff76a5a commit 811d04a
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 104 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package org.broadinstitute.dsde.workbench.leonardo.dao.sam

import akka.http.scaladsl.model.StatusCodes
import akka.http.scaladsl.model.headers.OAuth2BearerToken
import cats.effect.Async
import cats.implicits.{catsSyntaxApplicativeError, toFlatMapOps}
import cats.mtl.Ask
import org.broadinstitute.dsde.workbench.leonardo.model.{
ForbiddenError,
LeoException,
RuntimeNotFoundByWorkspaceIdException,
RuntimeNotFoundException
}
import org.broadinstitute.dsde.workbench.leonardo.{
AppContext,
CloudContext,
RuntimeAction,
RuntimeName,
SamResourceId,
WorkspaceId
}
import org.broadinstitute.dsde.workbench.model.{UserInfo, WorkbenchEmail}

trait SamUtils[F[_]] {
val samService: SamService[F]

def checkRuntimeAction(userInfo: UserInfo,
cloudContext: CloudContext,
runtimeName: RuntimeName,
samResourceId: SamResourceId,
action: RuntimeAction,
userEmail: Option[WorkbenchEmail] = None
)(implicit F: Async[F], as: Ask[F, AppContext]): F[Unit] =
checkRuntimeActionInternal(
userInfo.accessToken,
userEmail.getOrElse(userInfo.userEmail),
samResourceId,
action,
RuntimeNotFoundException(cloudContext, runtimeName, "Not found in database")
)

def checkRuntimeAction(userInfo: UserInfo,
workspaceId: WorkspaceId,
runtimeName: RuntimeName,
samResourceId: SamResourceId,
action: RuntimeAction
)(implicit F: Async[F], as: Ask[F, AppContext]): F[Unit] =
checkRuntimeActionInternal(
userInfo.accessToken,
userInfo.userEmail,
samResourceId,
action,
RuntimeNotFoundByWorkspaceIdException(workspaceId, runtimeName, "Not found in database")
)

private def checkRuntimeActionInternal(userToken: OAuth2BearerToken,
userEmail: WorkbenchEmail,
samResourceId: SamResourceId,
action: RuntimeAction,
notFoundException: LeoException
)(implicit F: Async[F], as: Ask[F, AppContext]): F[Unit] =
samService
.checkAuthorized(userToken.token, samResourceId, action)
.handleErrorWith {
// If we've already checked read access and the user doesn't have it, pretend the runtime doesn't exist to avoid leaking its existence
case e: SamException if e.statusCode == StatusCodes.Forbidden && action == RuntimeAction.GetRuntimeStatus =>
F.raiseError(notFoundException)
// Check if the user can read the runtime to determine which error to raise
case e: SamException if e.statusCode == StatusCodes.Forbidden =>
samService
.checkAuthorized(userToken.token, samResourceId, RuntimeAction.GetRuntimeStatus)
.attempt
.flatMap {
// The user can read the runtime, but they don't have the required action. Raise the original Forbidden action from Sam
case Right(_) => F.raiseError(ForbiddenError(userEmail))
// The user can't read the runtime, pretend it doesn't exist to avoid leaking its existence
case Left(_) => F.raiseError(notFoundException)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.broadinstitute.dsde.workbench.leonardo.SamResourceId._
import org.broadinstitute.dsde.workbench.leonardo.config.ProxyConfig
import org.broadinstitute.dsde.workbench.leonardo.dao.HostStatus._
import org.broadinstitute.dsde.workbench.leonardo.dao.google.GoogleOAuth2Service
import org.broadinstitute.dsde.workbench.leonardo.dao.sam.{SamException, SamService}
import org.broadinstitute.dsde.workbench.leonardo.dao.sam.{SamService, SamUtils}
import org.broadinstitute.dsde.workbench.leonardo.dao.{HostStatus, JupyterDAO, Proxy, SamDAO, TerminalName}
import org.broadinstitute.dsde.workbench.leonardo.db.{appQuery, clusterQuery, DbReference, KubernetesServiceDbQueries}
import org.broadinstitute.dsde.workbench.leonardo.dns.{KubernetesDnsCache, ProxyResolver, RuntimeDnsCache}
Expand Down Expand Up @@ -93,14 +93,15 @@ class ProxyService(
samDAO: SamDAO[IO],
googleTokenCache: Cache[IO, String, (UserInfo, Instant)],
samResourceCache: Cache[IO, SamResourceCacheKey, (Option[String], Option[AppAccessScope])],
samService: SamService[IO]
val samService: SamService[IO]
)(implicit
val system: ActorSystem,
executionContext: ExecutionContext,
dbRef: DbReference[IO],
loggerIO: StructuredLogger[IO],
metrics: OpenTelemetryMetrics[IO]
) extends LazyLogging {
) extends LazyLogging
with SamUtils[IO] {
val httpsConnectionContext = ConnectionContext.httpsClient(sslContext)
val clientConnectionSettings =
ClientConnectionSettings(system).withTransport(ClientTransport.withCustomResolver(proxyResolver.resolveAkka))
Expand Down Expand Up @@ -558,37 +559,6 @@ class ProxyService(
}
}

private def checkRuntimeAction(userInfo: UserInfo,
cloudContext: CloudContext,
runtimeName: RuntimeName,
samResourceId: SamResourceId,
action: RuntimeAction
)(implicit ev: Ask[IO, AppContext]): IO[Unit] =
samService
.checkAuthorized(userInfo.accessToken.token, samResourceId, action)
.handleErrorWith {
// If we've already checked read access and the user doesn't have it, pretend the runtime doesn't exist to avoid leaking its existence
case e: SamException if e.statusCode == StatusCodes.Forbidden && action == RuntimeAction.GetRuntimeStatus =>
IO.raiseError(RuntimeNotFoundException(cloudContext, runtimeName, "Not found in database"))
// Check if the user can read the runtime to determine which error to raise
case e: SamException if e.statusCode == StatusCodes.Forbidden =>
samService
.checkAuthorized(userInfo.accessToken.token, samResourceId, RuntimeAction.GetRuntimeStatus)
.attempt
.flatMap {
// The user can read the runtime, but they don't have the required action so raise a ForbiddenError
case Right(_) =>
IO.raiseError(
ForbiddenError(userInfo.userEmail)
)
// The user can't read the runtime, pretend it doesn't exist to avoid leaking its existence
case Left(_) =>
IO.raiseError(
RuntimeNotFoundException(cloudContext, runtimeName, "Not found in database")
)
}
}

private def filterHeaders(headers: immutable.Seq[HttpHeader]): immutable.Seq[HttpHeader] =
headers.filterNot(header => HeadersToFilter(header.lowercaseName()))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.broadinstitute.dsde.workbench.leonardo.SamResourceId.{
}
import org.broadinstitute.dsde.workbench.leonardo.config._
import org.broadinstitute.dsde.workbench.leonardo.dao.DockerDAO
import org.broadinstitute.dsde.workbench.leonardo.dao.sam.{SamException, SamService}
import org.broadinstitute.dsde.workbench.leonardo.dao.sam.{SamService, SamUtils}
import org.broadinstitute.dsde.workbench.leonardo.db._
import org.broadinstitute.dsde.workbench.leonardo.http.service.DiskServiceInterp.getDiskSamPolicyMap
import org.broadinstitute.dsde.workbench.leonardo.model.SamResourceAction.{
Expand Down Expand Up @@ -66,14 +66,15 @@ class RuntimeServiceInterp[F[_]: Parallel](
googleStorageService: Option[GoogleStorageService[F]],
googleComputeService: Option[GoogleComputeService[F]],
publisherQueue: Queue[F, LeoPubsubMessage],
samService: SamService[F]
val samService: SamService[F]
)(implicit
F: Async[F],
log: StructuredLogger[F],
dbReference: DbReference[F],
ec: ExecutionContext,
metrics: OpenTelemetryMetrics[F]
) extends RuntimeService[F] {
) extends RuntimeService[F]
with SamUtils[F] {

override def createRuntime(
userInfo: UserInfo,
Expand Down Expand Up @@ -818,38 +819,6 @@ class RuntimeServiceInterp[F[_]: Parallel](

_ <- checkRuntimeAction(userInfo, cloudContext, runtimeName, runtime.samResource, action, userEmail)
} yield runtime

private def checkRuntimeAction(userInfo: UserInfo,
cloudContext: CloudContext,
runtimeName: RuntimeName,
samResourceId: RuntimeSamResourceId,
action: RuntimeAction,
userEmail: Option[WorkbenchEmail] = None
)(implicit as: Ask[F, AppContext]): F[Unit] =
samService
.checkAuthorized(userInfo.accessToken.token, samResourceId, action)
.handleErrorWith {
// If we've already checked read access and the user doesn't have it, pretend the runtime doesn't exist to avoid leaking its existence
case e: SamException if e.statusCode == StatusCodes.Forbidden && action == RuntimeAction.GetRuntimeStatus =>
F.raiseError(RuntimeNotFoundException(cloudContext, runtimeName, "Not found in database"))
// Check if the user can read the runtime to determine which error to raise
case e: SamException if e.statusCode == StatusCodes.Forbidden =>
samService
.checkAuthorized(userInfo.accessToken.token, samResourceId, RuntimeAction.GetRuntimeStatus)
.attempt
.flatMap {
// The user can read the runtime, but they don't have the required action so raise a ForbiddenError
case Right(_) =>
F.raiseError(
ForbiddenError(userEmail.getOrElse(userInfo.userEmail))
)
// The user can't read the runtime, pretend it doesn't exist to avoid leaking its existence
case Left(_) =>
F.raiseError(
RuntimeNotFoundException(cloudContext, runtimeName, "Not found in database")
)
}
}
}

object RuntimeServiceInterp {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import org.broadinstitute.dsde.workbench.leonardo.SamResourceId.{
}
import org.broadinstitute.dsde.workbench.leonardo.config.PersistentDiskConfig
import org.broadinstitute.dsde.workbench.leonardo.dao._
import org.broadinstitute.dsde.workbench.leonardo.dao.sam.{SamException, SamService}
import org.broadinstitute.dsde.workbench.leonardo.dao.sam.{SamService, SamUtils}
import org.broadinstitute.dsde.workbench.leonardo.db._
import org.broadinstitute.dsde.workbench.leonardo.http.service.DiskServiceInterp.getDiskSamPolicyMap
import org.broadinstitute.dsde.workbench.leonardo.http.service.RuntimeServiceInterp.getRuntimeSamPolicyMap
Expand All @@ -42,9 +42,10 @@ class RuntimeV2ServiceInterp[F[_]: Parallel](
publisherQueue: Queue[F, LeoPubsubMessage],
dateAccessUpdaterQueue: Queue[F, UpdateDateAccessedMessage],
wsmClientProvider: WsmApiClientProvider[F],
samService: SamService[F]
val samService: SamService[F]
)(implicit F: Async[F], dbReference: DbReference[F], ec: ExecutionContext, log: StructuredLogger[F])
extends RuntimeV2Service[F] {
extends RuntimeV2Service[F]
with SamUtils[F] {

override def createRuntime(
userInfo: UserInfo,
Expand Down Expand Up @@ -488,42 +489,10 @@ class RuntimeV2ServiceInterp[F[_]: Parallel](
action: RuntimeAction
)(implicit as: Ask[F, AppContext]): F[ClusterRecord] =
for {
ctx <- as.ask
runtime <- RuntimeServiceDbQueries.getActiveRuntimeRecord(workspaceId, runtimeName).transaction
_ <- checkRuntimeAction(userInfo, workspaceId, runtimeName, RuntimeSamResourceId(runtime.internalId), action)
} yield runtime

private def checkRuntimeAction(userInfo: UserInfo,
workspaceId: WorkspaceId,
runtimeName: RuntimeName,
samResourceId: SamResourceId,
action: RuntimeAction
)(implicit as: Ask[F, AppContext]): F[Unit] =
samService
.checkAuthorized(userInfo.accessToken.token, samResourceId, action)
.handleErrorWith {
// If we've already checked read access and the user doesn't have it, pretend the runtime doesn't exist to avoid leaking its existence
case e: SamException if e.statusCode == StatusCodes.Forbidden && action == RuntimeAction.GetRuntimeStatus =>
F.raiseError(RuntimeNotFoundByWorkspaceIdException(workspaceId, runtimeName, "Not found in database"))
// Check if the user can read the runtime to determine which error to raise
case e: SamException if e.statusCode == StatusCodes.Forbidden =>
samService
.checkAuthorized(userInfo.accessToken.token, samResourceId, RuntimeAction.GetRuntimeStatus)
.attempt
.flatMap {
// The user can read the runtime, but they don't have the required action. Raise the original Forbidden action from Sam
case Right(_) =>
F.raiseError(
ForbiddenError(userInfo.userEmail)
)
// The user can't read the runtime, pretend it doesn't exist to avoid leaking its existence
case Left(_) =>
F.raiseError(
RuntimeNotFoundByWorkspaceIdException(workspaceId, runtimeName, "Not found in database")
)
}
}

private def errorHandler(runtimeId: Long, ctx: AppContext): Throwable => F[Unit] =
e =>
clusterErrorQuery
Expand Down

0 comments on commit 811d04a

Please sign in to comment.