diff --git a/pkg/utils/iptables.go b/pkg/utils/iptables.go new file mode 100644 index 00000000..f1a61696 --- /dev/null +++ b/pkg/utils/iptables.go @@ -0,0 +1,64 @@ +// Copyright 2017 CNI authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "errors" + "fmt" + + "github.com/coreos/go-iptables/iptables" +) + +const statusChainExists = 1 + +// EnsureChain idempotently creates the iptables chain. It does not +// return an error if the chain already exists. +func EnsureChain(ipt *iptables.IPTables, table, chain string) error { + if ipt == nil { + return errors.New("failed to ensure iptable chain: IPTables was nil") + } + exists, err := ChainExists(ipt, table, chain) + if err != nil { + return fmt.Errorf("failed to list iptables chains: %v", err) + } + if !exists { + err = ipt.NewChain(table, chain) + if err != nil { + eerr, eok := err.(*iptables.Error) + if eok && eerr.ExitStatus() != statusChainExists { + return err + } + } + } + return nil +} + +// ChainExists checks whether an iptables chain exists. +func ChainExists(ipt *iptables.IPTables, table, chain string) (bool, error) { + if ipt == nil { + return false, errors.New("failed to check iptable chain: IPTables was nil") + } + chains, err := ipt.ListChains(table) + if err != nil { + return false, err + } + + for _, ch := range chains { + if ch == chain { + return true, nil + } + } + return false, nil +} diff --git a/pkg/utils/iptables_test.go b/pkg/utils/iptables_test.go new file mode 100644 index 00000000..e293f22a --- /dev/null +++ b/pkg/utils/iptables_test.go @@ -0,0 +1,78 @@ +// Copyright 2017-2018 CNI authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "fmt" + "math/rand" + "runtime" + + "github.com/containernetworking/plugins/pkg/ns" + "github.com/containernetworking/plugins/pkg/testutils" + "github.com/coreos/go-iptables/iptables" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +const TABLE = "filter" // We'll monkey around here + +var _ = Describe("chain tests", func() { + var testChain string + var ipt *iptables.IPTables + var cleanup func() + + BeforeEach(func() { + + // Save a reference to the original namespace, + // Add a new NS + currNs, err := ns.GetCurrentNS() + Expect(err).NotTo(HaveOccurred()) + + testNs, err := testutils.NewNS() + Expect(err).NotTo(HaveOccurred()) + + testChain = fmt.Sprintf("cni-test-%d", rand.Intn(10000000)) + + ipt, err = iptables.NewWithProtocol(iptables.ProtocolIPv4) + Expect(err).NotTo(HaveOccurred()) + + runtime.LockOSThread() + err = testNs.Set() + Expect(err).NotTo(HaveOccurred()) + + cleanup = func() { + if ipt == nil { + return + } + ipt.ClearChain(TABLE, testChain) + ipt.DeleteChain(TABLE, testChain) + currNs.Set() + } + + }) + + AfterEach(func() { + cleanup() + }) + + It("creates chains idempotently", func() { + err := EnsureChain(ipt, TABLE, testChain) + Expect(err).NotTo(HaveOccurred()) + + // Create it again! + err = EnsureChain(ipt, TABLE, testChain) + Expect(err).NotTo(HaveOccurred()) + }) +}) diff --git a/plugins/meta/firewall/firewall_iptables_test.go b/plugins/meta/firewall/firewall_iptables_test.go index 6c7358f4..ac474dc8 100644 --- a/plugins/meta/firewall/firewall_iptables_test.go +++ b/plugins/meta/firewall/firewall_iptables_test.go @@ -270,6 +270,13 @@ var _ = Describe("firewall plugin iptables backend", func() { Expect(err).NotTo(HaveOccurred()) validateFullRuleset(fullConf) + + // ensure creation is idempotent + _, _, err = testutils.CmdAdd(targetNS.Path(), args.ContainerID, IFNAME, fullConf, func() error { + return cmdAdd(args) + }) + Expect(err).NotTo(HaveOccurred()) + return nil }) Expect(err).NotTo(HaveOccurred()) diff --git a/plugins/meta/firewall/iptables.go b/plugins/meta/firewall/iptables.go index faae35c6..857584e6 100644 --- a/plugins/meta/firewall/iptables.go +++ b/plugins/meta/firewall/iptables.go @@ -22,6 +22,7 @@ import ( "net" "github.com/containernetworking/cni/pkg/types/current" + "github.com/containernetworking/plugins/pkg/utils" "github.com/coreos/go-iptables/iptables" ) @@ -32,20 +33,6 @@ func getPrivChainRules(ip string) [][]string { return rules } -func ensureChain(ipt *iptables.IPTables, table, chain string) error { - chains, err := ipt.ListChains(table) - if err != nil { - return fmt.Errorf("failed to list iptables chains: %v", err) - } - for _, ch := range chains { - if ch == chain { - return nil - } - } - - return ipt.NewChain(table, chain) -} - func generateFilterRule(privChainName string) []string { return []string{"-m", "comment", "--comment", "CNI firewall plugin rules", "-j", privChainName} } @@ -73,10 +60,10 @@ func (ib *iptablesBackend) setupChains(ipt *iptables.IPTables) error { adminRule := generateFilterRule(ib.adminChainName) // Ensure our private chains exist - if err := ensureChain(ipt, "filter", ib.privChainName); err != nil { + if err := utils.EnsureChain(ipt, "filter", ib.privChainName); err != nil { return err } - if err := ensureChain(ipt, "filter", ib.adminChainName); err != nil { + if err := utils.EnsureChain(ipt, "filter", ib.adminChainName); err != nil { return err } @@ -160,10 +147,10 @@ func (ib *iptablesBackend) checkRules(conf *FirewallNetConf, result *current.Res } // Ensure our private chains exist - if err := ensureChain(ipt, "filter", ib.privChainName); err != nil { + if err := utils.EnsureChain(ipt, "filter", ib.privChainName); err != nil { return err } - if err := ensureChain(ipt, "filter", ib.adminChainName); err != nil { + if err := utils.EnsureChain(ipt, "filter", ib.adminChainName); err != nil { return err } diff --git a/plugins/meta/portmap/chain.go b/plugins/meta/portmap/chain.go index bca8214a..8e8dbe3f 100644 --- a/plugins/meta/portmap/chain.go +++ b/plugins/meta/portmap/chain.go @@ -18,6 +18,7 @@ import ( "fmt" "strings" + "github.com/containernetworking/plugins/pkg/utils" "github.com/coreos/go-iptables/iptables" "github.com/mattn/go-shellwords" ) @@ -35,16 +36,11 @@ type chain struct { // setup idempotently creates the chain. It will not error if the chain exists. func (c *chain) setup(ipt *iptables.IPTables) error { - // create the chain - exists, err := chainExists(ipt, c.table, c.name) + + err := utils.EnsureChain(ipt, c.table, c.name) if err != nil { return err } - if !exists { - if err := ipt.NewChain(c.table, c.name); err != nil { - return err - } - } // Add the rules to the chain for _, rule := range c.rules { @@ -125,24 +121,10 @@ func insertUnique(ipt *iptables.IPTables, table, chain string, prepend bool, rul } } -func chainExists(ipt *iptables.IPTables, tableName, chainName string) (bool, error) { - chains, err := ipt.ListChains(tableName) - if err != nil { - return false, err - } - - for _, ch := range chains { - if ch == chainName { - return true, nil - } - } - return false, nil -} - // check the chain. func (c *chain) check(ipt *iptables.IPTables) error { - exists, err := chainExists(ipt, c.table, c.name) + exists, err := utils.ChainExists(ipt, c.table, c.name) if err != nil { return err } diff --git a/plugins/meta/portmap/portmap.go b/plugins/meta/portmap/portmap.go index bd0be1fb..16d7d428 100644 --- a/plugins/meta/portmap/portmap.go +++ b/plugins/meta/portmap/portmap.go @@ -124,7 +124,7 @@ func checkPorts(config *PortMapConf, containerIP net.IP) error { } if ip4t != nil { - exists, err := chainExists(ip4t, dnatChain.table, dnatChain.name) + exists, err := utils.ChainExists(ip4t, dnatChain.table, dnatChain.name) if err != nil { return err } @@ -137,7 +137,7 @@ func checkPorts(config *PortMapConf, containerIP net.IP) error { } if ip6t != nil { - exists, err := chainExists(ip6t, dnatChain.table, dnatChain.name) + exists, err := utils.ChainExists(ip6t, dnatChain.table, dnatChain.name) if err != nil { return err }