Quan Tian 799d3cbf4c Fix race condition in GetCurrentNS
In GetCurrentNS, If there is a context-switch between
getCurrentThreadNetNSPath and GetNS, another goroutine may execute in
the original thread and change its network namespace, then the original
goroutine would get the updated network namespace, which could lead to
unexpected behavior, especially when GetCurrentNS is used to get the
host network namespace in netNS.Do.

The added test has a chance to reproduce it with "-count=50".

The patch fixes it by locking the thread in GetCurrentNS.

Signed-off-by: Quan Tian <qtian@vmware.com>
2020-08-21 13:05:21 +08:00

291 lines
7.9 KiB
Go

// Copyright 2016-2018 CNI authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ns_test
import (
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"sync"
"github.com/containernetworking/plugins/pkg/ns"
"github.com/containernetworking/plugins/pkg/testutils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"golang.org/x/sys/unix"
)
func getInodeCurNetNS() (uint64, error) {
curNS, err := ns.GetCurrentNS()
if err != nil {
return 0, err
}
defer curNS.Close()
return getInodeNS(curNS)
}
func getInodeNS(netns ns.NetNS) (uint64, error) {
return getInodeFd(int(netns.Fd()))
}
func getInode(path string) (uint64, error) {
file, err := os.Open(path)
if err != nil {
return 0, err
}
defer file.Close()
return getInodeFd(int(file.Fd()))
}
func getInodeFd(fd int) (uint64, error) {
stat := &unix.Stat_t{}
err := unix.Fstat(fd, stat)
return stat.Ino, err
}
var _ = Describe("Linux namespace operations", func() {
Describe("WithNetNS", func() {
var (
originalNetNS ns.NetNS
targetNetNS ns.NetNS
)
BeforeEach(func() {
var err error
originalNetNS, err = testutils.NewNS()
Expect(err).NotTo(HaveOccurred())
targetNetNS, err = testutils.NewNS()
Expect(err).NotTo(HaveOccurred())
})
AfterEach(func() {
targetNetNS.Close()
originalNetNS.Close()
Expect(testutils.UnmountNS(targetNetNS)).To(Succeed())
Expect(testutils.UnmountNS(originalNetNS)).To(Succeed())
})
It("executes the callback within the target network namespace", func() {
expectedInode, err := getInodeNS(targetNetNS)
Expect(err).NotTo(HaveOccurred())
err = targetNetNS.Do(func(ns.NetNS) error {
defer GinkgoRecover()
actualInode, err := getInodeCurNetNS()
Expect(err).NotTo(HaveOccurred())
Expect(actualInode).To(Equal(expectedInode))
return nil
})
Expect(err).NotTo(HaveOccurred())
})
It("provides the original namespace as the argument to the callback", func() {
// Ensure we start in originalNetNS
err := originalNetNS.Do(func(ns.NetNS) error {
defer GinkgoRecover()
origNSInode, err := getInodeNS(originalNetNS)
Expect(err).NotTo(HaveOccurred())
err = targetNetNS.Do(func(hostNS ns.NetNS) error {
defer GinkgoRecover()
hostNSInode, err := getInodeNS(hostNS)
Expect(err).NotTo(HaveOccurred())
Expect(hostNSInode).To(Equal(origNSInode))
return nil
})
Expect(err).NotTo(HaveOccurred())
return nil
})
Expect(err).NotTo(HaveOccurred())
})
Context("when called concurrently", func() {
It("provides the original namespace as the argument to the callback", func() {
concurrency := 200
origNS, err := ns.GetCurrentNS()
Expect(err).NotTo(HaveOccurred())
origNSInode, err := getInodeNS(origNS)
Expect(err).NotTo(HaveOccurred())
var wg sync.WaitGroup
wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
go func() {
defer wg.Done()
targetNetNS.Do(func(hostNS ns.NetNS) error {
defer GinkgoRecover()
hostNSInode, err := getInodeNS(hostNS)
Expect(err).NotTo(HaveOccurred())
Expect(hostNSInode).To(Equal(origNSInode))
return nil
})
}()
}
wg.Wait()
})
})
Context("when the callback returns an error", func() {
It("restores the calling thread to the original namespace before returning", func() {
err := originalNetNS.Do(func(ns.NetNS) error {
defer GinkgoRecover()
preTestInode, err := getInodeCurNetNS()
Expect(err).NotTo(HaveOccurred())
_ = targetNetNS.Do(func(ns.NetNS) error {
return errors.New("potato")
})
postTestInode, err := getInodeCurNetNS()
Expect(err).NotTo(HaveOccurred())
Expect(postTestInode).To(Equal(preTestInode))
return nil
})
Expect(err).NotTo(HaveOccurred())
})
It("returns the error from the callback", func() {
err := targetNetNS.Do(func(ns.NetNS) error {
return errors.New("potato")
})
Expect(err).To(MatchError("potato"))
})
})
Describe("validating inode mapping to namespaces", func() {
It("checks that different namespaces have different inodes", func() {
origNSInode, err := getInodeNS(originalNetNS)
Expect(err).NotTo(HaveOccurred())
testNsInode, err := getInodeNS(targetNetNS)
Expect(err).NotTo(HaveOccurred())
Expect(testNsInode).NotTo(Equal(0))
Expect(testNsInode).NotTo(Equal(origNSInode))
})
It("should not leak a closed netns onto any threads in the process", func() {
By("creating a new netns")
createdNetNS, err := testutils.NewNS()
Expect(err).NotTo(HaveOccurred())
By("discovering the inode of the created netns")
createdNetNSInode, err := getInodeNS(createdNetNS)
Expect(err).NotTo(HaveOccurred())
createdNetNS.Close()
Expect(testutils.UnmountNS(createdNetNS)).NotTo(HaveOccurred())
By("comparing against the netns inode of every thread in the process")
for _, netnsPath := range allNetNSInCurrentProcess() {
netnsInode, err := getInode(netnsPath)
Expect(err).NotTo(HaveOccurred())
Expect(netnsInode).NotTo(Equal(createdNetNSInode))
}
})
It("fails when the path is not a namespace", func() {
tempFile, err := ioutil.TempFile("", "nstest")
Expect(err).NotTo(HaveOccurred())
defer tempFile.Close()
nspath := tempFile.Name()
defer os.Remove(nspath)
_, err = ns.GetNS(nspath)
Expect(err).To(HaveOccurred())
Expect(err).To(BeAssignableToTypeOf(ns.NSPathNotNSErr{}))
Expect(err).NotTo(BeAssignableToTypeOf(ns.NSPathNotExistErr{}))
})
})
Describe("closing a network namespace", func() {
It("should prevent further operations", func() {
createdNetNS, err := testutils.NewNS()
defer testutils.UnmountNS(createdNetNS)
Expect(err).NotTo(HaveOccurred())
err = createdNetNS.Close()
Expect(err).NotTo(HaveOccurred())
err = createdNetNS.Do(func(ns.NetNS) error { return nil })
Expect(err).To(HaveOccurred())
err = createdNetNS.Set()
Expect(err).To(HaveOccurred())
})
It("should only work once", func() {
createdNetNS, err := testutils.NewNS()
Expect(err).NotTo(HaveOccurred())
defer testutils.UnmountNS(createdNetNS)
err = createdNetNS.Close()
Expect(err).NotTo(HaveOccurred())
err = createdNetNS.Close()
Expect(err).To(HaveOccurred())
})
})
})
Describe("IsNSorErr", func() {
It("should detect a namespace", func() {
createdNetNS, err := testutils.NewNS()
Expect(err).NotTo(HaveOccurred())
defer testutils.UnmountNS(createdNetNS)
err = ns.IsNSorErr(createdNetNS.Path())
Expect(err).NotTo(HaveOccurred())
})
It("should refuse other paths", func() {
tempFile, err := ioutil.TempFile("", "nstest")
Expect(err).NotTo(HaveOccurred())
defer tempFile.Close()
nspath := tempFile.Name()
defer os.Remove(nspath)
err = ns.IsNSorErr(nspath)
Expect(err).To(HaveOccurred())
Expect(err).To(BeAssignableToTypeOf(ns.NSPathNotNSErr{}))
Expect(err).NotTo(BeAssignableToTypeOf(ns.NSPathNotExistErr{}))
})
It("should error on non-existing paths", func() {
err := ns.IsNSorErr("/tmp/IDoNotExist")
Expect(err).To(HaveOccurred())
Expect(err).To(BeAssignableToTypeOf(ns.NSPathNotExistErr{}))
Expect(err).NotTo(BeAssignableToTypeOf(ns.NSPathNotNSErr{}))
})
})
})
func allNetNSInCurrentProcess() []string {
pid := unix.Getpid()
paths, err := filepath.Glob(fmt.Sprintf("/proc/%d/task/*/ns/net", pid))
Expect(err).NotTo(HaveOccurred())
return paths
}