Compare commits

..

7 Commits

Author SHA1 Message Date
6fb30a6700 Merge pull request #222 from steveeJ/ns-check-path
pkg/ns: verify netns when initialized with GetNS
2016-05-25 08:54:10 +02:00
d6751cea24 pkg/ns: test IsNSFS() 2016-05-24 22:30:49 +02:00
c43ccc703a pkg/ns: test case for rejecting a non-ns nspath 2016-05-24 22:30:49 +02:00
76ea259ff9 pkg/ns: verify netns when initialized with GetNS 2016-05-24 22:30:49 +02:00
c29cd52628 Merge pull request #223 from steveeJ/ns-respect-close
pkg/ns: don't allow operations after Close()
2016-05-24 22:16:09 +02:00
2de97b7e98 pkg/ns: add tests cases for Close()'d NS 2016-05-24 21:15:51 +02:00
b23895a7c7 pkg/ns: don't allow operations after Close() 2016-05-24 20:52:00 +02:00
2 changed files with 113 additions and 2 deletions

View File

@ -21,6 +21,7 @@ import (
"path"
"runtime"
"sync"
"syscall"
"golang.org/x/sys/unix"
)
@ -58,6 +59,7 @@ type NetNS interface {
type netNS struct {
file *os.File
mounted bool
closed bool
}
func getCurrentThreadNetNSPath() string {
@ -76,11 +78,34 @@ func GetCurrentNS() (NetNS, error) {
func GetNS(nspath string) (NetNS, error) {
fd, err := os.Open(nspath)
if err != nil {
return nil, fmt.Errorf("Failed to open %v: %v", nspath, err)
}
isNSFS, err := IsNSFS(nspath)
if err != nil {
fd.Close()
return nil, err
}
if !isNSFS {
fd.Close()
return nil, fmt.Errorf("%v is not of type NSFS", nspath)
}
return &netNS{file: fd}, nil
}
// Returns whether or not the nspath argument points to a network namespace
func IsNSFS(nspath string) (bool, error) {
const NSFS_MAGIC = 0x6e736673
stat := syscall.Statfs_t{}
if err := syscall.Statfs(nspath, &stat); err != nil {
return false, fmt.Errorf("failed to Statfs %q: %v", nspath, err)
}
return stat.Type == NSFS_MAGIC, nil
}
// Creates a new persistent network namespace and returns an object
// representing that namespace, without switching to it
func NewNS() (NetNS, error) {
@ -165,8 +190,22 @@ func (ns *netNS) Fd() uintptr {
return ns.file.Fd()
}
func (ns *netNS) errorIfClosed() error {
if ns.closed {
return fmt.Errorf("%q has already been closed", ns.file.Name())
}
return nil
}
func (ns *netNS) Close() error {
ns.file.Close()
if err := ns.errorIfClosed(); err != nil {
return err
}
if err := ns.file.Close(); err != nil {
return fmt.Errorf("Failed to close %q: %v", ns.file.Name(), err)
}
ns.closed = true
if ns.mounted {
if err := unix.Unmount(ns.file.Name(), unix.MNT_DETACH); err != nil {
@ -175,11 +214,17 @@ func (ns *netNS) Close() error {
if err := os.RemoveAll(ns.file.Name()); err != nil {
return fmt.Errorf("Failed to clean up namespace %s: %v", ns.file.Name(), err)
}
ns.mounted = false
}
return nil
}
func (ns *netNS) Do(toRun func(NetNS) error) error {
if err := ns.errorIfClosed(); err != nil {
return err
}
containedCall := func(hostNS NetNS) error {
threadNS, err := GetNS(getCurrentThreadNetNSPath())
if err != nil {
@ -218,6 +263,10 @@ func (ns *netNS) Do(toRun func(NetNS) error) error {
}
func (ns *netNS) Set() error {
if err := ns.errorIfClosed(); err != nil {
return err
}
if _, _, err := unix.Syscall(unix.SYS_SETNS, ns.Fd(), uintptr(unix.CLONE_NEWNET), 0); err != 0 {
return fmt.Errorf("Error switching to ns %v: %v", ns.file.Name(), err)
}
@ -230,7 +279,7 @@ func (ns *netNS) Set() error {
func WithNetNSPath(nspath string, toRun func(NetNS) error) error {
ns, err := GetNS(nspath)
if err != nil {
return fmt.Errorf("Failed to open %v: %v", nspath, err)
return err
}
defer ns.Close()
return ns.Do(toRun)

View File

@ -17,6 +17,7 @@ package ns_test
import (
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
@ -169,6 +170,67 @@ var _ = Describe("Linux namespace operations", func() {
Expect(netnsInode).NotTo(Equal(createdNetNSInode))
}
})
It("fails when the path is not a namespace", func() {
tempFile, err := ioutil.TempFile("", "nstest")
Expect(err).NotTo(HaveOccurred())
defer tempFile.Close()
nspath := tempFile.Name()
defer os.Remove(nspath)
_, err = ns.GetNS(nspath)
Expect(err).To(MatchError(fmt.Sprintf("%v is not of type NSFS", nspath)))
})
})
Describe("closing a network namespace", func() {
It("should prevent further operations", func() {
createdNetNS, err := ns.NewNS()
Expect(err).NotTo(HaveOccurred())
err = createdNetNS.Close()
Expect(err).NotTo(HaveOccurred())
err = createdNetNS.Do(func(ns.NetNS) error { return nil })
Expect(err).To(HaveOccurred())
err = createdNetNS.Set()
Expect(err).To(HaveOccurred())
})
It("should only work once", func() {
createdNetNS, err := ns.NewNS()
Expect(err).NotTo(HaveOccurred())
err = createdNetNS.Close()
Expect(err).NotTo(HaveOccurred())
err = createdNetNS.Close()
Expect(err).To(HaveOccurred())
})
})
})
Describe("IsNSFS", func() {
It("should detect a namespace", func() {
createdNetNS, err := ns.NewNS()
isNSFS, err := ns.IsNSFS(createdNetNS.Path())
Expect(err).NotTo(HaveOccurred())
Expect(isNSFS).To(Equal(true))
})
It("should refuse other paths", func() {
tempFile, err := ioutil.TempFile("", "nstest")
Expect(err).NotTo(HaveOccurred())
defer tempFile.Close()
nspath := tempFile.Name()
defer os.Remove(nspath)
isNSFS, err := ns.IsNSFS(nspath)
Expect(err).NotTo(HaveOccurred())
Expect(isNSFS).To(Equal(false))
})
})
})