diff --git a/ns/ns.go b/ns/ns.go index 9996aac7..97c6a1e9 100644 --- a/ns/ns.go +++ b/ns/ns.go @@ -82,11 +82,11 @@ func WithNetNS(ns *os.File, lockThread bool, f func(*os.File) error) error { if err = SetNS(ns, syscall.CLONE_NEWNET); err != nil { return fmt.Errorf("Error switching to ns %v: %v", ns.Name(), err) } + defer SetNS(thisNS, syscall.CLONE_NEWNET) // switch back if err = f(thisNS); err != nil { return err } - // switch back - return SetNS(thisNS, syscall.CLONE_NEWNET) + return nil } diff --git a/ns/ns_suite_test.go b/ns/ns_suite_test.go new file mode 100644 index 00000000..ff26ec2e --- /dev/null +++ b/ns/ns_suite_test.go @@ -0,0 +1,20 @@ +package ns_test + +import ( + "math/rand" + "runtime" + + . "github.com/onsi/ginkgo" + "github.com/onsi/ginkgo/config" + . "github.com/onsi/gomega" + + "testing" +) + +func TestNs(t *testing.T) { + rand.Seed(config.GinkgoConfig.RandomSeed) + runtime.LockOSThread() + + RegisterFailHandler(Fail) + RunSpecs(t, "pkg/ns Suite") +} diff --git a/ns/ns_test.go b/ns/ns_test.go new file mode 100644 index 00000000..ad52f7f4 --- /dev/null +++ b/ns/ns_test.go @@ -0,0 +1,153 @@ +package ns_test + +import ( + "errors" + "fmt" + "math/rand" + "os" + "os/exec" + "path/filepath" + + "golang.org/x/sys/unix" + + "github.com/appc/cni/pkg/ns" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func getInode(path string) (uint64, error) { + file, err := os.Open(path) + if err != nil { + return 0, err + } + defer file.Close() + return getInodeF(file) +} + +func getInodeF(file *os.File) (uint64, error) { + stat := &unix.Stat_t{} + err := unix.Fstat(int(file.Fd()), stat) + return stat.Ino, err +} + +const CurrentNetNS = "/proc/self/ns/net" + +var _ = Describe("Linux namespace operations", func() { + Describe("WithNetNS", func() { + var ( + originalNetNS *os.File + + targetNetNSName string + targetNetNSPath string + targetNetNS *os.File + ) + + BeforeEach(func() { + var err error + originalNetNS, err = os.Open(CurrentNetNS) + Expect(err).NotTo(HaveOccurred()) + + targetNetNSName = fmt.Sprintf("test-netns-%d", rand.Int()) + + err = exec.Command("ip", "netns", "add", targetNetNSName).Run() + Expect(err).NotTo(HaveOccurred()) + + targetNetNSPath = filepath.Join("/var/run/netns/", targetNetNSName) + targetNetNS, err = os.Open(targetNetNSPath) + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + Expect(targetNetNS.Close()).To(Succeed()) + + err := exec.Command("ip", "netns", "del", targetNetNSName).Run() + Expect(err).NotTo(HaveOccurred()) + + Expect(originalNetNS.Close()).To(Succeed()) + }) + + It("executes the callback within the target network namespace", func() { + expectedInode, err := getInode(targetNetNSPath) + Expect(err).NotTo(HaveOccurred()) + + var actualInode uint64 + var innerErr error + err = ns.WithNetNS(targetNetNS, false, func(*os.File) error { + actualInode, innerErr = getInode(CurrentNetNS) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + + Expect(innerErr).NotTo(HaveOccurred()) + Expect(actualInode).To(Equal(expectedInode)) + }) + + It("provides the original namespace as the argument to the callback", func() { + hostNSInode, err := getInode(CurrentNetNS) + Expect(err).NotTo(HaveOccurred()) + + var inputNSInode uint64 + var innerErr error + err = ns.WithNetNS(targetNetNS, false, func(inputNS *os.File) error { + inputNSInode, err = getInodeF(inputNS) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + + Expect(innerErr).NotTo(HaveOccurred()) + Expect(inputNSInode).To(Equal(hostNSInode)) + }) + + It("restores the calling thread to the original network namespace", func() { + preTestInode, err := getInode(CurrentNetNS) + Expect(err).NotTo(HaveOccurred()) + + err = ns.WithNetNS(targetNetNS, false, func(*os.File) error { + return nil + }) + Expect(err).NotTo(HaveOccurred()) + + postTestInode, err := getInode(CurrentNetNS) + Expect(err).NotTo(HaveOccurred()) + + Expect(postTestInode).To(Equal(preTestInode)) + }) + + Context("when the callback returns an error", func() { + It("restores the calling thread to the original namespace before returning", func() { + preTestInode, err := getInode(CurrentNetNS) + Expect(err).NotTo(HaveOccurred()) + + _ = ns.WithNetNS(targetNetNS, false, func(*os.File) error { + return errors.New("potato") + }) + + postTestInode, err := getInode(CurrentNetNS) + Expect(err).NotTo(HaveOccurred()) + + Expect(postTestInode).To(Equal(preTestInode)) + }) + + It("returns the error from the callback", func() { + err := ns.WithNetNS(targetNetNS, false, func(*os.File) error { + return errors.New("potato") + }) + Expect(err).To(MatchError("potato")) + }) + }) + + Describe("validating inode mapping to namespaces", func() { + It("checks that different namespaces have different inodes", func() { + hostNSInode, err := getInode(CurrentNetNS) + Expect(err).NotTo(HaveOccurred()) + + testNsInode, err := getInode(targetNetNSPath) + Expect(err).NotTo(HaveOccurred()) + + Expect(hostNSInode).NotTo(Equal(0)) + Expect(testNsInode).NotTo(Equal(0)) + Expect(testNsInode).NotTo(Equal(hostNSInode)) + }) + }) + }) +})