Skip to content
Open
14 changes: 14 additions & 0 deletions cmd/gcs/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/Microsoft/hcsshim/internal/log"
"github.com/Microsoft/hcsshim/internal/oc"
"github.com/Microsoft/hcsshim/internal/version"
"github.com/Microsoft/hcsshim/pkg/amdsevsnp"
"github.com/Microsoft/hcsshim/pkg/securitypolicy"
)

Expand Down Expand Up @@ -359,9 +360,22 @@ func main() {
logrus.WithError(err).Fatal("failed to initialize new runc runtime")
}
mux := bridge.NewBridgeMux()

forceSequential, err := amdsevsnp.IsSNP()
if err != nil {
// IsSNP cannot fail on LCOW
logrus.Errorf("Got unexpected error from IsSNP(): %v", err)
// If it fails, we proceed with forceSequential enabled to be safe
forceSequential = true
}

b := bridge.Bridge{
Handler: mux,
EnableV4: *v4,

// For confidential containers, we protect ourselves against attacks caused
// by concurrent modifications, by processing one request at a time.
ForceSequential: forceSequential,
}
h := hcsv2.NewHost(rtime, tport, initialEnforcer, logWriter)
// Initialize virtual pod support in the host
Expand Down
53 changes: 53 additions & 0 deletions internal/gcs/unrecoverable_error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package gcs

import (
"context"
"fmt"
"os"
"runtime"
"time"

"github.com/Microsoft/hcsshim/internal/log"
"github.com/Microsoft/hcsshim/pkg/amdsevsnp"
"github.com/sirupsen/logrus"
)

