diff --git a/pkg/ns/ns_linux.go b/pkg/ns/ns_linux.go index a34f9717..3b745d49 100644 --- a/pkg/ns/ns_linux.go +++ b/pkg/ns/ns_linux.go @@ -26,6 +26,11 @@ import ( // Returns an object representing the current OS thread's network namespace func GetCurrentNS() (NetNS, error) { + // Lock the thread in case other goroutine executes in it and changes its + // network namespace after getCurrentThreadNetNSPath(), otherwise it might + // return an unexpected network namespace. + runtime.LockOSThread() + defer runtime.UnlockOSThread() return GetNS(getCurrentThreadNetNSPath()) } diff --git a/pkg/ns/ns_linux_test.go b/pkg/ns/ns_linux_test.go index 26e1ecd1..3559410c 100644 --- a/pkg/ns/ns_linux_test.go +++ b/pkg/ns/ns_linux_test.go @@ -20,6 +20,7 @@ import ( "io/ioutil" "os" "path/filepath" + "sync" "github.com/containernetworking/plugins/pkg/ns" "github.com/containernetworking/plugins/pkg/testutils" @@ -118,6 +119,33 @@ var _ = Describe("Linux namespace operations", func() { Expect(err).NotTo(HaveOccurred()) }) + Context("when called concurrently", func() { + It("provides the original namespace as the argument to the callback", func() { + concurrency := 200 + origNS, err := ns.GetCurrentNS() + Expect(err).NotTo(HaveOccurred()) + origNSInode, err := getInodeNS(origNS) + Expect(err).NotTo(HaveOccurred()) + + var wg sync.WaitGroup + wg.Add(concurrency) + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + targetNetNS.Do(func(hostNS ns.NetNS) error { + defer GinkgoRecover() + + hostNSInode, err := getInodeNS(hostNS) + Expect(err).NotTo(HaveOccurred()) + Expect(hostNSInode).To(Equal(origNSInode)) + return nil + }) + }() + } + wg.Wait() + }) + }) + Context("when the callback returns an error", func() { It("restores the calling thread to the original namespace before returning", func() { err := originalNetNS.Do(func(ns.NetNS) error {