From 799d3cbf4cf17ccfc37a0e764236e5033ae6aed1 Mon Sep 17 00:00:00 2001 From: Quan Tian Date: Fri, 21 Aug 2020 12:28:52 +0800 Subject: [PATCH] Fix race condition in GetCurrentNS In GetCurrentNS, If there is a context-switch between getCurrentThreadNetNSPath and GetNS, another goroutine may execute in the original thread and change its network namespace, then the original goroutine would get the updated network namespace, which could lead to unexpected behavior, especially when GetCurrentNS is used to get the host network namespace in netNS.Do. The added test has a chance to reproduce it with "-count=50". The patch fixes it by locking the thread in GetCurrentNS. Signed-off-by: Quan Tian --- pkg/ns/ns_linux.go | 5 +++++ pkg/ns/ns_linux_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/pkg/ns/ns_linux.go b/pkg/ns/ns_linux.go index a34f9717..3b745d49 100644 --- a/pkg/ns/ns_linux.go +++ b/pkg/ns/ns_linux.go @@ -26,6 +26,11 @@ import ( // Returns an object representing the current OS thread's network namespace func GetCurrentNS() (NetNS, error) { + // Lock the thread in case other goroutine executes in it and changes its + // network namespace after getCurrentThreadNetNSPath(), otherwise it might + // return an unexpected network namespace. + runtime.LockOSThread() + defer runtime.UnlockOSThread() return GetNS(getCurrentThreadNetNSPath()) } diff --git a/pkg/ns/ns_linux_test.go b/pkg/ns/ns_linux_test.go index 26e1ecd1..3559410c 100644 --- a/pkg/ns/ns_linux_test.go +++ b/pkg/ns/ns_linux_test.go @@ -20,6 +20,7 @@ import ( "io/ioutil" "os" "path/filepath" + "sync" "github.com/containernetworking/plugins/pkg/ns" "github.com/containernetworking/plugins/pkg/testutils" @@ -118,6 +119,33 @@ var _ = Describe("Linux namespace operations", func() { Expect(err).NotTo(HaveOccurred()) }) + Context("when called concurrently", func() { + It("provides the original namespace as the argument to the callback", func() { + concurrency := 200 + origNS, err := ns.GetCurrentNS() + Expect(err).NotTo(HaveOccurred()) + origNSInode, err := getInodeNS(origNS) + Expect(err).NotTo(HaveOccurred()) + + var wg sync.WaitGroup + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + targetNetNS.Do(func(hostNS ns.NetNS) error { + defer GinkgoRecover() + + hostNSInode, err := getInodeNS(hostNS) + Expect(err).NotTo(HaveOccurred()) + Expect(hostNSInode).To(Equal(origNSInode)) + return nil + }) + }() + } + wg.Wait() + }) + }) + Context("when the callback returns an error", func() { It("restores the calling thread to the original namespace before returning", func() { err := originalNetNS.Do(func(ns.NetNS) error {