diff --git a/pkg/ns/ns.go b/pkg/ns/ns.go index 837ab8be..328fff34 100644 --- a/pkg/ns/ns.go +++ b/pkg/ns/ns.go @@ -21,6 +21,7 @@ import ( "path" "runtime" "sync" + "syscall" "golang.org/x/sys/unix" ) @@ -77,11 +78,34 @@ func GetCurrentNS() (NetNS, error) { func GetNS(nspath string) (NetNS, error) { fd, err := os.Open(nspath) if err != nil { + return nil, fmt.Errorf("Failed to open %v: %v", nspath, err) + } + + isNSFS, err := IsNSFS(nspath) + if err != nil { + fd.Close() return nil, err } + if !isNSFS { + fd.Close() + return nil, fmt.Errorf("%v is not of type NSFS", nspath) + } + return &netNS{file: fd}, nil } +// Returns whether or not the nspath argument points to a network namespace +func IsNSFS(nspath string) (bool, error) { + const NSFS_MAGIC = 0x6e736673 + + stat := syscall.Statfs_t{} + if err := syscall.Statfs(nspath, &stat); err != nil { + return false, fmt.Errorf("failed to Statfs %q: %v", nspath, err) + } + + return stat.Type == NSFS_MAGIC, nil +} + // Creates a new persistent network namespace and returns an object // representing that namespace, without switching to it func NewNS() (NetNS, error) { @@ -255,7 +279,7 @@ func (ns *netNS) Set() error { func WithNetNSPath(nspath string, toRun func(NetNS) error) error { ns, err := GetNS(nspath) if err != nil { - return fmt.Errorf("Failed to open %v: %v", nspath, err) + return err } defer ns.Close() return ns.Do(toRun) diff --git a/pkg/ns/ns_test.go b/pkg/ns/ns_test.go index de0f3853..82001ea0 100644 --- a/pkg/ns/ns_test.go +++ b/pkg/ns/ns_test.go @@ -17,6 +17,7 @@ package ns_test import ( "errors" "fmt" + "io/ioutil" "os" "path/filepath" @@ -169,6 +170,18 @@ var _ = Describe("Linux namespace operations", func() { Expect(netnsInode).NotTo(Equal(createdNetNSInode)) } }) + + It("fails when the path is not a namespace", func() { + tempFile, err := ioutil.TempFile("", "nstest") + Expect(err).NotTo(HaveOccurred()) + defer tempFile.Close() + + nspath := tempFile.Name() + defer os.Remove(nspath) + + _, err = ns.GetNS(nspath) + Expect(err).To(MatchError(fmt.Sprintf("%v is not of type NSFS", nspath))) + }) }) Describe("closing a network namespace", func() { @@ -198,6 +211,28 @@ var _ = Describe("Linux namespace operations", func() { }) }) }) + + Describe("IsNSFS", func() { + It("should detect a namespace", func() { + createdNetNS, err := ns.NewNS() + isNSFS, err := ns.IsNSFS(createdNetNS.Path()) + Expect(err).NotTo(HaveOccurred()) + Expect(isNSFS).To(Equal(true)) + }) + + It("should refuse other paths", func() { + tempFile, err := ioutil.TempFile("", "nstest") + Expect(err).NotTo(HaveOccurred()) + defer tempFile.Close() + + nspath := tempFile.Name() + defer os.Remove(nspath) + + isNSFS, err := ns.IsNSFS(nspath) + Expect(err).NotTo(HaveOccurred()) + Expect(isNSFS).To(Equal(false)) + }) + }) }) func allNetNSInCurrentProcess() []string {