// UnrecoverableError logs the error and then puts the current thread into an
// infinite sleep loop. This is to be used instead of panicking, as the
// behaviour of GCS panics is unpredictable. This function can be extended to,
// for example, try to shutdown the VM cleanly.
func UnrecoverableError(err error) {
buf := make([]byte, 300*(1<<10))
stackSize := runtime.Stack(buf, true)
stackTrace := string(buf[:stackSize])

errPrint := fmt.Sprintf(
"Unrecoverable error in GCS: %v\n%s",
err, stackTrace,
)

isSnp, err := amdsevsnp.IsSNP()
if err != nil {
// IsSNP() cannot fail on LCOW
// but if it does, we proceed as if we're on SNP to be safe.
isSnp = true
}

if isSnp {
errPrint += "\nThis thread will now enter an infinite loop."
}
log.G(context.Background()).WithError(err).Logf(
logrus.FatalLevel,
"%s",
errPrint,
)

if !isSnp {
panic("Unrecoverable error in GCS: " + err.Error())
} else {
fmt.Fprintf(os.Stderr, "%s\n", errPrint)
for {
time.Sleep(time.Hour)
}
}
}
94 changes: 71 additions & 23 deletions internal/guest/bridge/bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ type Bridge struct {
Handler Handler
// EnableV4 enables the v4+ bridge and the schema v2+ interfaces.
EnableV4 bool
// Setting ForceSequential to true will force the bridge to only process one
// request at a time, except for certain long-running operations (as defined
// in asyncMessages).
ForceSequential bool

// responseChan is the response channel used for both request/response
// and publish notification workflows.
Expand All @@ -191,6 +195,14 @@ type Bridge struct {
protVer prot.ProtocolVersion
}

// Messages that will be processed asynchronously even in sequential mode. Note
// that in sequential mode, these messages will still wait for any in-progress
// non-async messages to be handled before they are processed, but once they are
// "acknowledged", the rest will be done asynchronously.
var alwaysAsyncMessages map[prot.MessageIdentifier]bool = map[prot.MessageIdentifier]bool{
prot.ComputeSystemWaitForProcessV1: true,
}

// AssignHandlers creates and assigns the appropriate bridge
// events to be listen for and intercepted on `mux` before forwarding
// to `gcs` for handling.
Expand Down Expand Up @@ -238,6 +250,10 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser
defer close(requestErrChan)
defer bridgeIn.Close()

if b.ForceSequential {
log.G(context.Background()).Info("bridge: ForceSequential enabled")
}

// Receive bridge requests and schedule them to be processed.
go func() {
var recverr error
Expand Down Expand Up @@ -340,30 +356,36 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser
}()
// Process each bridge request async and create the response writer.
go func() {
for req := range requestChan {
go func(r *Request) {
br := bridgeResponse{
ctx: r.Context,
header: &prot.MessageHeader{
Type: prot.GetResponseIdentifier(r.Header.Type),
ID: r.Header.ID,
},
}
resp, err := b.Handler.ServeMsg(r)
if resp == nil {
resp = &prot.MessageResponseBase{}
}
resp.Base().ActivityID = r.ActivityID
if err != nil {
span := trace.FromContext(r.Context)
if span != nil {
oc.SetSpanStatus(span, err)
}
setErrorForResponseBase(resp.Base(), err, "gcs" /* moduleName */)
doOneRequest := func(r *Request) {
br := bridgeResponse{
ctx: r.Context,
header: &prot.MessageHeader{
Type: prot.GetResponseIdentifier(r.Header.Type),
ID: r.Header.ID,
},
}
resp, err := b.Handler.ServeMsg(r)
if resp == nil {
resp = &prot.MessageResponseBase{}
}
resp.Base().ActivityID = r.ActivityID
if err != nil {
span := trace.FromContext(r.Context)
if span != nil {
oc.SetSpanStatus(span, err)
}
br.response = resp
b.responseChan <- br
}(req)
setErrorForResponseBase(resp.Base(), err, "gcs" /* moduleName */)
}
br.response = resp
b.responseChan <- br
}

for req := range requestChan {
if b.ForceSequential && !alwaysAsyncMessages[req.Header.Type] {
runSequentialRequestHandler(req, doOneRequest)
} else {
go doOneRequest(req)
}
}
}()
// Process each bridge response sync. This channel is for request/response and publish workflows.
Expand Down Expand Up @@ -423,6 +445,32 @@ func (b *Bridge) ListenAndServe(bridgeIn io.ReadCloser, bridgeOut io.WriteCloser
}
}

// Do handleFn(r), but prints a warning if handleFn does not, or takes too long
// to return.
func runSequentialRequestHandler(r *Request, handleFn func(*Request)) {
// Note that this is only a context used for triggering the blockage
// warning, the request processing still uses r.Context. We don't want to
// cancel the request handling itself when we reach the 5s timeout.
timeoutCtx, cancel := context.WithTimeout(r.Context, 5*time.Second)
go func() {
<-timeoutCtx.Done()
if errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) {
log.G(timeoutCtx).WithFields(logrus.Fields{
// We want to log those even though we're providing r.Context, since if
// the request never finishes the span end log will never get written,
// and we may therefore not be able to find out about the following info
// otherwise:
"message-type": r.Header.Type.String(),
"message-id": r.Header.ID,
"activity-id": r.ActivityID,
"container-id": r.ContainerID,
}).Warnf("bridge: request processing thread in sequential mode blocked on the current request for more than 5 seconds")
}
}()
defer cancel()
handleFn(r)
}

// PublishNotification writes a specific notification to the bridge.
func (b *Bridge) PublishNotification(n *prot.ContainerNotification) {
ctx, span := oc.StartSpan(context.Background(),
Expand Down
8 changes: 1 addition & 7 deletions internal/guest/bridge/bridge_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,16 +467,10 @@ func (b *Bridge) deleteContainerStateV2(r *Request) (_ RequestResponse, err erro
return nil, errors.Wrapf(err, "failed to unmarshal JSON in message \"%s\"", r.Message)
}

c, err := b.hostState.GetCreatedContainer(request.ContainerID)
err = b.hostState.DeleteContainerState(ctx, request.ContainerID)
if err != nil {
return nil, err
}
// remove container state regardless of delete's success
defer b.hostState.RemoveContainer(request.ContainerID)

if err := c.Delete(ctx); err != nil {
return nil, err
}

return &prot.MessageResponseBase{}, nil
}
13 changes: 13 additions & 0 deletions internal/guest/network/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"time"

Expand All @@ -32,6 +33,18 @@ var (
// maxDNSSearches is limited to 6 in `man 5 resolv.conf`
const maxDNSSearches = 6

var validHostnameRegex = regexp.MustCompile(`^[a-zA-Z0-9_\-\.]{0,255}$`)

// Check that the hostname is safe. This function is less strict than
// technically allowed, but ensures that when the hostname is inserted to
// /etc/hosts, it cannot lead to injection attacks.
func ValidateHostname(hostname string) error {
if !validHostnameRegex.MatchString(hostname) {
return errors.Errorf("hostname %q invalid: must match %s", hostname, validHostnameRegex.String())
}
return nil
}

// GenerateEtcHostsContent generates a /etc/hosts file based on `hostname`.
func GenerateEtcHostsContent(ctx context.Context, hostname string) string {
_, span := oc.StartSpan(ctx, "network::GenerateEtcHostsContent")
Expand Down
35 changes: 35 additions & 0 deletions internal/guest/network/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
Expand Down Expand Up @@ -70,6 +71,40 @@ func Test_GenerateResolvConfContent(t *testing.T) {
}
}

func Test_ValidateHostname(t *testing.T) {
validNames := []string{
"localhost",
"my-hostname",
"my.hostname",
"my-host-name123",
"_underscores.are.allowed.too",
"", // Allow not specifying a hostname
}

invalidNames := []string{
"localhost\n13.104.0.1 ip6-localhost ip6-loopback localhost",
"localhost\n2603:1000::1 ip6-localhost ip6-loopback",
"hello@microsoft.com",
"has space",
"has,comma",
"\x00",
"a\nb",
strings.Repeat("a", 1000),
}

for _, n := range validNames {
if err := ValidateHostname(n); err != nil {
t.Fatalf("expected %q to be valid, got: %v", n, err)
}
}

for _, n := range invalidNames {
if err := ValidateHostname(n); err == nil {
t.Fatalf("expected %q to be invalid, but got nil error", n)
}
}
}

func Test_GenerateEtcHostsContent(t *testing.T) {
type testcase struct {
name string
Expand Down
3 changes: 3 additions & 0 deletions internal/guest/runtime/hcsv2/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ type Container struct {
// and deal with the extra pointer dereferencing overhead.
status atomic.Uint32

// Set to true when the init process for the container has exited
terminated atomic.Bool

// scratchDirPath represents the path inside the UVM where the scratch directory
// of this container is located. Usually, this is either `/run/gcs/c/<containerID>` or
// `/run/gcs/c/<UVMID>/container_<containerID>` if scratch is shared with UVM scratch.
Expand Down
1 change: 1 addition & 0 deletions internal/guest/runtime/hcsv2/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ func newProcess(c *Container, spec *oci.Process, process runtime.Process, pid ui
log.G(ctx).WithError(err).Error("failed to wait for runc process")
}
p.exitCode = exitCode
c.terminated.Store(true)
log.G(ctx).WithField("exitCode", p.exitCode).Debug("process exited")

// Free any process waiters
Expand Down
3 changes: 3 additions & 0 deletions internal/guest/runtime/hcsv2/sandbox_container.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ func setupSandboxContainerSpec(ctx context.Context, id string, spec *oci.Spec) (

// Write the hostname
hostname := spec.Hostname
if err = network.ValidateHostname(hostname); err != nil {
return err
}
if hostname == "" {
var err error
hostname, err = os.Hostname()
Expand Down
3 changes: 3 additions & 0 deletions internal/guest/runtime/hcsv2/standalone_container.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ func setupStandaloneContainerSpec(ctx context.Context, id string, spec *oci.Spec
}()

hostname := spec.Hostname
if err = network.ValidateHostname(hostname); err != nil {
return err
}
if hostname == "" {
var err error
hostname, err = os.Hostname()
Expand Down
Loading
Loading