Merge pull request #222 from steveeJ/ns-check-path

pkg/ns: verify netns when initialized with GetNS
This commit is contained in:
Stefan Junker 2016-05-25 08:54:10 +02:00
commit 6fb30a6700
2 changed files with 60 additions and 1 deletions

View File

@ -21,6 +21,7 @@ import (
"path" "path"
"runtime" "runtime"
"sync" "sync"
"syscall"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -77,11 +78,34 @@ func GetCurrentNS() (NetNS, error) {
func GetNS(nspath string) (NetNS, error) { func GetNS(nspath string) (NetNS, error) {
fd, err := os.Open(nspath) fd, err := os.Open(nspath)
if err != nil { if err != nil {
return nil, fmt.Errorf("Failed to open %v: %v", nspath, err)
}
isNSFS, err := IsNSFS(nspath)
if err != nil {
fd.Close()
return nil, err return nil, err
} }
if !isNSFS {
fd.Close()
return nil, fmt.Errorf("%v is not of type NSFS", nspath)
}
return &netNS{file: fd}, nil return &netNS{file: fd}, nil
} }
// Returns whether or not the nspath argument points to a network namespace
func IsNSFS(nspath string) (bool, error) {
const NSFS_MAGIC = 0x6e736673
stat := syscall.Statfs_t{}
if err := syscall.Statfs(nspath, &stat); err != nil {
return false, fmt.Errorf("failed to Statfs %q: %v", nspath, err)
}
return stat.Type == NSFS_MAGIC, nil
}
// Creates a new persistent network namespace and returns an object // Creates a new persistent network namespace and returns an object
// representing that namespace, without switching to it // representing that namespace, without switching to it
func NewNS() (NetNS, error) { func NewNS() (NetNS, error) {
@ -255,7 +279,7 @@ func (ns *netNS) Set() error {
func WithNetNSPath(nspath string, toRun func(NetNS) error) error { func WithNetNSPath(nspath string, toRun func(NetNS) error) error {
ns, err := GetNS(nspath) ns, err := GetNS(nspath)
if err != nil { if err != nil {
return fmt.Errorf("Failed to open %v: %v", nspath, err) return err
} }
defer ns.Close() defer ns.Close()
return ns.Do(toRun) return ns.Do(toRun)

View File

@ -17,6 +17,7 @@ package ns_test
import ( import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
@ -169,6 +170,18 @@ var _ = Describe("Linux namespace operations", func() {
Expect(netnsInode).NotTo(Equal(createdNetNSInode)) Expect(netnsInode).NotTo(Equal(createdNetNSInode))
} }
}) })
It("fails when the path is not a namespace", func() {
tempFile, err := ioutil.TempFile("", "nstest")
Expect(err).NotTo(HaveOccurred())
defer tempFile.Close()
nspath := tempFile.Name()
defer os.Remove(nspath)
_, err = ns.GetNS(nspath)
Expect(err).To(MatchError(fmt.Sprintf("%v is not of type NSFS", nspath)))
})
}) })
Describe("closing a network namespace", func() { Describe("closing a network namespace", func() {
@ -198,6 +211,28 @@ var _ = Describe("Linux namespace operations", func() {
}) })
}) })
}) })
Describe("IsNSFS", func() {
It("should detect a namespace", func() {
createdNetNS, err := ns.NewNS()
isNSFS, err := ns.IsNSFS(createdNetNS.Path())
Expect(err).NotTo(HaveOccurred())
Expect(isNSFS).To(Equal(true))
})
It("should refuse other paths", func() {
tempFile, err := ioutil.TempFile("", "nstest")
Expect(err).NotTo(HaveOccurred())
defer tempFile.Close()
nspath := tempFile.Name()
defer os.Remove(nspath)
isNSFS, err := ns.IsNSFS(nspath)
Expect(err).NotTo(HaveOccurred())
Expect(isNSFS).To(Equal(false))
})
})
}) })
func allNetNSInCurrentProcess() []string { func allNetNSInCurrentProcess() []string {