diff --git a/cmd/gcs/main.go b/cmd/gcs/main.go index 36ae1991b6..09bfe02305 100644 --- a/cmd/gcs/main.go +++ b/cmd/gcs/main.go @@ -31,6 +31,7 @@ import ( "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/oc" "github.com/Microsoft/hcsshim/internal/version" + "github.com/Microsoft/hcsshim/pkg/amdsevsnp" "github.com/Microsoft/hcsshim/pkg/securitypolicy" ) @@ -359,9 +360,22 @@ func main() { logrus.WithError(err).Fatal("failed to initialize new runc runtime") } mux := bridge.NewBridgeMux() + + forceSequential, err := amdsevsnp.IsSNP() + if err != nil { + // IsSNP cannot fail on LCOW + logrus.Errorf("Got unexpected error from IsSNP(): %v", err) + // If it fails, we proceed with forceSequential enabled to be safe + forceSequential = true + } + b := bridge.Bridge{ Handler: mux, EnableV4: *v4, + + // For confidential containers, we protect ourselves against attacks caused + // by concurrent modifications, by processing one request at a time. + ForceSequential: forceSequential, } h := hcsv2.NewHost(rtime, tport, initialEnforcer, logWriter) // Initialize virtual pod support in the host diff --git a/internal/gcs/unrecoverable_error.go b/internal/gcs/unrecoverable_error.go new file mode 100644 index 0000000000..96528088aa --- /dev/null +++ b/internal/gcs/unrecoverable_error.go @@ -0,0 +1,53 @@ +package gcs + +import ( + "context" + "fmt" + "os" + "runtime" + "time" + + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/pkg/amdsevsnp" + "github.com/sirupsen/logrus" +) + +// UnrecoverableError logs the error and then puts the current thread into an +// infinite sleep loop. This is to be used instead of panicking, as the +// behaviour of GCS panics is unpredictable. This function can be extended to, +// for example, try to shutdown the VM cleanly. +func UnrecoverableError(err error) { + buf := make([]byte, 300*(1<<10)) + stackSize := runtime.Stack(buf, true) + stackTrace := string(buf[:stackSize]) + + errPrint := fmt.Sprintf( + "Unrecoverable error in GCS: %v\n%s", + err, stackTrace, + ) + + isSnp, err := amdsevsnp.IsSNP() + if err != nil { + // IsSNP() cannot fail on LCOW + // but if it does, we proceed as if we're on SNP to be safe. + isSnp = true + } + + if isSnp { + errPrint += "\nThis thread will now enter an infinite loop." + } + log.G(context.Background()).WithError(err).Logf( + logrus.FatalLevel, + "%s", + errPrint, + ) + + if !isSnp { + panic("Unrecoverable error in GCS: " + err.Error()) + } else { + fmt.Fprintf(os.Stderr, "%s\n", errPrint) + for { + time.Sleep(time.Hour) + } + } +} diff --git a/internal/guest/bridge/bridge.go b/internal/guest/bridge/bridge.go index 4ea03ed104..875def5809 100644 --- a/internal/guest/bridge/bridge.go +++ b/internal/guest/bridge/bridge.go @@ -177,6 +177,10 @@ type Bridge struct { Handler Handler // EnableV4 enables the v4+ bridge and the schema v2+ interfaces. EnableV4 bool + // Setting ForceSequential to true will force the bridge to only process one + // request at a time, except for certain long-running operations (as defined + // in asyncMessages). + ForceSequential bool // responseChan is the response channel used for both request/response // and publish notification workflows. @@ -191,6 +195,14 @@ type Bridge struct { protVer prot.ProtocolVersion } +// Messages that will be processed asynchronously even in sequential mode. Note +// that in sequential mode, these messages will still wait for any in-progress +// non-async messages to be handled before they are processed, but once they are +// "acknowledged", the rest will be done asynchronously. +var alwaysAsyncMessages map[prot.MessageIdentifier]bool = map[prot.MessageIdentifier]bool{ + prot.ComputeSystemWaitForProcessV1: true, +} + // AssignHandlers creates and assigns the appropriate bridge // events to be listen for and intercepted on `mux` before forwarding // to `gcs` for handling. @@ -238,6 +250,10 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser defer close(requestErrChan) defer bridgeIn.Close() + if b.ForceSequential { + log.G(context.Background()).Info("bridge: ForceSequential enabled") + } + // Receive bridge requests and schedule them to be processed. go func() { var recverr error @@ -340,30 +356,36 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser }() // Process each bridge request async and create the response writer. go func() { - for req := range requestChan { - go func(r *Request) { - br := bridgeResponse{ - ctx: r.Context, - header: &prot.MessageHeader{ - Type: prot.GetResponseIdentifier(r.Header.Type), - ID: r.Header.ID, - }, - } - resp, err := b.Handler.ServeMsg(r) - if resp == nil { - resp = &prot.MessageResponseBase{} - } - resp.Base().ActivityID = r.ActivityID - if err != nil { - span := trace.FromContext(r.Context) - if span != nil { - oc.SetSpanStatus(span, err) - } - setErrorForResponseBase(resp.Base(), err, "gcs" /* moduleName */) + doOneRequest := func(r *Request) { + br := bridgeResponse{ + ctx: r.Context, + header: &prot.MessageHeader{ + Type: prot.GetResponseIdentifier(r.Header.Type), + ID: r.Header.ID, + }, + } + resp, err := b.Handler.ServeMsg(r) + if resp == nil { + resp = &prot.MessageResponseBase{} + } + resp.Base().ActivityID = r.ActivityID + if err != nil { + span := trace.FromContext(r.Context) + if span != nil { + oc.SetSpanStatus(span, err) } - br.response = resp - b.responseChan <- br - }(req) + setErrorForResponseBase(resp.Base(), err, "gcs" /* moduleName */) + } + br.response = resp + b.responseChan <- br + } + + for req := range requestChan { + if b.ForceSequential && !alwaysAsyncMessages[req.Header.Type] { + runSequentialRequestHandler(req, doOneRequest) + } else { + go doOneRequest(req) + } } }() // Process each bridge response sync. This channel is for request/response and publish workflows. @@ -423,6 +445,32 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser } } +// Do handleFn(r), but prints a warning if handleFn does not, or takes too long +// to return. +func runSequentialRequestHandler(r *Request, handleFn func(*Request)) { + // Note that this is only a context used for triggering the blockage + // warning, the request processing still uses r.Context. We don't want to + // cancel the request handling itself when we reach the 5s timeout. + timeoutCtx, cancel := context.WithTimeout(r.Context, 5*time.Second) + go func() { + <-timeoutCtx.Done() + if errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) { + log.G(timeoutCtx).WithFields(logrus.Fields{ + // We want to log those even though we're providing r.Context, since if + // the request never finishes the span end log will never get written, + // and we may therefore not be able to find out about the following info + // otherwise: + "message-type": r.Header.Type.String(), + "message-id": r.Header.ID, + "activity-id": r.ActivityID, + "container-id": r.ContainerID, + }).Warnf("bridge: request processing thread in sequential mode blocked on the current request for more than 5 seconds") + } + }() + defer cancel() + handleFn(r) +} + // PublishNotification writes a specific notification to the bridge. func (b *Bridge) PublishNotification(n *prot.ContainerNotification) { ctx, span := oc.StartSpan(context.Background(), diff --git a/internal/guest/bridge/bridge_v2.go b/internal/guest/bridge/bridge_v2.go index 800094e549..2f105ef5b6 100644 --- a/internal/guest/bridge/bridge_v2.go +++ b/internal/guest/bridge/bridge_v2.go @@ -467,16 +467,10 @@ func (b *Bridge) deleteContainerStateV2(r *Request) (_ RequestResponse, err erro return nil, errors.Wrapf(err, "failed to unmarshal JSON in message \"%s\"", r.Message) } - c, err := b.hostState.GetCreatedContainer(request.ContainerID) + err = b.hostState.DeleteContainerState(ctx, request.ContainerID) if err != nil { return nil, err } - // remove container state regardless of delete's success - defer b.hostState.RemoveContainer(request.ContainerID) - - if err := c.Delete(ctx); err != nil { - return nil, err - } return &prot.MessageResponseBase{}, nil } diff --git a/internal/guest/network/network.go b/internal/guest/network/network.go index 49564126a6..ca837d0c3f 100644 --- a/internal/guest/network/network.go +++ b/internal/guest/network/network.go @@ -9,6 +9,7 @@ import ( "fmt" "os" "path/filepath" + "regexp" "strings" "time" @@ -32,6 +33,18 @@ var ( // maxDNSSearches is limited to 6 in `man 5 resolv.conf` const maxDNSSearches = 6 +var validHostnameRegex = regexp.MustCompile(`^[a-zA-Z0-9_\-\.]{0,255}$`) + +// Check that the hostname is safe. This function is less strict than +// technically allowed, but ensures that when the hostname is inserted to +// /etc/hosts, it cannot lead to injection attacks. +func ValidateHostname(hostname string) error { + if !validHostnameRegex.MatchString(hostname) { + return errors.Errorf("hostname %q invalid: must match %s", hostname, validHostnameRegex.String()) + } + return nil +} + // GenerateEtcHostsContent generates a /etc/hosts file based on `hostname`. func GenerateEtcHostsContent(ctx context.Context, hostname string) string { _, span := oc.StartSpan(ctx, "network::GenerateEtcHostsContent") diff --git a/internal/guest/network/network_test.go b/internal/guest/network/network_test.go index 9a718aed66..cc0ec1be3a 100644 --- a/internal/guest/network/network_test.go +++ b/internal/guest/network/network_test.go @@ -7,6 +7,7 @@ import ( "context" "os" "path/filepath" + "strings" "testing" "time" ) @@ -70,6 +71,40 @@ func Test_GenerateResolvConfContent(t *testing.T) { } } +func Test_ValidateHostname(t *testing.T) { + validNames := []string{ + "localhost", + "my-hostname", + "my.hostname", + "my-host-name123", + "_underscores.are.allowed.too", + "", // Allow not specifying a hostname + } + + invalidNames := []string{ + "localhost\n13.104.0.1 ip6-localhost ip6-loopback localhost", + "localhost\n2603:1000::1 ip6-localhost ip6-loopback", + "hello@microsoft.com", + "has space", + "has,comma", + "\x00", + "a\nb", + strings.Repeat("a", 1000), + } + + for _, n := range validNames { + if err := ValidateHostname(n); err != nil { + t.Fatalf("expected %q to be valid, got: %v", n, err) + } + } + + for _, n := range invalidNames { + if err := ValidateHostname(n); err == nil { + t.Fatalf("expected %q to be invalid, but got nil error", n) + } + } +} + func Test_GenerateEtcHostsContent(t *testing.T) { type testcase struct { name string diff --git a/internal/guest/runtime/hcsv2/container.go b/internal/guest/runtime/hcsv2/container.go index bb9c3af5ea..886eb0528d 100644 --- a/internal/guest/runtime/hcsv2/container.go +++ b/internal/guest/runtime/hcsv2/container.go @@ -73,6 +73,9 @@ type Container struct { // and deal with the extra pointer dereferencing overhead. status atomic.Uint32 + // Set to true when the init process for the container has exited + terminated atomic.Bool + // scratchDirPath represents the path inside the UVM where the scratch directory // of this container is located. Usually, this is either `/run/gcs/c/` or // `/run/gcs/c//container_` if scratch is shared with UVM scratch. diff --git a/internal/guest/runtime/hcsv2/process.go b/internal/guest/runtime/hcsv2/process.go index e94c9792f6..96564cfab0 100644 --- a/internal/guest/runtime/hcsv2/process.go +++ b/internal/guest/runtime/hcsv2/process.go @@ -99,6 +99,7 @@ func newProcess(c *Container, spec *oci.Process, process runtime.Process, pid ui log.G(ctx).WithError(err).Error("failed to wait for runc process") } p.exitCode = exitCode + c.terminated.Store(true) log.G(ctx).WithField("exitCode", p.exitCode).Debug("process exited") // Free any process waiters diff --git a/internal/guest/runtime/hcsv2/sandbox_container.go b/internal/guest/runtime/hcsv2/sandbox_container.go index 471a28bfe1..6c9b3511c6 100644 --- a/internal/guest/runtime/hcsv2/sandbox_container.go +++ b/internal/guest/runtime/hcsv2/sandbox_container.go @@ -54,6 +54,9 @@ func setupSandboxContainerSpec(ctx context.Context, id string, spec *oci.Spec) ( // Write the hostname hostname := spec.Hostname + if err = network.ValidateHostname(hostname); err != nil { + return err + } if hostname == "" { var err error hostname, err = os.Hostname() diff --git a/internal/guest/runtime/hcsv2/standalone_container.go b/internal/guest/runtime/hcsv2/standalone_container.go index 1a23d276fc..e10d967f85 100644 --- a/internal/guest/runtime/hcsv2/standalone_container.go +++ b/internal/guest/runtime/hcsv2/standalone_container.go @@ -60,6 +60,9 @@ func setupStandaloneContainerSpec(ctx context.Context, id string, spec *oci.Spec }() hostname := spec.Hostname + if err = network.ValidateHostname(hostname); err != nil { + return err + } if hostname == "" { var err error hostname, err = os.Hostname() diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index 48c09b2018..935d363071 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -15,8 +15,10 @@ import ( "path" "path/filepath" "regexp" + "slices" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -26,6 +28,7 @@ import ( "github.com/opencontainers/runtime-spec/specs-go" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "go.opencensus.io/trace" "golang.org/x/sys/unix" "github.com/Microsoft/hcsshim/internal/bridgeutils/gcserr" @@ -44,6 +47,7 @@ import ( "github.com/Microsoft/hcsshim/internal/guestpath" "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/logfields" + "github.com/Microsoft/hcsshim/internal/oc" "github.com/Microsoft/hcsshim/internal/oci" "github.com/Microsoft/hcsshim/internal/protocol/guestrequest" "github.com/Microsoft/hcsshim/internal/protocol/guestresource" @@ -107,8 +111,20 @@ type Host struct { securityOptions *securitypolicy.SecurityOptions // hostMounts keeps the state of currently mounted devices and file systems, - // which is used for GCS hardening. + // which is used for GCS hardening. It is only used for confidential + // containers, and is initialized in SetConfidentialUVMOptions. If this is + // nil, we do not do add any special restrictions on mounts / unmounts. hostMounts *hostMounts + // A permanent flag to indicate that further mounts, unmounts and container + // creation should not be allowed. This is set when, because of a failure + // during an unmount operation, we end up in a state where the policy + // enforcer's state is out of sync with what we have actually done, but we + // cannot safely revert its state. + // + // Not used in non-confidential mode. + mountsBroken atomic.Bool + // A user-friendly error message for why mountsBroken was set. + mountsBrokenCausedBy string } func NewHost(rtime runtime.Runtime, vsock transport.Transport, initialEnforcer securitypolicy.SecurityPolicyEnforcer, logWriter io.Writer) *Host { @@ -126,8 +142,9 @@ func NewHost(rtime runtime.Runtime, vsock transport.Transport, initialEnforcer s rtime: rtime, vsock: vsock, devNullTransport: &transport.DevNullTransport{}, - hostMounts: newHostMounts(), + hostMounts: nil, securityOptions: securityPolicyOptions, + mountsBroken: atomic.Bool{}, } } @@ -309,7 +326,44 @@ func checkContainerSettings(sandboxID, containerID string, settings *prot.VMHost return nil } +// Returns an error if h.mountsBroken is set (and we're in a confidential +// container host) +func (h *Host) checkMountsNotBroken() error { + if h.HasSecurityPolicy() && h.mountsBroken.Load() { + return errors.Errorf( + "Mount, unmount, container creation and deletion has been disabled in this UVM due to a previous error (%q)", + h.mountsBrokenCausedBy, + ) + } + return nil +} + +func (h *Host) setMountsBrokenIfConfidential(cause string) { + if !h.HasSecurityPolicy() { + return + } + h.mountsBroken.Store(true) + h.mountsBrokenCausedBy = cause + log.G(context.Background()).WithFields(logrus.Fields{ + "cause": cause, + }).Error("Host::mountsBroken set to true. All further mounts/unmounts, container creation and deletion will fail.") +} + +func checkExists(path string) (error, bool) { + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + return nil, false + } + return errors.Wrapf(err, "failed to determine if path '%s' exists", path), false + } + return nil, true +} + func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VMHostedContainerSettingsV2) (_ *Container, err error) { + if err = h.checkMountsNotBroken(); err != nil { + return nil, err + } + criType, isCRI := settings.OCISpecification.Annotations[annotations.KubernetesContainerType] // Check for virtual pod annotation @@ -347,6 +401,7 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM isSandbox: criType == "sandbox", exitType: prot.NtUnexpectedExit, processes: make(map[uint32]*containerProcess), + terminated: atomic.Bool{}, scratchDirPath: settings.ScratchDirPath, } c.setStatus(containerCreating) @@ -390,6 +445,18 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM } } + // Take a backup of the devices array before we populate it with any devices + // found by GCS, in order to pass to the policy enforcer later. + // + // In specGuest.ApplyAnnotationsToSpec, if this is a privileged container, + // we will add devices found in the GCS namespace's /dev. Regardless of + // privileged or not, we also always include /dev/sev-guest. Since the + // policy already lets the user enforce whether the container should be + // privileged or not, and the sev-guest device is always added for a + // confidential container, we do not need the policy enforcer to check these + // devices we dynamically add again. + extraLinuxDevices := slices.Clone(settings.OCISpecification.Linux.Devices) + // Normally we would be doing policy checking here at the start of our // "policy gated function". However, we can't for create container as we // need a properly correct sandboxID which might be changed by the code @@ -525,21 +592,27 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM return nil, err } - envToKeep, capsToKeep, allowStdio, err := h.securityOptions.PolicyEnforcer.EnforceCreateContainerPolicy( + privileged := isPrivilegedContainerCreationRequest(ctx, settings.OCISpecification) + noNewPrivileges := settings.OCISpecification.Process.NoNewPrivileges + opts := &securitypolicy.CreateContainerOptions{ + SandboxID: sandboxID, + Privileged: &privileged, + NoNewPrivileges: &noNewPrivileges, + Groups: groups, + Umask: umask, + Capabilities: settings.OCISpecification.Process.Capabilities, + SeccompProfileSHA256: seccomp, + LinuxDevices: extraLinuxDevices, + } + envToKeep, capsToKeep, allowStdio, err := h.securityOptions.PolicyEnforcer.EnforceCreateContainerPolicyV2( ctx, - sandboxID, id, settings.OCISpecification.Process.Args, settings.OCISpecification.Process.Env, settings.OCISpecification.Process.Cwd, settings.OCISpecification.Mounts, - isPrivilegedContainerCreationRequest(ctx, settings.OCISpecification), - settings.OCISpecification.Process.NoNewPrivileges, user, - groups, - umask, - settings.OCISpecification.Process.Capabilities, - seccomp, + opts, ) if err != nil { return nil, errors.Wrapf(err, "container creation denied due to policy") @@ -671,6 +744,25 @@ func writeSpecToFile(ctx context.Context, configFile string, spec *specs.Spec) e return nil } +// Returns whether there is a running container that is currently using the +// given overlay (as its rootfs). +func (h *Host) IsOverlayInUse(overlayPath string) bool { + h.containersMutex.Lock() + defer h.containersMutex.Unlock() + + for _, c := range h.containers { + if c.terminated.Load() { + continue + } + + if c.spec.Root.Path == overlayPath { + return true + } + } + + return false +} + func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req *guestrequest.ModificationRequest) (retErr error) { if h.HasSecurityPolicy() { if err := checkValidContainerID(containerID, "container"); err != nil { @@ -682,6 +774,10 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * case guestresource.ResourceTypeSCSIDevice: return modifySCSIDevice(ctx, req.RequestType, req.Settings.(*guestresource.SCSIDevice)) case guestresource.ResourceTypeMappedVirtualDisk: + if err := h.checkMountsNotBroken(); err != nil { + return err + } + mvd := req.Settings.(*guestresource.LCOWMappedVirtualDisk) // find the actual controller number on the bus and update the incoming request. var cNum uint8 @@ -690,47 +786,25 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * return err } mvd.Controller = cNum - // first we try to update the internal state for read-write attachments. - if !mvd.ReadOnly { - localCtx, cancel := context.WithTimeout(ctx, time.Second*5) - defer cancel() - source, err := scsi.GetDevicePath(localCtx, mvd.Controller, mvd.Lun, mvd.Partition) - if err != nil { - return err - } - switch req.RequestType { - case guestrequest.RequestTypeAdd: - if err := h.hostMounts.AddRWDevice(mvd.MountPath, source, mvd.Encrypted); err != nil { - return err - } - defer func() { - if retErr != nil { - _ = h.hostMounts.RemoveRWDevice(mvd.MountPath, source) - } - }() - case guestrequest.RequestTypeRemove: - if err := h.hostMounts.RemoveRWDevice(mvd.MountPath, source); err != nil { - return err - } - defer func() { - if retErr != nil { - _ = h.hostMounts.AddRWDevice(mvd.MountPath, source, mvd.Encrypted) - } - }() - } - } - return modifyMappedVirtualDisk(ctx, req.RequestType, mvd, h.securityOptions.PolicyEnforcer) + return h.modifyMappedVirtualDisk(ctx, req.RequestType, mvd) case guestresource.ResourceTypeMappedDirectory: - return modifyMappedDirectory(ctx, h.vsock, req.RequestType, req.Settings.(*guestresource.LCOWMappedDirectory), h.securityOptions.PolicyEnforcer) + if err := h.checkMountsNotBroken(); err != nil { + return err + } + + return h.modifyMappedDirectory(ctx, h.vsock, req.RequestType, req.Settings.(*guestresource.LCOWMappedDirectory)) case guestresource.ResourceTypeVPMemDevice: - return modifyMappedVPMemDevice(ctx, req.RequestType, req.Settings.(*guestresource.LCOWMappedVPMemDevice), h.securityOptions.PolicyEnforcer) + if err := h.checkMountsNotBroken(); err != nil { + return err + } + + return h.modifyMappedVPMemDevice(ctx, req.RequestType, req.Settings.(*guestresource.LCOWMappedVPMemDevice)) case guestresource.ResourceTypeCombinedLayers: - cl := req.Settings.(*guestresource.LCOWCombinedLayers) - // when cl.ScratchPath == "", we mount overlay as read-only, in which case - // we don't really care about scratch encryption, since the host already - // knows about the layers and the overlayfs. - encryptedScratch := cl.ScratchPath != "" && h.hostMounts.IsEncrypted(cl.ScratchPath) - return modifyCombinedLayers(ctx, req.RequestType, req.Settings.(*guestresource.LCOWCombinedLayers), encryptedScratch, h.securityOptions.PolicyEnforcer) + if err := h.checkMountsNotBroken(); err != nil { + return err + } + + return h.modifyCombinedLayers(ctx, req.RequestType, req.Settings.(*guestresource.LCOWCombinedLayers)) case guestresource.ResourceTypeNetwork: return modifyNetwork(ctx, req.RequestType, req.Settings.(*guestresource.LCOWNetworkAdapter)) case guestresource.ResourceTypeVPCIDevice: @@ -746,10 +820,22 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * if !ok { return errors.New("the request's settings are not of type ConfidentialOptions") } - return h.securityOptions.SetConfidentialOptions(ctx, + err := h.securityOptions.SetConfidentialOptions(ctx, r.EnforcerType, r.EncodedSecurityPolicy, r.EncodedUVMReference) + if err != nil { + return err + } + + // Start tracking mounts and restricting unmounts on confidential containers. + // As long as we started off with the ClosedDoorSecurityPolicyEnforcer, no + // mounts should have been allowed until this point. + if h.HasSecurityPolicy() { + log.G(ctx).Debug("hostMounts initialized") + h.hostMounts = newHostMounts() + } + return nil case guestresource.ResourceTypePolicyFragment: r, ok := req.Settings.(*guestresource.SecurityPolicyFragment) if !ok { @@ -1118,23 +1204,38 @@ func modifySCSIDevice( } } -func modifyMappedVirtualDisk( +func (h *Host) modifyMappedVirtualDisk( ctx context.Context, rt guestrequest.RequestType, mvd *guestresource.LCOWMappedVirtualDisk, - securityPolicy securitypolicy.SecurityPolicyEnforcer, ) (err error) { + ctx, span := oc.StartSpan(ctx, "gcs::Host::modifyMappedVirtualDisk") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + span.AddAttributes( + trace.StringAttribute("requestType", string(rt)), + trace.BoolAttribute("hasHostMounts", h.hostMounts != nil), + trace.Int64Attribute("controller", int64(mvd.Controller)), + trace.Int64Attribute("lun", int64(mvd.Lun)), + trace.Int64Attribute("partition", int64(mvd.Partition)), + trace.BoolAttribute("readOnly", mvd.ReadOnly), + trace.StringAttribute("mountPath", mvd.MountPath), + ) + var verityInfo *guestresource.DeviceVerityInfo + securityPolicy := h.securityOptions.PolicyEnforcer + devPath, err := scsi.GetDevicePath(ctx, mvd.Controller, mvd.Lun, mvd.Partition) + if err != nil { + return err + } + span.AddAttributes(trace.StringAttribute("devicePath", devPath)) + if mvd.ReadOnly { // The only time the policy is empty, and we want it to be empty // is when no policy is provided, and we default to open door // policy. In any other case, e.g. explicit open door or any // other rego policy we would like to mount layers with verity. - if len(securityPolicy.EncodedSecurityPolicy()) > 0 { - devPath, err := scsi.GetDevicePath(ctx, mvd.Controller, mvd.Lun, mvd.Partition) - if err != nil { - return err - } + if h.HasSecurityPolicy() { verityInfo, err = verity.ReadVeritySuperBlock(ctx, devPath) if err != nil { return err @@ -1144,11 +1245,40 @@ func modifyMappedVirtualDisk( } } } + + // For confidential containers, we revert the policy metadata on both mount + // and unmount errors, but if we've actually called Unmount and it fails we + // permanently block further device operations. + var rev securitypolicy.RevertableSectionHandle + rev, err = securityPolicy.StartRevertableSection() + if err != nil { + return errors.Wrapf(err, "failed to start revertable section on security policy enforcer") + } + defer h.commitOrRollbackPolicyRevSection(ctx, rev, &err) + switch rt { case guestrequest.RequestTypeAdd: mountCtx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() if mvd.MountPath != "" { + if h.HasSecurityPolicy() { + // The only option we allow if there is policy enforcement is + // "ro", and it must match the readonly field in the request. + mountOptionHasRo := false + for _, opt := range mvd.Options { + if opt == "ro" { + mountOptionHasRo = true + continue + } + return errors.Errorf("mounting scsi device controller %d lun %d onto %s: mount option %q denied by policy", mvd.Controller, mvd.Lun, mvd.MountPath, opt) + } + if mvd.ReadOnly != mountOptionHasRo { + return errors.Errorf( + "mounting scsi device controller %d lun %d onto %s with mount option %q failed due to mount option mismatch: mvd.ReadOnly=%t but mountOptionHasRo=%t", + mvd.Controller, mvd.Lun, mvd.MountPath, strings.Join(mvd.Options, ","), mvd.ReadOnly, mountOptionHasRo, + ) + } + } if mvd.ReadOnly { var deviceHash string if verityInfo != nil { @@ -1158,11 +1288,42 @@ func modifyMappedVirtualDisk( if err != nil { return errors.Wrapf(err, "mounting scsi device controller %d lun %d onto %s denied by policy", mvd.Controller, mvd.Lun, mvd.MountPath) } + if h.hostMounts != nil { + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + + err = h.hostMounts.AddRODevice(mvd.MountPath, devPath) + if err != nil { + return err + } + // Note: "When a function returns, its deferred calls are + // executed in last-in-first-out order." - so we are safe to + // call RemoveRODevice in this defer. + defer func() { + if err != nil { + _ = h.hostMounts.RemoveRODevice(mvd.MountPath, devPath) + } + }() + } } else { err = securityPolicy.EnforceRWDeviceMountPolicy(ctx, mvd.MountPath, mvd.Encrypted, mvd.EnsureFilesystem, mvd.Filesystem) if err != nil { return errors.Wrapf(err, "mounting scsi device controller %d lun %d onto %s denied by policy", mvd.Controller, mvd.Lun, mvd.MountPath) } + if h.hostMounts != nil { + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + + err = h.hostMounts.AddRWDevice(mvd.MountPath, devPath, mvd.Encrypted) + if err != nil { + return err + } + defer func() { + if err != nil { + _ = h.hostMounts.RemoveRWDevice(mvd.MountPath, devPath, mvd.Encrypted) + } + }() + } } config := &scsi.Config{ Encrypted: mvd.Encrypted, @@ -1171,6 +1332,12 @@ func modifyMappedVirtualDisk( Filesystem: mvd.Filesystem, BlockDev: mvd.BlockDev, } + // Since we're rolling back the policy metadata (via the revertable + // section) on failure, we need to ensure that we have reverted all + // the side effects from this failed mount attempt, otherwise the + // Rego metadata is technically still inconsistent with reality. + // Mount cleans up the created directory and dm devices on failure, + // so we're good. return scsi.Mount(mountCtx, mvd.Controller, mvd.Lun, mvd.Partition, mvd.MountPath, mvd.ReadOnly, mvd.Options, config) } @@ -1178,13 +1345,58 @@ func modifyMappedVirtualDisk( case guestrequest.RequestTypeRemove: if mvd.MountPath != "" { if mvd.ReadOnly { - if err := securityPolicy.EnforceDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { + if err = securityPolicy.EnforceDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { return fmt.Errorf("unmounting scsi device at %s denied by policy: %w", mvd.MountPath, err) } + if h.hostMounts != nil { + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + + if err = h.hostMounts.RemoveRODevice(mvd.MountPath, devPath); err != nil { + return err + } + defer func() { + if err != nil { + _ = h.hostMounts.AddRODevice(mvd.MountPath, devPath) + } + }() + } } else { - if err := securityPolicy.EnforceRWDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { + if err = securityPolicy.EnforceRWDeviceUnmountPolicy(ctx, mvd.MountPath); err != nil { return fmt.Errorf("unmounting scsi device at %s denied by policy: %w", mvd.MountPath, err) } + if h.hostMounts != nil { + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + + if err = h.hostMounts.RemoveRWDevice(mvd.MountPath, devPath, mvd.Encrypted); err != nil { + return err + } + defer func() { + if err != nil { + _ = h.hostMounts.AddRWDevice(mvd.MountPath, devPath, mvd.Encrypted) + } + }() + } + } + // Check that the directory actually exists first, and if it does + // not then we just refuse to do anything, without closing the dm + // device or setting the mountsBroken flag. Policy metadata is + // still reverted to reflect the fact that we have not done + // anything. + // + // Note: we should not do this check before calling the policy + // enforcer, as otherwise we might inadvertently allow the host to + // find out whether an arbitrary path (which may point to sensitive + // data within a container rootfs) exists or not + if h.HasSecurityPolicy() { + err, exists := checkExists(mvd.MountPath) + if err != nil { + return err + } + if !exists { + return errors.Errorf("unmounting scsi device at %s failed: directory does not exist", mvd.MountPath) + } } config := &scsi.Config{ Encrypted: mvd.Encrypted, @@ -1193,8 +1405,11 @@ func modifyMappedVirtualDisk( Filesystem: mvd.Filesystem, BlockDev: mvd.BlockDev, } - if err := scsi.Unmount(ctx, mvd.Controller, mvd.Lun, mvd.Partition, - mvd.MountPath, config); err != nil { + err = scsi.Unmount(ctx, mvd.Controller, mvd.Lun, mvd.Partition, mvd.MountPath, config) + if err != nil { + h.setMountsBrokenIfConfidential( + fmt.Sprintf("unmounting scsi device at %s failed: %v", mvd.MountPath, err), + ) return err } } @@ -1204,13 +1419,23 @@ func modifyMappedVirtualDisk( } } -func modifyMappedDirectory( +func (h *Host) modifyMappedDirectory( ctx context.Context, vsock transport.Transport, rt guestrequest.RequestType, md *guestresource.LCOWMappedDirectory, - securityPolicy securitypolicy.SecurityPolicyEnforcer, ) (err error) { + securityPolicy := h.securityOptions.PolicyEnforcer + // For confidential containers, we revert the policy metadata on both mount + // and unmount errors, but if we've actually called Unmount and it fails we + // permanently block further device operations. + var rev securitypolicy.RevertableSectionHandle + rev, err = securityPolicy.StartRevertableSection() + if err != nil { + return errors.Wrapf(err, "failed to start revertable section on security policy enforcer") + } + defer h.commitOrRollbackPolicyRevSection(ctx, rev, &err) + switch rt { case guestrequest.RequestTypeAdd: err = securityPolicy.EnforcePlan9MountPolicy(ctx, md.MountPath) @@ -1218,6 +1443,15 @@ func modifyMappedDirectory( return errors.Wrapf(err, "mounting plan9 device at %s denied by policy", md.MountPath) } + if h.HasSecurityPolicy() { + if err = plan9.ValidateShareName(md.ShareName); err != nil { + return err + } + } + + // Similar to the reasoning in modifyMappedVirtualDisk, since we're + // rolling back the policy metadata, plan9.Mount here must clean up + // everything if it fails, which it does do. return plan9.Mount(ctx, vsock, md.MountPath, md.ShareName, uint32(md.Port), md.ReadOnly) case guestrequest.RequestTypeRemove: err = securityPolicy.EnforcePlan9UnmountPolicy(ctx, md.MountPath) @@ -1225,20 +1459,28 @@ func modifyMappedDirectory( return errors.Wrapf(err, "unmounting plan9 device at %s denied by policy", md.MountPath) } - return storage.UnmountPath(ctx, md.MountPath, true) + // Note: storage.UnmountPath is nop if path does not exist. + err = storage.UnmountPath(ctx, md.MountPath, true) + if err != nil { + h.setMountsBrokenIfConfidential( + fmt.Sprintf("unmounting plan9 device at %s failed: %v", md.MountPath, err), + ) + return err + } + return nil default: return newInvalidRequestTypeError(rt) } } -func modifyMappedVPMemDevice(ctx context.Context, +func (h *Host) modifyMappedVPMemDevice(ctx context.Context, rt guestrequest.RequestType, vpd *guestresource.LCOWMappedVPMemDevice, - securityPolicy securitypolicy.SecurityPolicyEnforcer, ) (err error) { var verityInfo *guestresource.DeviceVerityInfo + securityPolicy := h.securityOptions.PolicyEnforcer var deviceHash string - if len(securityPolicy.EncodedSecurityPolicy()) > 0 { + if h.HasSecurityPolicy() { if vpd.MappingInfo != nil { return fmt.Errorf("multi mapping is not supported with verity") } @@ -1248,6 +1490,17 @@ func modifyMappedVPMemDevice(ctx context.Context, } deviceHash = verityInfo.RootDigest } + + // For confidential containers, we revert the policy metadata on both mount + // and unmount errors, but if we've actually called Unmount and it fails we + // permanently block further device operations. + var rev securitypolicy.RevertableSectionHandle + rev, err = securityPolicy.StartRevertableSection() + if err != nil { + return errors.Wrapf(err, "failed to start revertable section on security policy enforcer") + } + defer h.commitOrRollbackPolicyRevSection(ctx, rev, &err) + switch rt { case guestrequest.RequestTypeAdd: err = securityPolicy.EnforceDeviceMountPolicy(ctx, vpd.MountPath, deviceHash) @@ -1255,13 +1508,39 @@ func modifyMappedVPMemDevice(ctx context.Context, return errors.Wrapf(err, "mounting pmem device %d onto %s denied by policy", vpd.DeviceNumber, vpd.MountPath) } + // Similar to the reasoning in modifyMappedVirtualDisk, since we're + // rolling back the policy metadata, pmem.Mount here must clean up + // everything if it fails, which it does do. return pmem.Mount(ctx, vpd.DeviceNumber, vpd.MountPath, vpd.MappingInfo, verityInfo) case guestrequest.RequestTypeRemove: - if err := securityPolicy.EnforceDeviceUnmountPolicy(ctx, vpd.MountPath); err != nil { + if err = securityPolicy.EnforceDeviceUnmountPolicy(ctx, vpd.MountPath); err != nil { return errors.Wrapf(err, "unmounting pmem device from %s denied by policy", vpd.MountPath) } - return pmem.Unmount(ctx, vpd.DeviceNumber, vpd.MountPath, vpd.MappingInfo, verityInfo) + // Check that the directory actually exists first, and if it does not + // then we just refuse to do anything, without closing the dm-linear or + // dm-verity device or setting the mountsBroken flag. + // + // Similar to the reasoning in modifyMappedVirtualDisk, we should not do + // this check before calling the policy enforcer. + if h.HasSecurityPolicy() { + err, exists := checkExists(vpd.MountPath) + if err != nil { + return err + } + if !exists { + return errors.Errorf("unmounting pmem device at %s failed: directory does not exist", vpd.MountPath) + } + } + + err = pmem.Unmount(ctx, vpd.DeviceNumber, vpd.MountPath, vpd.MappingInfo, verityInfo) + if err != nil { + h.setMountsBrokenIfConfidential( + fmt.Sprintf("unmounting pmem device at %s failed: %v", vpd.MountPath, err), + ) + return err + } + return nil default: return newInvalidRequestTypeError(rt) } @@ -1276,19 +1555,43 @@ func modifyMappedVPCIDevice(ctx context.Context, rt guestrequest.RequestType, vp } } -func modifyCombinedLayers( +func (h *Host) modifyCombinedLayers( ctx context.Context, rt guestrequest.RequestType, cl *guestresource.LCOWCombinedLayers, - scratchEncrypted bool, - securityPolicy securitypolicy.SecurityPolicyEnforcer, ) (err error) { - isConfidential := len(securityPolicy.EncodedSecurityPolicy()) > 0 + ctx, span := oc.StartSpan(ctx, "gcs::Host::modifyCombinedLayers") + defer span.End() + defer func() { oc.SetSpanStatus(span, err) }() + span.AddAttributes( + trace.StringAttribute("requestType", string(rt)), + trace.BoolAttribute("hasHostMounts", h.hostMounts != nil), + trace.StringAttribute("containerRootPath", cl.ContainerRootPath), + trace.StringAttribute("scratchPath", cl.ScratchPath), + ) + + securityPolicy := h.securityOptions.PolicyEnforcer containerID := cl.ContainerID + // For confidential containers, we revert the policy metadata on both mount + // and unmount errors, but if we've actually called Unmount and it fails we + // permanently block further device operations. + var rev securitypolicy.RevertableSectionHandle + rev, err = securityPolicy.StartRevertableSection() + if err != nil { + return errors.Wrapf(err, "failed to start revertable section on security policy enforcer") + } + defer h.commitOrRollbackPolicyRevSection(ctx, rev, &err) + + if h.hostMounts != nil { + // We will need this in multiple places, let's take the lock once here. + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + } + switch rt { case guestrequest.RequestTypeAdd: - if isConfidential { + if h.HasSecurityPolicy() { if err := checkValidContainerID(containerID, "container"); err != nil { return err } @@ -1333,25 +1636,68 @@ func modifyCombinedLayers( } else { upperdirPath = filepath.Join(cl.ScratchPath, "upper") workdirPath = filepath.Join(cl.ScratchPath, "work") + scratchEncrypted := false + if h.hostMounts != nil { + scratchEncrypted = h.hostMounts.IsEncrypted(cl.ScratchPath) + } if err := securityPolicy.EnforceScratchMountPolicy(ctx, cl.ScratchPath, scratchEncrypted); err != nil { return fmt.Errorf("scratch mounting denied by policy: %w", err) } } - if err := securityPolicy.EnforceOverlayMountPolicy(ctx, containerID, layerPaths, cl.ContainerRootPath); err != nil { + if err = securityPolicy.EnforceOverlayMountPolicy(ctx, containerID, layerPaths, cl.ContainerRootPath); err != nil { return fmt.Errorf("overlay creation denied by policy: %w", err) } + if h.hostMounts != nil { + if err = h.hostMounts.AddOverlay(cl.ContainerRootPath, layerPaths, cl.ScratchPath); err != nil { + return err + } + defer func() { + if err != nil { + _, _ = h.hostMounts.RemoveOverlay(cl.ContainerRootPath) + } + }() + } + // Correctness for policy revertable section: + // MountLayer does two things - mkdir, then mount. On mount failure, the + // target directory is cleaned up. Therefore we're clean in terms of + // side effects. return overlay.MountLayer(ctx, layerPaths, upperdirPath, workdirPath, cl.ContainerRootPath, readonly) case guestrequest.RequestTypeRemove: // cl.ContainerID is not set on remove requests, but rego checks that we can // only umount previously mounted targets anyway - if err := securityPolicy.EnforceOverlayUnmountPolicy(ctx, cl.ContainerRootPath); err != nil { + if err = securityPolicy.EnforceOverlayUnmountPolicy(ctx, cl.ContainerRootPath); err != nil { return errors.Wrap(err, "overlay removal denied by policy") } - return storage.UnmountPath(ctx, cl.ContainerRootPath, true) + // Check that no running container is using this overlay as its rootfs. + if h.HasSecurityPolicy() && h.IsOverlayInUse(cl.ContainerRootPath) { + return fmt.Errorf("overlay %q is in use by a running container", cl.ContainerRootPath) + } + + if h.hostMounts != nil { + var undoRemoveOverlay func() + if undoRemoveOverlay, err = h.hostMounts.RemoveOverlay(cl.ContainerRootPath); err != nil { + return err + } + defer func() { + if err != nil && undoRemoveOverlay != nil { + undoRemoveOverlay() + } + }() + } + + // Note: storage.UnmountPath is a no-op if the path does not exist. + err = storage.UnmountPath(ctx, cl.ContainerRootPath, true) + if err != nil { + h.setMountsBrokenIfConfidential( + fmt.Sprintf("unmounting overlay at %s failed: %v", cl.ContainerRootPath, err), + ) + return err + } + return nil default: return newInvalidRequestTypeError(rt) } @@ -1638,3 +1984,59 @@ func setupVirtualPodHugePageMountsPath(virtualSandboxID string) error { return storage.MountRShared(mountPath) } + +// If *err is not nil, the section is rolled back, otherwise it is committed. +func (h *Host) commitOrRollbackPolicyRevSection( + ctx context.Context, + rev securitypolicy.RevertableSectionHandle, + err *error, +) { + if !h.HasSecurityPolicy() { + // Don't produce bogus log entries if we aren't in confidential mode, + // even though rev.Rollback would have been no-op. + return + } + if *err != nil { + rev.Rollback() + logrus.WithContext(ctx).WithError(*err).Warn("rolling back security policy revertable section due to error") + } else { + rev.Commit() + } +} + +func (h *Host) DeleteContainerState(ctx context.Context, containerID string) error { + if h.HasSecurityPolicy() { + if err := checkValidContainerID(containerID, "container"); err != nil { + return err + } + } + + if err := h.checkMountsNotBroken(); err != nil { + return err + } + + c, err := h.GetCreatedContainer(containerID) + if err != nil { + return err + } + if h.HasSecurityPolicy() { + if !c.terminated.Load() { + return errors.Errorf("Denied deleting state of a running container %q", containerID) + } + overlay := c.spec.Root.Path + h.hostMounts.Lock() + defer h.hostMounts.Unlock() + if h.hostMounts.HasOverlayMountedAt(overlay) { + return errors.Errorf("Denied deleting state of a container with a overlay mount still active") + } + } + + // remove container state regardless of delete's success + defer h.RemoveContainer(containerID) + + if err = c.Delete(ctx); err != nil { + return err + } + + return nil +} diff --git a/internal/guest/runtime/hcsv2/uvm_state.go b/internal/guest/runtime/hcsv2/uvm_state.go index dd1ff521f0..96e64371a2 100644 --- a/internal/guest/runtime/hcsv2/uvm_state.go +++ b/internal/guest/runtime/hcsv2/uvm_state.go @@ -4,91 +4,360 @@ package hcsv2 import ( + "context" + "errors" "fmt" "path/filepath" "strings" "sync" + + "github.com/Microsoft/hcsshim/internal/gcs" + "github.com/Microsoft/hcsshim/internal/log" + "github.com/sirupsen/logrus" +) + +type deviceType int + +const ( + DeviceTypeRW deviceType = iota + DeviceTypeRO + DeviceTypeOverlay ) -type rwDevice struct { +func (d deviceType) String() string { + switch d { + case DeviceTypeRW: + return "RW" + case DeviceTypeRO: + return "RO" + case DeviceTypeOverlay: + return "Overlay" + default: + return fmt.Sprintf("Unknown(%d)", d) + } +} + +type device struct { + // fields common to all mountPath string + ty deviceType + usage int sourcePath string - encrypted bool + + // rw devices + encrypted bool + + // overlay devices + referencedDevices []*device } +// hostMounts tracks the state of fs/overlay mounts and their usage +// relationship. Users of this struct must call hm.Lock() before calling any +// other methods and call hm.Unlock() when done. +// +// Since mount/unmount operations can fail, the expected way to use this struct +// is to first lock it, call the method to add/remove the device, then, with the +// lock still held, perform the actual operation. If the operation fails, the +// caller must undo the operation by calling the appropriate remove/add method +// or the returned undo function, before unlocking. type hostMounts struct { - stateMutex sync.Mutex + stateMutex sync.Mutex + stateMutexLocked bool - // Holds information about read-write devices, which can be encrypted and - // contain overlay fs upper/work directory mounts. - readWriteMounts map[string]*rwDevice + // Map from mountPath to device struct + devices map[string]*device } func newHostMounts() *hostMounts { return &hostMounts{ - readWriteMounts: map[string]*rwDevice{}, + devices: make(map[string]*device), } } -// AddRWDevice adds read-write device metadata for device mounted at `mountPath`. -// Returns an error if there's an existing device mounted at `mountPath` location. -func (hm *hostMounts) AddRWDevice(mountPath string, sourcePath string, encrypted bool) error { +func (hm *hostMounts) expectLocked() { + if !hm.stateMutexLocked { + gcs.UnrecoverableError(errors.New("hostMounts: expected stateMutex to be locked, but it was not")) + } +} + +// Locks the state mutex. This is not re-entrant, calling it twice in the same +// thread will deadlock/panic. +func (hm *hostMounts) Lock() { hm.stateMutex.Lock() - defer hm.stateMutex.Unlock() + // Since we just acquired the lock, either it was not locked before, or + // somebody just unlocked it. Either case, hm.stateMutexLocked should be + // false. + if hm.stateMutexLocked { + gcs.UnrecoverableError(errors.New("hostMounts: stateMutexLocked already true when locking stateMutex")) + } + hm.stateMutexLocked = true +} + +// Unlocks the state mutex +func (hm *hostMounts) Unlock() { + hm.expectLocked() + hm.stateMutexLocked = false + hm.stateMutex.Unlock() +} - mountTarget := filepath.Clean(mountPath) - if source, ok := hm.readWriteMounts[mountTarget]; ok { - return fmt.Errorf("read-write with source %q and mount target %q already exists", source.sourcePath, mountPath) +func (hm *hostMounts) findDeviceAtPath(mountPath string) *device { + hm.expectLocked() + + if dev, ok := hm.devices[mountPath]; ok { + return dev } - hm.readWriteMounts[mountTarget] = &rwDevice{ - mountPath: mountTarget, - sourcePath: sourcePath, - encrypted: encrypted, + return nil +} + +func (hm *hostMounts) addDeviceToMapChecked(dev *device) error { + hm.expectLocked() + + if _, ok := hm.devices[dev.mountPath]; ok { + return fmt.Errorf("device at mount path %q already exists", dev.mountPath) } + hm.devices[dev.mountPath] = dev return nil } -// RemoveRWDevice removes the read-write device metadata for device mounted at -// `mountPath`. -func (hm *hostMounts) RemoveRWDevice(mountPath string, sourcePath string) error { - hm.stateMutex.Lock() - defer hm.stateMutex.Unlock() +func (hm *hostMounts) findDeviceContainingPath(path string) *device { + hm.expectLocked() + + // TODO: can we refactor this function by walking each component of the path + // from leaf to root, each time checking if the current component is a mount + // point? (i.e. why do we have to use filepath.Rel?) + + var foundDev *device + cleanPath := filepath.Clean(path) + for devPath, dev := range hm.devices { + relPath, err := filepath.Rel(devPath, cleanPath) + // skip further checks if an error is returned or the relative path + // contains "..", meaning that the `path` isn't directly nested under + // `rwPath`. + if err != nil || strings.HasPrefix(relPath, "..") { + continue + } + if foundDev == nil { + foundDev = dev + } else if len(dev.mountPath) > len(foundDev.mountPath) { + // The current device is mounted on top of a previously found device. + foundDev = dev + } + } + return foundDev +} + +func (hm *hostMounts) usePath(path string) (*device, error) { + hm.expectLocked() + + // Find the device at the given path and increment its usage count. + dev := hm.findDeviceContainingPath(path) + if dev == nil { + return nil, nil + } + dev.usage++ + return dev, nil +} + +func (hm *hostMounts) releaseDeviceUsage(dev *device) { + hm.expectLocked() + + if dev.usage <= 0 { + log.G(context.Background()).WithFields(logrus.Fields{ + "device": dev.mountPath, + "deviceSource": dev.sourcePath, + "deviceType": dev.ty, + "usage": dev.usage, + }).Error("hostMounts::releaseDeviceUsage: unexpected zero usage count") + return + } + dev.usage-- +} + +// User should carefully handle side-effects of adding a device if the device +// fails to be added. +func (hm *hostMounts) doAddDevice(mountPath string, ty deviceType, sourcePath string) (*device, error) { + hm.expectLocked() + + dev := &device{ + mountPath: filepath.Clean(mountPath), + ty: ty, + usage: 0, + sourcePath: sourcePath, + } + + if err := hm.addDeviceToMapChecked(dev); err != nil { + return nil, err + } + return dev, nil +} + +// Once checks is called, unless it returns an error, this function will always +// succeed +func (hm *hostMounts) doRemoveDevice(mountPath string, ty deviceType, sourcePath string, checks func(*device) error) error { + hm.expectLocked() unmountTarget := filepath.Clean(mountPath) - device, ok := hm.readWriteMounts[unmountTarget] - if !ok { + device := hm.findDeviceAtPath(unmountTarget) + if device == nil { // already removed or didn't exist return nil } if device.sourcePath != sourcePath { - return fmt.Errorf("wrong sourcePath %s", sourcePath) + return fmt.Errorf("wrong sourcePath %s, expected %s", sourcePath, device.sourcePath) + } + if device.ty != ty { + return fmt.Errorf("wrong device type %s, expected %s", ty, device.ty) + } + if device.usage > 0 { + log.G(context.Background()).WithFields(logrus.Fields{ + "device": device.mountPath, + "deviceSource": device.sourcePath, + "deviceType": device.ty, + "usage": device.usage, + }).Error("hostMounts::doRemoveDevice: device still in use, refusing unmount") + return fmt.Errorf("device at %q is still in use, can't unmount", unmountTarget) + } + if checks != nil { + if err := checks(device); err != nil { + return err + } } - delete(hm.readWriteMounts, unmountTarget) + delete(hm.devices, unmountTarget) return nil } +func (hm *hostMounts) AddRODevice(mountPath string, sourcePath string) error { + hm.expectLocked() + + _, err := hm.doAddDevice(mountPath, DeviceTypeRO, sourcePath) + return err +} + +// AddRWDevice adds read-write device metadata for device mounted at `mountPath`. +// Returns an error if there's an existing device mounted at `mountPath` location. +func (hm *hostMounts) AddRWDevice(mountPath string, sourcePath string, encrypted bool) error { + hm.expectLocked() + + dev, err := hm.doAddDevice(mountPath, DeviceTypeRW, sourcePath) + if err != nil { + return err + } + dev.encrypted = encrypted + return nil +} + +func (hm *hostMounts) AddOverlay(mountPath string, layers []string, scratchDir string) (err error) { + hm.expectLocked() + + dev, err := hm.doAddDevice(mountPath, DeviceTypeOverlay, mountPath) + if err != nil { + return err + } + dev.referencedDevices = make([]*device, 0, len(layers)+1) + defer func() { + if err != nil { + // If we failed to use any of the paths, we need to release the ones + // that we did use. + for _, d := range dev.referencedDevices { + hm.releaseDeviceUsage(d) + } + delete(hm.devices, mountPath) + } + }() + + for _, layer := range layers { + refDev, err := hm.usePath(layer) + if err != nil { + return err + } + if refDev != nil { + dev.referencedDevices = append(dev.referencedDevices, refDev) + } + } + refDev, err := hm.usePath(scratchDir) + if err != nil { + return err + } + if refDev != nil { + dev.referencedDevices = append(dev.referencedDevices, refDev) + } + + return nil +} + +func (hm *hostMounts) RemoveRODevice(mountPath string, sourcePath string) error { + hm.expectLocked() + + return hm.doRemoveDevice(mountPath, DeviceTypeRO, sourcePath, nil) +} + +// RemoveRWDevice removes the read-write device metadata for device mounted at +// `mountPath`. +func (hm *hostMounts) RemoveRWDevice(mountPath string, sourcePath string, encrypted bool) error { + hm.expectLocked() + + return hm.doRemoveDevice(mountPath, DeviceTypeRW, sourcePath, func(dev *device) error { + if dev.encrypted != encrypted { + return fmt.Errorf("encrypted flag wrong, provided %v, expected %v", encrypted, dev.encrypted) + } + return nil + }) +} + +func (hm *hostMounts) RemoveOverlay(mountPath string) (undo func(), err error) { + hm.expectLocked() + + var dev *device + err = hm.doRemoveDevice(mountPath, DeviceTypeOverlay, mountPath, func(_dev *device) error { + dev = _dev + for _, refDev := range dev.referencedDevices { + hm.releaseDeviceUsage(refDev) + } + return nil + }) + if err != nil { + // If we get an error from doRemoveDevice, we have not released anything + // yet. + return nil, err + } + undo = func() { + hm.expectLocked() + + for _, refDev := range dev.referencedDevices { + refDev.usage++ + } + + if _, ok := hm.devices[mountPath]; ok { + log.G(context.Background()).WithField("mountPath", mountPath).Error( + "hostMounts::RemoveOverlay: failed to undo remove: device that was removed exists in map", + ) + return + } + + hm.devices[mountPath] = dev + } + return undo, nil +} + // IsEncrypted checks if the given path is a sub-path of an encrypted read-write // device. func (hm *hostMounts) IsEncrypted(path string) bool { - hm.stateMutex.Lock() - defer hm.stateMutex.Unlock() + hm.expectLocked() - parentPath := "" - encrypted := false - cleanPath := filepath.Clean(path) - for rwPath, rwDev := range hm.readWriteMounts { - relPath, err := filepath.Rel(rwPath, cleanPath) - // skip further checks if an error is returned or the relative path - // contains "..", meaning that the `path` isn't directly nested under - // `rwPath`. - if err != nil || strings.HasPrefix(relPath, "..") { - continue - } - if len(rwDev.mountPath) > len(parentPath) { - parentPath = rwDev.mountPath - encrypted = rwDev.encrypted - } + dev := hm.findDeviceContainingPath(path) + if dev == nil { + return false + } + return dev.encrypted +} + +func (hm *hostMounts) HasOverlayMountedAt(path string) bool { + hm.expectLocked() + + dev := hm.findDeviceAtPath(filepath.Clean(path)) + if dev == nil { + return false } - return encrypted + return dev.ty == DeviceTypeOverlay } diff --git a/internal/guest/runtime/hcsv2/uvm_state_test.go b/internal/guest/runtime/hcsv2/uvm_state_test.go index b708caaeba..e87a207308 100644 --- a/internal/guest/runtime/hcsv2/uvm_state_test.go +++ b/internal/guest/runtime/hcsv2/uvm_state_test.go @@ -12,10 +12,13 @@ func Test_Add_Remove_RWDevice(t *testing.T) { mountPath := "/run/gcs/c/abcd" sourcePath := "/dev/sda" + hm.Lock() + defer hm.Unlock() + if err := hm.AddRWDevice(mountPath, sourcePath, false); err != nil { t.Fatalf("unexpected error adding RW device: %s", err) } - if err := hm.RemoveRWDevice(mountPath, sourcePath); err != nil { + if err := hm.RemoveRWDevice(mountPath, sourcePath, false); err != nil { t.Fatalf("unexpected error removing RW device: %s", err) } } @@ -25,29 +28,55 @@ func Test_Cannot_AddRWDevice_Twice(t *testing.T) { mountPath := "/run/gcs/c/abc" sourcePath := "/dev/sda" + hm.Lock() if err := hm.AddRWDevice(mountPath, sourcePath, false); err != nil { t.Fatalf("unexpected error: %s", err) } + hm.Unlock() + + hm.Lock() if err := hm.AddRWDevice(mountPath, sourcePath, false); err == nil { t.Fatalf("expected error adding %q for the second time", mountPath) } + hm.Unlock() } func Test_Cannot_RemoveRWDevice_Wrong_Source(t *testing.T) { hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + mountPath := "/run/gcs/c/abcd" sourcePath := "/dev/sda" wrongSource := "/dev/sdb" if err := hm.AddRWDevice(mountPath, sourcePath, false); err != nil { t.Fatalf("unexpected error: %s", err) } - if err := hm.RemoveRWDevice(mountPath, wrongSource); err == nil { + if err := hm.RemoveRWDevice(mountPath, wrongSource, false); err == nil { t.Fatalf("expected error removing wrong source %s", wrongSource) } } +func Test_Cannot_RemoveRWDevice_Wrong_Encrypted(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + mountPath := "/run/gcs/c/abcd" + sourcePath := "/dev/sda" + if err := hm.AddRWDevice(mountPath, sourcePath, false); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if err := hm.RemoveRWDevice(mountPath, sourcePath, true); err == nil { + t.Fatalf("expected error removing RW device with wrong encrypted flag") + } +} + func Test_HostMounts_IsEncrypted(t *testing.T) { hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + encryptedPath := "/run/gcs/c/encrypted" encryptedSource := "/dev/sda" if err := hm.AddRWDevice(encryptedPath, encryptedSource, true); err != nil { @@ -108,3 +137,189 @@ func Test_HostMounts_IsEncrypted(t *testing.T) { }) } } + +func Test_HostMounts_AddRemoveRODevice(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + mountPath := "/run/gcs/c/abcd" + sourcePath := "/dev/sda" + + if err := hm.AddRODevice(mountPath, sourcePath); err != nil { + t.Fatalf("unexpected error adding RO device: %s", err) + } + + if err := hm.RemoveRODevice(mountPath, sourcePath); err != nil { + t.Fatalf("unexpected error removing RO device: %s", err) + } +} + +func Test_HostMounts_Cannot_AddRODevice_Twice(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + mountPath := "/run/gcs/c/abc" + sourcePath := "/dev/sda" + + if err := hm.AddRODevice(mountPath, sourcePath); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if err := hm.AddRODevice(mountPath, sourcePath); err == nil { + t.Fatalf("expected error adding %q for the second time", mountPath) + } +} + +func Test_HostMounts_AddRemoveOverlay(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + mountPath := "/run/gcs/c/aaaa/rootfs" + layers := []string{ + "/run/mounts/scsi/m1", + "/run/mounts/scsi/m2", + "/run/mounts/scsi/m3", + } + for _, layer := range layers { + if err := hm.AddRODevice(layer, layer); err != nil { + t.Fatalf("unexpected error adding RO device: %s", err) + } + } + scratchDir := "/run/gcs/c/aaaa/scratch" + if err := hm.AddRWDevice(scratchDir, scratchDir, true); err != nil { + t.Fatalf("unexpected error adding RW device: %s", err) + } + if err := hm.AddOverlay(mountPath, layers, scratchDir); err != nil { + t.Fatalf("unexpected error adding overlay: %s", err) + } + undo, err := hm.RemoveOverlay(mountPath) + if err != nil { + t.Fatalf("unexpected error removing overlay: %s", err) + } + if undo == nil { + t.Fatalf("expected undo function to be non-nil") + } + undo() + if _, err = hm.RemoveOverlay(mountPath); err != nil { + t.Fatalf("unexpected error removing overlay again: %s", err) + } +} + +func Test_HostMounts_Cannot_RemoveInUseDeviceByOverlay(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + mountPath := "/run/gcs/c/aaaa/rootfs" + layers := []string{ + "/run/mounts/scsi/m1", + "/run/mounts/scsi/m2", + "/run/mounts/scsi/m3", + } + for _, layer := range layers { + if err := hm.AddRODevice(layer, layer); err != nil { + t.Fatalf("unexpected error adding RO device: %s", err) + } + } + scratchDir := "/run/gcs/c/aaaa/scratch" + if err := hm.AddRWDevice(scratchDir, scratchDir, true); err != nil { + t.Fatalf("unexpected error adding RW device: %s", err) + } + if err := hm.AddOverlay(mountPath, layers, scratchDir); err != nil { + t.Fatalf("unexpected error adding overlay: %s", err) + } + + for _, layer := range layers { + if err := hm.RemoveRODevice(layer, layer); err == nil { + t.Fatalf("expected error removing RO device %s while in use by overlay", layer) + } + } + if err := hm.RemoveRWDevice(scratchDir, scratchDir, true); err == nil { + t.Fatalf("expected error removing RW device %s while in use by overlay", scratchDir) + } + + if _, err := hm.RemoveOverlay(mountPath); err != nil { + t.Fatalf("unexpected error removing overlay: %s", err) + } + + // now we can remove + for _, layer := range layers { + if err := hm.RemoveRODevice(layer, layer); err != nil { + t.Fatalf("unexpected error removing RO device %s: %s", layer, err) + } + } + if err := hm.RemoveRWDevice(scratchDir, scratchDir, true); err != nil { + t.Fatalf("unexpected error removing RW device %s: %s", scratchDir, err) + } +} + +func Test_HostMounts_Cannot_RemoveInUseDeviceByOverlay_MultipleUsers(t *testing.T) { + hm := newHostMounts() + hm.Lock() + defer hm.Unlock() + + overlay1 := "/run/gcs/c/aaaa/rootfs" + overlay2 := "/run/gcs/c/bbbb/rootfs" + layers := []string{ + "/run/mounts/scsi/m1", + "/run/mounts/scsi/m2", + "/run/mounts/scsi/m3", + } + for _, layer := range layers { + if err := hm.AddRODevice(layer, layer); err != nil { + t.Fatalf("unexpected error adding RO device: %s", err) + } + } + sharedScratchMount := "/run/gcs/c/sandbox" + scratch1 := sharedScratchMount + "/scratch/aaaa" + scratch2 := sharedScratchMount + "/scratch/bbbb" + if err := hm.AddRWDevice(sharedScratchMount, sharedScratchMount, true); err != nil { + t.Fatalf("unexpected error adding RW device: %s", err) + } + if err := hm.AddOverlay(overlay1, layers, scratch1); err != nil { + t.Fatalf("unexpected error adding overlay1: %s", err) + } + + if err := hm.AddOverlay(overlay2, layers[0:2], scratch2); err != nil { + t.Fatalf("unexpected error adding overlay2: %s", err) + } + + for _, layer := range layers { + if err := hm.RemoveRODevice(layer, layer); err == nil { + t.Fatalf("expected error removing RO device %s while in use by overlay", layer) + } + } + if err := hm.RemoveRWDevice(sharedScratchMount, sharedScratchMount, true); err == nil { + t.Fatalf("expected error removing RW device %s while in use by overlay", sharedScratchMount) + } + + if _, err := hm.RemoveOverlay(overlay1); err != nil { + t.Fatalf("unexpected error removing overlay 1: %s", err) + } + + for _, layer := range layers[0:2] { + if err := hm.RemoveRODevice(layer, layer); err == nil { + t.Fatalf("expected error removing RO device %s (still in use by overlay 2)", layer) + } + } + if err := hm.RemoveRODevice(layers[2], layers[2]); err != nil { + t.Fatalf("unexpected error removing layers[2] which is not being used by overlay 2: %s", err) + } + if err := hm.RemoveRWDevice(sharedScratchMount, sharedScratchMount, true); err == nil { + t.Fatalf("expected error removing RW device %s while in use by overlay 2", scratch2) + } + + if _, err := hm.RemoveOverlay(overlay2); err != nil { + t.Fatalf("unexpected error removing overlay 2: %s", err) + } + for _, layer := range layers[0:2] { + if err := hm.RemoveRODevice(layer, layer); err != nil { + t.Fatalf("unexpected error removing RO device %s: %s", layer, err) + } + } + if err := hm.RemoveRWDevice(sharedScratchMount, sharedScratchMount, true); err != nil { + t.Fatalf("unexpected error removing RW device %s: %s", sharedScratchMount, err) + } +} diff --git a/internal/guest/storage/mount.go b/internal/guest/storage/mount.go index a3d10a3b25..142f0ccbbc 100644 --- a/internal/guest/storage/mount.go +++ b/internal/guest/storage/mount.go @@ -16,6 +16,7 @@ import ( "go.opencensus.io/trace" "golang.org/x/sys/unix" + "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/oc" ) @@ -126,6 +127,7 @@ func UnmountPath(ctx context.Context, target string, removeTarget bool) (err err if _, err := osStat(target); err != nil { if os.IsNotExist(err) { + log.G(ctx).WithField("target", target).Warnf("UnmountPath called for non-existent path") return nil } return errors.Wrapf(err, "failed to determine if path '%s' exists", target) diff --git a/internal/guest/storage/overlay/overlay.go b/internal/guest/storage/overlay/overlay.go index aa4877508f..84bf8fa529 100644 --- a/internal/guest/storage/overlay/overlay.go +++ b/internal/guest/storage/overlay/overlay.go @@ -56,8 +56,7 @@ func processErrNoSpace(ctx context.Context, path string, err error) { }).WithError(err).Warn("got ENOSPC, gathering diagnostics") } -// MountLayer first enforces the security policy for the container's layer paths -// and then calls Mount to mount the layer paths as an overlayfs. +// MountLayer calls Mount to mount the layer paths as an overlayfs. func MountLayer( ctx context.Context, layerPaths []string, diff --git a/internal/guest/storage/plan9/plan9.go b/internal/guest/storage/plan9/plan9.go index 5c1f1d74f4..44ac0f4e4e 100644 --- a/internal/guest/storage/plan9/plan9.go +++ b/internal/guest/storage/plan9/plan9.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "os" + "regexp" "syscall" "github.com/Microsoft/hcsshim/internal/guest/transport" @@ -25,6 +26,19 @@ var ( unixMount = unix.Mount ) +// c.f. v9fs_parse_options in linux/fs/9p/v9fs.c - technically anything other +// than ',' is ok (quoting is not handled), however, this name is generated from +// a counter in AddPlan9 (internal/uvm/plan9.go), and therefore we expect only +// digits from a normal hcsshim host. +var validShareNameRegex = regexp.MustCompile(`^[0-9]+$`) + +func ValidateShareName(name string) error { + if !validShareNameRegex.MatchString(name) { + return fmt.Errorf("invalid plan9 share name %q: must match regex %q", name, validShareNameRegex.String()) + } + return nil +} + // Mount dials a connection from `vsock` and mounts a Plan9 share to `target`. // // `target` will be created. On mount failure the created `target` will be diff --git a/internal/guest/storage/scsi/scsi.go b/internal/guest/storage/scsi/scsi.go index ec62636590..83c586c3eb 100644 --- a/internal/guest/storage/scsi/scsi.go +++ b/internal/guest/storage/scsi/scsi.go @@ -121,8 +121,9 @@ type Config struct { // Mount creates a mount from the SCSI device on `controller` index `lun` to // `target` // -// `target` will be created. On mount failure the created `target` will be -// automatically cleaned up. +// `target` will be created. On mount failure the created `target`, as well as +// any associated dm-crypt or dm-verify devices will be automatically cleaned +// up. // // If the config has `encrypted` is set to true, the SCSI device will be // encrypted using dm-crypt. @@ -200,7 +201,8 @@ func Mount( var deviceFS string if config.Encrypted { cryptDeviceName := fmt.Sprintf(cryptDeviceFmt, controller, lun, partition) - encryptedSource, err := encryptDevice(spnCtx, source, cryptDeviceName) + var encryptedSource string + encryptedSource, err = encryptDevice(spnCtx, source, cryptDeviceName) if err != nil { // todo (maksiman): add better retry logic, similar to how SCSI device mounts are // retried on unix.ENOENT and unix.ENXIO. The retry should probably be on an @@ -211,6 +213,13 @@ func Mount( } } source = encryptedSource + defer func() { + if err != nil { + if err := cleanupCryptDevice(spnCtx, cryptDeviceName); err != nil { + log.G(spnCtx).WithError(err).WithField("cryptDeviceName", cryptDeviceName).Debug("failed to cleanup dm-crypt device after mount failure") + } + } + }() } else { // Get the filesystem that is already on the device (if any) and use that // as the mountType unless `Filesystem` was given. diff --git a/internal/guest/storage/scsi/scsi_test.go b/internal/guest/storage/scsi/scsi_test.go index ebfcf8e382..94992047bd 100644 --- a/internal/guest/storage/scsi/scsi_test.go +++ b/internal/guest/storage/scsi/scsi_test.go @@ -999,6 +999,12 @@ func Test_Mount_EncryptDevice_Mkfs_Error(t *testing.T) { } return expectedDevicePath, nil } + cleanupCryptDevice = func(_ context.Context, dmCryptName string) error { + if dmCryptName != expectedCryptTarget { + t.Fatalf("expected cleanupCryptDevice name %q got %q", expectedCryptTarget, dmCryptName) + } + return nil + } osStat = osStatNoop xfsFormat = func(arg string) error { diff --git a/internal/regopolicyinterpreter/regopolicyinterpreter.go b/internal/regopolicyinterpreter/regopolicyinterpreter.go index 66f62c5114..6e316f9b41 100644 --- a/internal/regopolicyinterpreter/regopolicyinterpreter.go +++ b/internal/regopolicyinterpreter/regopolicyinterpreter.go @@ -63,6 +63,9 @@ type RegoModule struct { type regoMetadata map[string]map[string]interface{} +const metadataRootKey = "metadata" +const metadataOperationsKey = "metadata" + type regoMetadataAction string const ( @@ -81,6 +84,11 @@ type regoMetadataOperation struct { // The result from a policy query type RegoQueryResult map[string]interface{} +// An immutable, saved copy of the metadata state. +type SavedMetadata struct { + metadataRoot regoMetadata +} + // deep copy for an object func copyObject(data map[string]interface{}) (map[string]interface{}, error) { objJSON, err := json.Marshal(data) @@ -113,6 +121,24 @@ func copyValue(value interface{}) (interface{}, error) { return valueCopy, nil } +// deep copy for regoMetadata. +// We cannot use copyObject for this due to the fact that map[string]interface{} +// is a concrete type and a map of it cannot be used as a map of interface{}. +func copyRegoMetadata(value regoMetadata) (regoMetadata, error) { + valueJSON, err := json.Marshal(value) + if err != nil { + return nil, err + } + + var valueCopy regoMetadata + err = json.Unmarshal(valueJSON, &valueCopy) + if err != nil { + return nil, err + } + + return valueCopy, nil +} + // NewRegoPolicyInterpreter creates a new RegoPolicyInterpreter, using the code provided. // inputData is the Rego data which should be used as the initial state // of the interpreter. A deep copy is performed on it such that it will @@ -123,8 +149,8 @@ func NewRegoPolicyInterpreter(code string, inputData map[string]interface{}) (*R return nil, fmt.Errorf("unable to copy the input data: %w", err) } - if _, ok := data["metadata"]; !ok { - data["metadata"] = make(regoMetadata) + if _, ok := data[metadataRootKey]; !ok { + data[metadataRootKey] = make(regoMetadata) } policy := &RegoPolicyInterpreter{ @@ -207,7 +233,7 @@ func (r *RegoPolicyInterpreter) GetMetadata(name string, key string) (interface{ r.dataAndModulesMutex.Lock() defer r.dataAndModulesMutex.Unlock() - metadataRoot, ok := r.data["metadata"].(regoMetadata) + metadataRoot, ok := r.data[metadataRootKey].(regoMetadata) if !ok { return nil, errors.New("illegal interpreter state: invalid metadata object type") } @@ -228,6 +254,32 @@ func (r *RegoPolicyInterpreter) GetMetadata(name string, key string) (interface{ } } +// Saves a copy of the internal policy metadata state. +func (r *RegoPolicyInterpreter) SaveMetadata() (s SavedMetadata, err error) { + r.dataAndModulesMutex.Lock() + defer r.dataAndModulesMutex.Unlock() + + metadataRoot, ok := r.data[metadataRootKey].(regoMetadata) + if !ok { + return SavedMetadata{}, errors.New("illegal interpreter state: invalid metadata object type") + } + s.metadataRoot, err = copyRegoMetadata(metadataRoot) + return s, err +} + +// Restores a previously saved metadata state. +func (r *RegoPolicyInterpreter) RestoreMetadata(m SavedMetadata) error { + r.dataAndModulesMutex.Lock() + defer r.dataAndModulesMutex.Unlock() + + copied, err := copyRegoMetadata(m.metadataRoot) + if err != nil { + return fmt.Errorf("unable to copy metadata: %w", err) + } + r.data[metadataRootKey] = copied + return nil +} + func newRegoMetadataOperation(operation interface{}) (*regoMetadataOperation, error) { var metadataOp regoMetadataOperation @@ -286,7 +338,7 @@ func (r *RegoPolicyInterpreter) UpdateOSType(os string) error { func (r *RegoPolicyInterpreter) updateMetadata(ops []*regoMetadataOperation) error { // dataAndModulesMutex must be held before calling this - metadataRoot, ok := r.data["metadata"].(regoMetadata) + metadataRoot, ok := r.data[metadataRootKey].(regoMetadata) if !ok { return errors.New("illegal interpreter state: invalid metadata object type") } @@ -431,7 +483,7 @@ func (r *RegoPolicyInterpreter) logMetadata() { return } - contents, err := json.Marshal(r.data["metadata"]) + contents, err := json.Marshal(r.data[metadataRootKey]) if err != nil { r.metadataLogger.Printf("error marshaling metadata: %v\n", err.Error()) } else { @@ -637,7 +689,7 @@ func (r *RegoPolicyInterpreter) Query(rule string, input map[string]interface{}) r.logResult(rule, resultSet) ops := []*regoMetadataOperation{} - if rawMetadata, ok := resultSet["metadata"]; ok { + if rawMetadata, ok := resultSet[metadataOperationsKey]; ok { metadata, ok := rawMetadata.([]interface{}) if !ok { return nil, errors.New("error loading metadata array: invalid type") @@ -660,7 +712,7 @@ func (r *RegoPolicyInterpreter) Query(rule string, input map[string]interface{}) } for name, value := range resultSet { - if name == "metadata" { + if name == metadataOperationsKey { continue } else { result[name] = value diff --git a/internal/regopolicyinterpreter/regopolicyinterpreter_test.go b/internal/regopolicyinterpreter/regopolicyinterpreter_test.go index b7d86609f7..3872afff51 100644 --- a/internal/regopolicyinterpreter/regopolicyinterpreter_test.go +++ b/internal/regopolicyinterpreter/regopolicyinterpreter_test.go @@ -72,6 +72,37 @@ func Test_copyValue(t *testing.T) { } } +func Test_copyRegoMetadata(t *testing.T) { + f := func(orig testRegoMetadata) bool { + copy, err := copyRegoMetadata(regoMetadata(orig)) + if err != nil { + t.Error(err) + return false + } + + if len(orig) != len(copy) { + t.Errorf("original and copy have different number of objects: %d != %d", len(orig), len(copy)) + return false + } + + for name, origObject := range orig { + if copyObject, ok := copy[name]; ok { + if !assertObjectsEqual(origObject, copyObject) { + t.Errorf("original and copy differ on key %s", name) + } + } else { + t.Errorf("copy missing object %s", name) + } + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 30, Rand: testRand}); err != nil { + t.Errorf("Test_copyRegoMetadata: %v", err) + } +} + //go:embed test.rego var testCode string @@ -364,6 +395,107 @@ func Test_Metadata_Remove(t *testing.T) { } } +func Test_Metadata_SaveRestore(t *testing.T) { + rego, err := setupRego() + if err != nil { + t.Fatal(err) + } + + f := func(pairs1before, pairs1after intPairArray, name1 metadataName, pairs2before, pairs2after intPairArray, name2 metadataName) bool { + if name1 == name2 { + t.Fatalf("generated two identical names: %s", name1) + } + + err := appendAll(rego, pairs1before, name1) + if err != nil { + t.Errorf("error appending pairs1before: %v", err) + return false + } + err = appendAll(rego, pairs2before, name2) + if err != nil { + t.Errorf("error appending pairs2before: %v", err) + return false + } + + saved, err := rego.SaveMetadata() + if err != nil { + t.Errorf("unable to save metadata: %v", err) + return false + } + + beforeSum1 := getExpectedGapFromPairs(pairs1before) + err = computeGap(rego, name1, beforeSum1) + if err != nil { + t.Error(err) + return false + } + + beforeSum2 := getExpectedGapFromPairs(pairs2before) + err = computeGap(rego, name2, beforeSum2) + if err != nil { + t.Error(err) + return false + } + + // computeGap would have cleared the list, so we restore it. + err = rego.RestoreMetadata(saved) + if err != nil { + t.Errorf("unable to restore metadata: %v", err) + return false + } + + err = appendAll(rego, pairs1after, name1) + if err != nil { + t.Errorf("error appending pairs1after: %v", err) + return false + } + + err = appendAll(rego, pairs2after, name2) + if err != nil { + t.Errorf("error appending pairs2after: %v", err) + return false + } + + afterSum1 := beforeSum1 + getExpectedGapFromPairs(pairs1after) + err = computeGap(rego, name1, afterSum1) + if err != nil { + t.Error(err) + return false + } + + afterSum2 := beforeSum2 + getExpectedGapFromPairs(pairs2after) + err = computeGap(rego, name2, afterSum2) + if err != nil { + t.Error(err) + return false + } + + err = rego.RestoreMetadata(saved) + if err != nil { + t.Errorf("unable to restore metadata: %v", err) + return false + } + + err = computeGap(rego, name1, beforeSum1) + if err != nil { + t.Errorf("computeGap failed for name1 after restore: %v", err) + return false + } + + err = computeGap(rego, name2, beforeSum2) + if err != nil { + t.Errorf("computeGap failed for name2 after restore: %v", err) + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 100, Rand: testRand}); err != nil { + t.Errorf("Test_Metadata_SaveRestore: %v", err) + } +} + //go:embed module.rego var moduleCode string @@ -508,6 +640,7 @@ type testValue struct { } type testArray []interface{} type testObject map[string]interface{} +type testRegoMetadata regoMetadata type testValueType int @@ -580,6 +713,16 @@ func (testObject) Generate(r *rand.Rand, _ int) reflect.Value { return reflect.ValueOf(value) } +func (testRegoMetadata) Generate(r *rand.Rand, _ int) reflect.Value { + numObjects := r.Intn(maxNumberOfFields) + metadata := make(testRegoMetadata) + for i := 0; i < numObjects; i++ { + name := uniqueString(r) + metadata[name] = generateObject(r, 0) + } + return reflect.ValueOf(metadata) +} + func getResult(r *RegoPolicyInterpreter, p intPair, rule string) (RegoQueryResult, error) { input := map[string]interface{}{"a": p.a, "b": p.b} result, err := r.Query("data.test."+rule, input) @@ -640,6 +783,27 @@ func appendLists(r *RegoPolicyInterpreter, p intPair, name metadataName) error { return nil } +func appendAll(r *RegoPolicyInterpreter, pairs intPairArray, name metadataName) error { + for _, pair := range pairs { + if err := appendLists(r, pair, name); err != nil { + return fmt.Errorf("error appending pair %v: %w", pair, err) + } + } + return nil +} + +func getExpectedGapFromPairs(pairs intPairArray) int { + expected := 0 + for _, pair := range pairs { + if pair.a >= pair.b { + expected += pair.a - pair.b + } else { + expected += pair.b - pair.a + } + } + return expected +} + func computeGap(r *RegoPolicyInterpreter, name metadataName, expected int) error { input := map[string]interface{}{"name": string(name)} result, err := r.Query("data.test.compute_gap", input) diff --git a/pkg/securitypolicy/framework.rego b/pkg/securitypolicy/framework.rego index de534eb863..9fc7aba02e 100644 --- a/pkg/securitypolicy/framework.rego +++ b/pkg/securitypolicy/framework.rego @@ -387,6 +387,14 @@ seccomp_ok(seccomp_profile_sha256) { is_windows } +devices_ok(expected_devices, actual_devices) { + # Allow out of order but not duplicates + set_expected := {dev | dev := expected_devices[_]} + set_actual := {dev | dev := actual_devices[_]} + set_expected == set_actual + count(set_actual) == count(actual_devices) +} + default container_started := false container_started { @@ -598,6 +606,8 @@ create_container := {"metadata": [updateMatches, addStarted], command_ok(container.command) mountList_ok(container.mounts, container.allow_elevated) seccomp_ok(container.seccomp_profile_sha256) + # We do not support adding device nodes to the policy yet + devices_ok([], input.devices) ] count(possible_after_initial_containers) > 0 @@ -2089,6 +2099,12 @@ errors["capabilities don't match"] { count(possible_after_caps_containers) == 0 } +errors["devices not supported"] { + is_linux + input.rule == "create_container" + not devices_ok([], input.devices) +} + # covers exec_in_container as well. it shouldn't be possible to ever get # an exec_in_container as it "inherits" capabilities rules from create_container errors["containers only distinguishable by capabilties"] { diff --git a/pkg/securitypolicy/rego_utils_test.go b/pkg/securitypolicy/rego_utils_test.go index 5ac12a5a0a..3adbb6a2b6 100644 --- a/pkg/securitypolicy/rego_utils_test.go +++ b/pkg/securitypolicy/rego_utils_test.go @@ -347,18 +347,25 @@ type regoPlan9MountTestConfig struct { } func mountImageForContainer(policy *regoEnforcer, container *securityPolicyContainer) (string, error) { - ctx := context.Background() containerID := testDataGenerator.uniqueContainerID() + if err := mountImageForContainerWithID(policy, container, containerID); err != nil { + return "", err + } + return containerID, nil +} + +func mountImageForContainerWithID(policy *regoEnforcer, container *securityPolicyContainer, containerID string) error { + ctx := context.Background() layerPaths, err := testDataGenerator.createValidOverlayForContainer(policy, container) if err != nil { - return "", fmt.Errorf("error creating valid overlay: %w", err) + return fmt.Errorf("error creating valid overlay: %w", err) } scratchDisk := getScratchDiskMountTarget(containerID) err = policy.EnforceRWDeviceMountPolicy(ctx, scratchDisk, true, true, "xfs") if err != nil { - return "", fmt.Errorf("error mounting scratch disk: %w", err) + return fmt.Errorf("error mounting scratch disk: %w", err) } overlayTarget := getOverlayMountTarget(containerID) @@ -367,12 +374,13 @@ func mountImageForContainer(policy *regoEnforcer, container *securityPolicyConta err = policy.EnforceOverlayMountPolicy( ctx, containerID, copyStrings(layerPaths), overlayTarget) if err != nil { - return "", fmt.Errorf("error mounting filesystem: %w", err) + return fmt.Errorf("error mounting filesystem: %w", err) } - return containerID, nil + return nil } + func buildMountSpecFromMountArray(mounts []mountInternal, sandboxID string, r *rand.Rand) *oci.Spec { mountSpec := new(oci.Spec) @@ -1404,6 +1412,10 @@ func setupRegoCreateContainerTest(gc *generatedConstraints, testContainer *secur return nil, err } + return createTestContainerSpec(gc, containerID, testContainer, privilegedError, policy, defaultMounts, privilegedMounts) +} + +func createTestContainerSpec(gc *generatedConstraints, containerID string, testContainer *securityPolicyContainer, privilegedError bool, policy *regoEnforcer, defaultMounts, privilegedMounts []mountInternal) (*regoContainerTestConfig, error) { envList := buildEnvironmentVariablesFromEnvRules(testContainer.EnvRules, testRand) sandboxID := testDataGenerator.uniqueSandboxID() @@ -2994,3 +3006,19 @@ type containerInitProcess struct { WorkingDir string AllowStdioAccess bool } + +func startRevertableSection(t *testing.T, policy *regoEnforcer) RevertableSectionHandle { + rev, err := policy.StartRevertableSection() + if err != nil { + t.Fatalf("Failed to start revertable section: %v", err) + } + return rev +} + +func commitOrRollback(rev RevertableSectionHandle, shouldCommit bool) { + if shouldCommit { + rev.Commit() + } else { + rev.Rollback() + } +} diff --git a/pkg/securitypolicy/regopolicy_linux_test.go b/pkg/securitypolicy/regopolicy_linux_test.go index 8dd409fccf..1aa733590a 100644 --- a/pkg/securitypolicy/regopolicy_linux_test.go +++ b/pkg/securitypolicy/regopolicy_linux_test.go @@ -963,6 +963,90 @@ func Test_Rego_EnforceOverlayMountPolicy_Multiple_Instances_Same_Container(t *te } } +func Test_Rego_EnforceOverlayMountPolicy_MountFail(t *testing.T) { + f := func(gc *generatedConstraints, commitOnEnforcementFailure bool) bool { + securityPolicy := gc.toPolicy() + policy, err := newRegoPolicy(securityPolicy.marshalRego(), []oci.Mount{}, []oci.Mount{}, testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + tc := selectContainerFromContainerList(gc.containers, testRand) + tid := testDataGenerator.uniqueContainerID() + scratchTarget := getScratchDiskMountTarget(tid) + + rev := startRevertableSection(t, policy) + err = policy.EnforceRWDeviceMountPolicy(gc.ctx, scratchTarget, true, true, "xfs") + if err != nil { + t.Errorf("failed to EnforceRWDeviceMountPolicy: %v", err) + return false + } + rev.Commit() + + layerToErr := testRand.Intn(len(tc.Layers)) + errLayerPathIndex := len(tc.Layers) - layerToErr - 1 + layerPaths := make([]string, len(tc.Layers)) + for i, layerHash := range tc.Layers { + rev := startRevertableSection(t, policy) + target := testDataGenerator.uniqueLayerMountTarget() + layerPaths[len(tc.Layers)-i-1] = target + err = policy.EnforceDeviceMountPolicy(gc.ctx, target, layerHash) + if err != nil { + t.Errorf("failed to EnforceDeviceMountPolicy: %v", err) + return false + } + if i == layerToErr { + // Simulate a mount failure at this point, which will cause us to rollback + rev.Rollback() + } else { + rev.Commit() + } + } + + rev = startRevertableSection(t, policy) + overlayTarget := getOverlayMountTarget(tid) + err = policy.EnforceOverlayMountPolicy(gc.ctx, tid, layerPaths, overlayTarget) + if !assertDecisionJSONContains(t, err, append(slices.Clone(layerPaths), "no matching containers for overlay")...) { + return false + } + commitOrRollback(rev, commitOnEnforcementFailure) + + layerPathsWithoutErr := make([]string, 0) + for i, layerPath := range layerPaths { + if i != errLayerPathIndex { + layerPathsWithoutErr = append(layerPathsWithoutErr, layerPath) + } + } + + rev = startRevertableSection(t, policy) + err = policy.EnforceOverlayMountPolicy(gc.ctx, tid, layerPathsWithoutErr, overlayTarget) + if !assertDecisionJSONContains(t, err, append(slices.Clone(layerPathsWithoutErr), "no matching containers for overlay")...) { + return false + } + commitOrRollback(rev, commitOnEnforcementFailure) + + retryTarget := layerPaths[errLayerPathIndex] + rev = startRevertableSection(t, policy) + err = policy.EnforceDeviceMountPolicy(gc.ctx, retryTarget, tc.Layers[layerToErr]) + if err != nil { + t.Errorf("failed to EnforceDeviceMountPolicy again after one previous reverted failure: %v", err) + return false + } + rev.Commit() + err = policy.EnforceOverlayMountPolicy(gc.ctx, tid, layerPaths, overlayTarget) + if err != nil { + t.Errorf("failed to EnforceOverlayMountPolicy after one previous reverted failure: %v", err) + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceOverlayMountPolicy_MountFail: %v", err) + } +} + func Test_Rego_EnforceOverlayUnmountPolicy(t *testing.T) { f := func(p *generatedConstraints) bool { tc, err := setupRegoOverlayTest(p, true) @@ -2043,6 +2127,39 @@ func Test_Rego_EnforceCreateContainer_Capabilities_Drop_NoMatches(t *testing.T) } } +func Test_Regi_EnforceCreateContainer_RequireNoDevices(t *testing.T) { + f := func(p *generatedConstraints) bool { + tc, err := setupSimpleRegoCreateContainerTest(p) + if err != nil { + t.Error(err) + return false + } + + privileged := false + + _, _, _, err = tc.policy.EnforceCreateContainerPolicyV2(p.ctx, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, tc.user, &CreateContainerOptions{ + SandboxID: tc.sandboxID, + Privileged: &privileged, + NoNewPrivileges: &tc.noNewPrivileges, + Groups: tc.groups, + Umask: tc.umask, + Capabilities: tc.capabilities, + SeccompProfileSHA256: tc.seccomp, + LinuxDevices: []oci.LinuxDevice{ + { + Path: "/test", + }, + }, + }) + + return assertDecisionJSONContains(t, err, "devices not supported") + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Regi_EnforceCreateContainer_RequireNoDevices: %v", err) + } +} + func Test_Rego_ExtendDefaultMounts(t *testing.T) { f := func(p *generatedConstraints) bool { tc, err := setupSimpleRegoCreateContainerTest(p) @@ -6103,6 +6220,195 @@ func Test_Rego_Enforce_CreateContainer_RequiredEnvMissingHasErrorMessage(t *test } } +func Test_Rego_EnforceCreateContainer_RejectRevertedOverlayMount(t *testing.T) { + f := func(gc *generatedConstraints, commitOnEnforcementFailure bool) bool { + container := selectContainerFromContainerList(gc.containers, testRand) + securityPolicy := gc.toPolicy() + defaultMounts := generateMounts(testRand) + privilegedMounts := generateMounts(testRand) + + policy, err := newRegoPolicy(securityPolicy.marshalRego(), + toOCIMounts(defaultMounts), + toOCIMounts(privilegedMounts), testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + + containerID := testDataGenerator.uniqueContainerID() + tc, err := createTestContainerSpec(gc, containerID, container, false, policy, defaultMounts, privilegedMounts) + if err != nil { + t.Fatal(err) + } + + layers, err := testDataGenerator.createValidOverlayForContainer(policy, container) + if err != nil { + t.Errorf("Failed to createValidOverlayForContainer: %v", err) + return false + } + + scratchMountTarget := getScratchDiskMountTarget(containerID) + rev := startRevertableSection(t, policy) + err = policy.EnforceRWDeviceMountPolicy(gc.ctx, scratchMountTarget, true, true, "xfs") + if err != nil { + t.Errorf("Failed to EnforceRWDeviceMountPolicy: %v", err) + return false + } + rev.Commit() + + rev = startRevertableSection(t, policy) + overlayTarget := getOverlayMountTarget(containerID) + err = policy.EnforceOverlayMountPolicy(gc.ctx, containerID, layers, overlayTarget) + if err != nil { + t.Errorf("Failed to EnforceOverlayMountPolicy: %v", err) + return false + } + // Simulate a failure by rolling back the overlay mount + rev.Rollback() + + rev = startRevertableSection(t, policy) + _, _, _, err = policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + if err == nil { + t.Errorf("EnforceCreateContainerPolicy should have failed due to missing (reverted) overlay mount") + return false + } + commitOrRollback(rev, commitOnEnforcementFailure) + + // "Retry" overlay mount + rev = startRevertableSection(t, policy) + err = policy.EnforceOverlayMountPolicy(gc.ctx, tc.containerID, layers, overlayTarget) + if err != nil { + t.Errorf("Failed to EnforceOverlayMountPolicy: %v", err) + return false + } + rev.Commit() + + rev = startRevertableSection(t, policy) + _, _, _, err = policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + if err != nil { + t.Errorf("Failed to EnforceCreateContainerPolicy after retrying overlay mount: %v", err) + return false + } + rev.Commit() + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceCreateContainerPolicy_RejectRevertedOverlayMount: %v", err) + } +} + +func Test_Rego_EnforceCreateContainer_RetryEverything(t *testing.T) { + f := func(gc *generatedConstraints, + newContainerID, failScratchMount, testDenyInvalidContainerCreation bool, + ) bool { + container := selectContainerFromContainerList(gc.containers, testRand) + securityPolicy := gc.toPolicy() + defaultMounts := generateMounts(testRand) + privilegedMounts := generateMounts(testRand) + + policy, err := newRegoPolicy(securityPolicy.marshalRego(), + toOCIMounts(defaultMounts), + toOCIMounts(privilegedMounts), testOSType) + if err != nil { + t.Errorf("cannot make rego policy from constraints: %v", err) + return false + } + + containerID := testDataGenerator.uniqueContainerID() + tc, err := createTestContainerSpec(gc, containerID, container, false, policy, defaultMounts, privilegedMounts) + if err != nil { + t.Fatal(err) + } + + scratchMountTarget := getScratchDiskMountTarget(containerID) + rev := startRevertableSection(t, policy) + err = policy.EnforceRWDeviceMountPolicy(gc.ctx, scratchMountTarget, true, true, "xfs") + if err != nil { + t.Errorf("Failed to EnforceRWDeviceMountPolicy: %v", err) + return false + } + + succeedLayerPaths := make([]string, 0) + + if failScratchMount { + rev.Rollback() + } else { + rev.Commit() + + // Simulate one of the layers failing to mount, after which the outside + // gives up on this container and starts over. + layerToErr := testRand.Intn(len(container.Layers)) + for i, layerHash := range container.Layers { + rev := startRevertableSection(t, policy) + target := testDataGenerator.uniqueLayerMountTarget() + err = policy.EnforceDeviceMountPolicy(gc.ctx, target, layerHash) + if err != nil { + t.Errorf("failed to EnforceDeviceMountPolicy: %v", err) + return false + } + if i == layerToErr { + // Simulate a mount failure at this point, which will cause us to rollback + rev.Rollback() + break + } else { + rev.Commit() + succeedLayerPaths = append(succeedLayerPaths, target) + } + } + + for _, layerPath := range succeedLayerPaths { + rev := startRevertableSection(t, policy) + err = policy.EnforceDeviceUnmountPolicy(gc.ctx, layerPath) + if err != nil { + t.Errorf("Failed to EnforceDeviceUnmountPolicy: %v", err) + return false + } + rev.Commit() + } + + rev = startRevertableSection(t, policy) + err = policy.EnforceRWDeviceUnmountPolicy(gc.ctx, scratchMountTarget) + if err != nil { + t.Errorf("Failed to EnforceRWDeviceUnmountPolicy: %v", err) + return false + } + rev.Commit() + } + + if testDenyInvalidContainerCreation { + rev = startRevertableSection(t, policy) + _, _, _, err = policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + if err == nil { + t.Errorf("EnforceCreateContainerPolicy should have failed due to missing (reverted) overlay mount") + } + rev.Rollback() + } + + if newContainerID { + tc.containerID = testDataGenerator.uniqueContainerID() + } + + err = mountImageForContainerWithID(policy, container, tc.containerID) + if err != nil { + t.Errorf("Failed to mount image for container after reverting and retrying: %v", err) + return false + } + _, _, _, err = policy.EnforceCreateContainerPolicy(gc.ctx, tc.sandboxID, tc.containerID, tc.argList, tc.envList, tc.workingDir, tc.mounts, false, tc.noNewPrivileges, tc.user, tc.groups, tc.umask, tc.capabilities, tc.seccomp) + if err != nil { + t.Errorf("Failed to EnforceCreateContainerPolicy after retrying: %v", err) + return false + } + + return true + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50, Rand: testRand}); err != nil { + t.Errorf("Test_Rego_EnforceCreateContainerPolicy_RejectRevertedOverlayMount: %v", err) + } +} + func Test_Rego_ExecInContainerPolicy_RequiredEnvMissingHasErrorMessage(t *testing.T) { constraints := generateConstraints(testRand, 1) container := selectContainerFromContainerList(constraints.containers, testRand) diff --git a/pkg/securitypolicy/securitypolicyenforcer.go b/pkg/securitypolicy/securitypolicyenforcer.go index 6cae97118d..19fddcc744 100644 --- a/pkg/securitypolicy/securitypolicyenforcer.go +++ b/pkg/securitypolicy/securitypolicyenforcer.go @@ -29,6 +29,7 @@ type CreateContainerOptions struct { Umask string Capabilities *oci.LinuxCapabilities SeccompProfileSHA256 string + LinuxDevices []oci.LinuxDevice } type SignalContainerOptions struct { IsInitProcess bool @@ -57,6 +58,14 @@ func init() { registeredEnforcers[openDoorEnforcerName] = createOpenDoorEnforcer } +// Represents an in-progress revertable section. To ensure state is consistent, +// Commit() and Rollback() must not fail, so they do not return anything, and if +// an error does occur they should panic. +type RevertableSectionHandle interface { + Commit() + Rollback() +} + type SecurityPolicyEnforcer interface { EnforceDeviceMountPolicy(ctx context.Context, target string, deviceHash string) (err error) EnforceRWDeviceMountPolicy(ctx context.Context, target string, encrypted, ensureFilesystem bool, filesystem string) (err error) @@ -128,6 +137,7 @@ type SecurityPolicyEnforcer interface { GetUserInfo(spec *oci.Process, rootPath string) (IDName, []IDName, string, error) EnforceVerifiedCIMsPolicy(ctx context.Context, containerID string, layerHashes []string) (err error) EnforceRegistryChangesPolicy(ctx context.Context, containerID string, registryValues interface{}) error + StartRevertableSection() (RevertableSectionHandle, error) } //nolint:unused @@ -182,6 +192,11 @@ func CreateSecurityPolicyEnforcer( } } +type nopRevertableSectionHandle struct{} + +func (nopRevertableSectionHandle) Commit() {} +func (nopRevertableSectionHandle) Rollback() {} + type OpenDoorSecurityPolicyEnforcer struct { encodedSecurityPolicy string } @@ -324,6 +339,10 @@ func (OpenDoorSecurityPolicyEnforcer) EnforceRegistryChangesPolicy(ctx context.C return nil } +func (*OpenDoorSecurityPolicyEnforcer) StartRevertableSection() (RevertableSectionHandle, error) { + return nopRevertableSectionHandle{}, nil +} + type ClosedDoorSecurityPolicyEnforcer struct{} var _ SecurityPolicyEnforcer = (*ClosedDoorSecurityPolicyEnforcer)(nil) @@ -452,3 +471,7 @@ func (ClosedDoorSecurityPolicyEnforcer) EnforceVerifiedCIMsPolicy(ctx context.Co func (ClosedDoorSecurityPolicyEnforcer) EnforceRegistryChangesPolicy(ctx context.Context, containerID string, registryValues interface{}) error { return errors.New("registry changes are denied by policy") } + +func (*ClosedDoorSecurityPolicyEnforcer) StartRevertableSection() (RevertableSectionHandle, error) { + return nopRevertableSectionHandle{}, nil +} diff --git a/pkg/securitypolicy/securitypolicyenforcer_rego.go b/pkg/securitypolicy/securitypolicyenforcer_rego.go index 1da5d78a0e..88af50bf5f 100644 --- a/pkg/securitypolicy/securitypolicyenforcer_rego.go +++ b/pkg/securitypolicy/securitypolicyenforcer_rego.go @@ -12,8 +12,10 @@ import ( "regexp" "slices" "strings" + "sync" "syscall" + "github.com/Microsoft/hcsshim/internal/gcs" "github.com/Microsoft/hcsshim/internal/guestpath" hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2" "github.com/Microsoft/hcsshim/internal/log" @@ -58,6 +60,10 @@ type regoEnforcer struct { maxErrorMessageLength int // OS type osType string + // Mutex to ensure only one revertable section is active + revertableSectionLock sync.Mutex + // Saved metadata for the revertable section + savedMetadata rpi.SavedMetadata } var _ SecurityPolicyEnforcer = (*regoEnforcer)(nil) @@ -710,6 +716,7 @@ func (policy *regoEnforcer) EnforceCreateContainerPolicy( Umask: umask, Capabilities: capabilities, SeccompProfileSHA256: seccompProfileSHA256, + LinuxDevices: []oci.LinuxDevice{}, } return policy.EnforceCreateContainerPolicyV2(ctx, containerID, argList, envList, workingDir, mounts, user, opts) } @@ -747,6 +754,7 @@ func (policy *regoEnforcer) EnforceCreateContainerPolicyV2( "sandboxDir": SandboxMountsDir(opts.SandboxID), "hugePagesDir": HugePagesMountsDir(opts.SandboxID), "mounts": appendMountData([]interface{}{}, mounts), + "devices": appendDeviceData([]interface{}{}, opts.LinuxDevices), "privileged": opts.Privileged, "noNewPrivileges": opts.NoNewPrivileges, "user": user.toInput(), @@ -835,6 +843,22 @@ func appendMountData(mountData []interface{}, mounts []oci.Mount) []interface{} return mountData } +func appendDeviceData(deviceData []interface{}, devices []oci.LinuxDevice) []interface{} { + for _, device := range devices { + deviceData = append(deviceData, inputData{ + "path": device.Path, + "type": device.Type, + "major": device.Major, + "minor": device.Minor, + "fileMode": device.FileMode, + "uid": device.UID, + "gid": device.GID, + }) + } + + return deviceData +} + func (policy *regoEnforcer) ExtendDefaultMounts(mounts []oci.Mount) error { policy.defaultMounts = append(policy.defaultMounts, mounts...) defaultMounts := appendMountData([]interface{}{}, policy.defaultMounts) @@ -1190,3 +1214,81 @@ func (policy *regoEnforcer) EnforceRegistryChangesPolicy(ctx context.Context, co func (policy *regoEnforcer) GetUserInfo(process *oci.Process, rootPath string) (IDName, []IDName, string, error) { return GetAllUserInfo(process, rootPath) } + +type revertableSectionHandle struct { + // policy is cleared once this struct is "used", to prevent accidental + // duplicate Commit/Rollback calls. + policy *regoEnforcer +} + +func (policy *regoEnforcer) inRevertableSection() bool { + succ := policy.revertableSectionLock.TryLock() + if succ { + // since nobody else has the lock, we're not in fact in a revertable + // section. + policy.revertableSectionLock.Unlock() + return false + } + // somebody else (i.e. the caller) has the lock, so we're in a revertable + // section. Don't unlock it here! + return true +} + +// Starts a revertable section by saving the current policy state. If another +// revertable section is already active, this will wait until that one is +// finished. +func (policy *regoEnforcer) StartRevertableSection() (RevertableSectionHandle, error) { + policy.revertableSectionLock.Lock() + var err error + policy.savedMetadata, err = policy.rego.SaveMetadata() + if err != nil { + err = errors.Wrapf(err, "unable to save metadata for revertable section") + policy.revertableSectionLock.Unlock() + return &revertableSectionHandle{}, err + } + // Keep policy.revertableSectionLock locked until the end of the section. + sh := &revertableSectionHandle{ + policy: policy, + } + return sh, nil +} + +func (sh *revertableSectionHandle) Commit() { + if sh.policy == nil { + gcs.UnrecoverableError(errors.New("revertable section handle already used")) + } + + policy := sh.policy + sh.policy = nil + lockSucc := policy.revertableSectionLock.TryLock() + if lockSucc { + gcs.UnrecoverableError(errors.New("not in a revertable section")) + } else { + // somebody else (i.e. the caller) has the lock, so we're in a revertable + // section. Clear the saved metadata just in case, then unlock to exit the + // section. + policy.savedMetadata = rpi.SavedMetadata{} + policy.revertableSectionLock.Unlock() + } +} + +func (sh *revertableSectionHandle) Rollback() { + if sh.policy == nil { + gcs.UnrecoverableError(errors.New("revertable section handle already used")) + } + + policy := sh.policy + sh.policy = nil + lockSucc := policy.revertableSectionLock.TryLock() + if lockSucc { + gcs.UnrecoverableError(errors.New("not in a revertable section")) + } else { + // somebody else (i.e. the caller) has the lock, so we're in a revertable + // section. Restore the saved metadata, then unlock to exit the section. + err := policy.rego.RestoreMetadata(policy.savedMetadata) + if err != nil { + gcs.UnrecoverableError(errors.Wrap(err, "unable to restore metadata for revertable section")) + } + policy.revertableSectionLock.Unlock() + } +}