Compare commits

..

3 Commits
main ... v1.1.1

Author SHA1 Message Date
Matt Dupre
4744ec27b8
Merge pull request #716 from squeed/cp-bugfixes
[release-1.1] Cherry-pick some bugfixes
2022-03-09 09:06:40 -08:00
Fabian Wiesel
b1782e50d7 ipam/dhcp: Fix client id in renew/release
The client id was constructed differently in the acquire
function compared to the release and renew functions,
which caused the dhcp-server to consider it a different client.
This is now encapsulated in a common function.

Signed-off-by: Fabian Wiesel <fabian.wiesel@sap.com>
2022-03-09 17:47:10 +01:00
gojoy
b03deb63a9 call ipam.ExceDel after clean up device in netns
fix #666

Signed-off-by: gojoy <729324352@qq.com>
2022-03-09 17:46:59 +01:00
1781 changed files with 66990 additions and 197388 deletions

View File

@ -1,4 +1,4 @@
FROM alpine:3.21 FROM alpine:3.10
RUN apk add --no-cache curl jq RUN apk add --no-cache curl jq

View File

@ -8,4 +8,4 @@ runs:
using: 'docker' using: 'docker'
image: 'Dockerfile' image: 'Dockerfile'
env: env:
GITHUB_TOKEN: ${{ inputs.token }} GITHUB_TOKEN: ${{ inputs.token }}

View File

@ -27,10 +27,10 @@ curl --request GET \
--header "authorization: Bearer ${GITHUB_TOKEN}" \ --header "authorization: Bearer ${GITHUB_TOKEN}" \
--header "content-type: application/json" | jq '.workflow_runs | max_by(.run_number)' > run.json --header "content-type: application/json" | jq '.workflow_runs | max_by(.run_number)' > run.json
RUN_URL=$(jq -r '.rerun_url' run.json) RERUN_URL=$(jq -r '.rerun_url' run.json)
curl --request POST \ curl --request POST \
--url "${RUN_URL}/rerun-failed-jobs" \ --url "${RERUN_URL}" \
--header "authorization: Bearer ${GITHUB_TOKEN}" \ --header "authorization: Bearer ${GITHUB_TOKEN}" \
--header "content-type: application/json" --header "content-type: application/json"
@ -42,4 +42,4 @@ curl --request POST \
--header "authorization: Bearer ${GITHUB_TOKEN}" \ --header "authorization: Bearer ${GITHUB_TOKEN}" \
--header "accept: application/vnd.github.squirrel-girl-preview+json" \ --header "accept: application/vnd.github.squirrel-girl-preview+json" \
--header "content-type: application/json" \ --header "content-type: application/json" \
--data '{ "content" : "rocket" }' --data '{ "content" : "rocket" }'

View File

@ -1,25 +0,0 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "docker" # See documentation for possible values
directory: "/.github/actions/retest-action" # Location of package manifests
schedule:
interval: "weekly"
- package-ecosystem: "github-actions" # See documentation for possible values
directory: "/" # Location of package manifests
schedule:
interval: "weekly"
- package-ecosystem: "gomod" # See documentation for possible values
directory: "/" # Location of package manifests
schedule:
interval: "weekly"
groups:
golang:
patterns:
- "*"
exclude-patterns:
- "github.com/containernetworking/*"

1
.github/go-version vendored
View File

@ -1 +0,0 @@
1.23

View File

@ -9,7 +9,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Check out code - name: Check out code
uses: actions/checkout@v4 uses: actions/checkout@v2
- name: Re-Test Action - name: Re-Test Action
uses: ./.github/actions/retest-action uses: ./.github/actions/retest-action

View File

@ -1,114 +0,0 @@
---
name: Release binaries
on:
push:
tags:
- 'v*'
jobs:
linux_release:
name: Release linux binaries
runs-on: ubuntu-latest
strategy:
matrix:
goarch: [amd64, arm, arm64, mips64le, ppc64le, riscv64, s390x]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: .github/go-version
- name: Build
env:
GOARCH: ${{ matrix.goarch }}
CGO_ENABLED: 0
run: ./build_linux.sh -ldflags '-extldflags -static -X github.com/containernetworking/plugins/pkg/utils/buildversion.BuildVersion=${{ github.ref_name }}'
- name: COPY files
run: cp README.md LICENSE bin/
- name: Change plugin file ownership
working-directory: ./bin
run: sudo chown -R root:root .
- name: Create dist directory
run: mkdir dist
- name: Create archive file
working-directory: ./bin
run: tar cfzpv ../dist/cni-plugins-linux-${{ matrix.goarch }}-${{ github.ref_name }}.tgz .
- name: Create sha256 checksum
working-directory: ./dist
run: sha256sum cni-plugins-linux-${{ matrix.goarch }}-${{ github.ref_name }}.tgz | tee cni-plugins-linux-${{ matrix.goarch }}-${{ github.ref_name }}.tgz.sha256
- name: Create sha512 checksum
working-directory: ./dist
run: sha512sum cni-plugins-linux-${{ matrix.goarch }}-${{ github.ref_name }}.tgz | tee cni-plugins-linux-${{ matrix.goarch }}-${{ github.ref_name }}.tgz.sha512
- name: Upload binaries to release
uses: svenstaro/upload-release-action@v2
with:
repo_token: ${{ secrets.GITHUB_TOKEN }}
file: ./dist/*
tag: ${{ github.ref }}
overwrite: true
file_glob: true
windows_releases:
name: Release windows binaries
runs-on: ubuntu-latest
strategy:
matrix:
goarch: [amd64]
steps:
- name: Install dos2unix
run: sudo apt-get install dos2unix
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: .github/go-version
- name: Build
env:
GOARCH: ${{ matrix.goarch }}
CGO_ENABLED: 0
run: ./build_windows.sh -ldflags '-extldflags -static -X github.com/containernetworking/plugins/pkg/utils/buildversion.BuildVersion=${{ github.ref_name }}'
- name: COPY files
run: cp README.md LICENSE bin/
- name: Change plugin file ownership
working-directory: ./bin
run: sudo chown -R root:root .
- name: Create dist directory
run: mkdir dist
- name: Create archive file
working-directory: ./bin
run: tar cpfzv ../dist/cni-plugins-windows-${{ matrix.goarch }}-${{ github.ref_name }}.tgz .
- name: Create sha256 checksum
working-directory: ./dist
run: sha256sum cni-plugins-windows-${{ matrix.goarch }}-${{ github.ref_name }}.tgz | tee cni-plugins-windows-${{ matrix.goarch }}-${{ github.ref_name }}.tgz.sha256
- name: Create sha512 checksum
working-directory: ./dist
run: sha512sum cni-plugins-windows-${{ matrix.goarch }}-${{ github.ref_name }}.tgz | tee cni-plugins-windows-${{ matrix.goarch }}-${{ github.ref_name }}.tgz.sha512
- name: Upload binaries to release
uses: svenstaro/upload-release-action@v2
with:
repo_token: ${{ secrets.GITHUB_TOKEN }}
file: ./dist/*
tag: ${{ github.ref }}
overwrite: true
file_glob: true

View File

@ -1,53 +1,23 @@
--- ---
name: test name: test
on: on: ["push", "pull_request"]
pull_request: {}
env: env:
LINUX_ARCHES: "amd64 386 arm arm64 s390x mips64le ppc64le riscv64" GO_VERSION: "1.17"
LINUX_ARCHES: "amd64 386 arm arm64 s390x mips64le ppc64le"
jobs: jobs:
lint:
name: Lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: setup go
uses: actions/setup-go@v5
with:
go-version-file: .github/go-version
- uses: ibiqlik/action-yamllint@v3
with:
format: auto
- uses: golangci/golangci-lint-action@v6
with:
version: v1.61.0
args: -v
verify-vendor:
name: Verify vendor directory
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: .github/go-version
- name: Check module vendoring
run: |
go mod tidy
go mod vendor
test -z "$(git status --porcelain)" || (echo "please run 'go mod tidy && go mod vendor', and submit your changes"; exit 1)
build: build:
name: Build all linux architectures name: Build all linux architectures
needs: lint
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4
- name: setup go - name: setup go
uses: actions/setup-go@v5 uses: actions/setup-go@v2
with: with:
go-version-file: .github/go-version go-version: ${{ env.GO_VERSION }}
- uses: actions/checkout@v2
- name: Build on all supported architectures - name: Build on all supported architectures
run: | run: |
set -e set -e
@ -56,9 +26,9 @@ jobs:
GOARCH=$arch ./build_linux.sh GOARCH=$arch ./build_linux.sh
rm bin/* rm bin/*
done done
test-linux: test-linux:
name: Run tests on Linux amd64 name: Run tests on Linux amd64
needs: build
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Install kernel module - name: Install kernel module
@ -67,25 +37,24 @@ jobs:
sudo apt-get install linux-modules-extra-$(uname -r) sudo apt-get install linux-modules-extra-$(uname -r)
- name: Install nftables - name: Install nftables
run: sudo apt-get install nftables run: sudo apt-get install nftables
- name: Install dnsmasq(dhcp server)
run: |
sudo apt-get install dnsmasq
sudo systemctl disable --now dnsmasq
- uses: actions/checkout@v4
- name: setup go - name: setup go
uses: actions/setup-go@v5 uses: actions/setup-go@v2
with: with:
go-version-file: .github/go-version go-version: ${{ env.GO_VERSION }}
- name: Set up Go for root - name: Set up Go for root
run: | run: |
sudo ln -sf `which go` `sudo which go` || true sudo ln -sf `which go` `sudo which go` || true
sudo go version sudo go version
- uses: actions/checkout@v2
- name: Install test binaries - name: Install test binaries
env:
GO111MODULE: off
run: | run: |
go install github.com/containernetworking/cni/cnitool@latest go get github.com/containernetworking/cni/cnitool
go install github.com/mattn/goveralls@latest go get github.com/mattn/goveralls
go install github.com/modocache/gover@latest go get github.com/modocache/gover
- name: test - name: test
run: PATH=$PATH:$(go env GOPATH)/bin COVERALLS=1 ./test_linux.sh run: PATH=$PATH:$(go env GOPATH)/bin COVERALLS=1 ./test_linux.sh
@ -97,15 +66,15 @@ jobs:
PATH=$PATH:$(go env GOPATH)/bin PATH=$PATH:$(go env GOPATH)/bin
gover gover
goveralls -coverprofile=gover.coverprofile -service=github goveralls -coverprofile=gover.coverprofile -service=github
test-win: test-win:
name: Build and run tests on Windows name: Build and run tests on Windows
needs: build
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- uses: actions/checkout@v4
- name: setup go - name: setup go
uses: actions/setup-go@v5 uses: actions/setup-go@v2
with: with:
go-version-file: .github/go-version go-version: ${{ env.GO_VERSION }}
- uses: actions/checkout@v2
- name: test - name: test
run: bash ./test_windows.sh run: bash ./test_windows.sh

View File

@ -1,44 +0,0 @@
issues:
exclude-rules:
- linters:
- revive
text: "don't use ALL_CAPS in Go names; use CamelCase"
- linters:
- revive
text: " and that stutters;"
- path: '(.+)_test\.go'
text: "dot-imports: should not use dot imports"
linters:
disable:
- errcheck
enable:
- contextcheck
- durationcheck
- gci
- ginkgolinter
- gocritic
- gofumpt
- gosimple
- govet
- ineffassign
- misspell
- nonamedreturns
- predeclared
- revive
- staticcheck
- unconvert
- unparam
- unused
- wastedassign
linters-settings:
gci:
sections:
- standard
- default
- prefix(github.com/containernetworking)
run:
timeout: 5m
modules-download-mode: vendor

View File

@ -1,12 +0,0 @@
extends: default
ignore: |
vendor
rules:
document-start: disable
line-length: disable
truthy:
ignore: |
.github/workflows/*.yml
.github/workflows/*.yaml

View File

@ -7,4 +7,3 @@ This is the official list of the CNI network plugins owners:
- Matt Dupre <matt@tigera.io> (@matthewdupre) - Matt Dupre <matt@tigera.io> (@matthewdupre)
- Michael Cambria <mcambria@redhat.com> (@mccv1r0) - Michael Cambria <mcambria@redhat.com> (@mccv1r0)
- Piotr Skarmuk <piotr.skarmuk@gmail.com> (@jellonek) - Piotr Skarmuk <piotr.skarmuk@gmail.com> (@jellonek)
- Michael Zappa <michael.zappa@gmail.com> (@MikeZappa87)

View File

@ -14,14 +14,13 @@ Read [CONTRIBUTING](CONTRIBUTING.md) for build and test instructions.
* `ptp`: Creates a veth pair. * `ptp`: Creates a veth pair.
* `vlan`: Allocates a vlan device. * `vlan`: Allocates a vlan device.
* `host-device`: Move an already-existing device into a container. * `host-device`: Move an already-existing device into a container.
* `dummy`: Creates a new Dummy device in the container.
#### Windows: Windows specific #### Windows: Windows specific
* `win-bridge`: Creates a bridge, adds the host and the container to it. * `win-bridge`: Creates a bridge, adds the host and the container to it.
* `win-overlay`: Creates an overlay interface to the container. * `win-overlay`: Creates an overlay interface to the container.
### IPAM: IP address allocation ### IPAM: IP address allocation
* `dhcp`: Runs a daemon on the host to make DHCP requests on behalf of the container * `dhcp`: Runs a daemon on the host to make DHCP requests on behalf of the container
* `host-local`: Maintains a local database of allocated IPs * `host-local`: Maintains a local database of allocated IPs
* `static`: Allocate a single static IPv4/IPv6 address to container. It's useful in debugging purpose. * `static`: Allocate a static IPv4/IPv6 addresses to container and it's useful in debugging purpose.
### Meta: other plugins ### Meta: other plugins
* `tuning`: Tweaks sysctl parameters of an existing interface * `tuning`: Tweaks sysctl parameters of an existing interface

View File

@ -1,8 +1,8 @@
#!/usr/bin/env sh #!/usr/bin/env bash
set -e set -e
cd "$(dirname "$0")" cd "$(dirname "$0")"
if [ "$(uname)" = "Darwin" ]; then if [ "$(uname)" == "Darwin" ]; then
export GOOS="${GOOS:-linux}" export GOOS="${GOOS:-linux}"
fi fi

View File

@ -1,4 +1,4 @@
#!/usr/bin/env sh #!/usr/bin/env bash
set -e set -e
cd "$(dirname "$0")" cd "$(dirname "$0")"

70
go.mod
View File

@ -1,54 +1,40 @@
module github.com/containernetworking/plugins module github.com/containernetworking/plugins
go 1.23 go 1.17
require ( require (
github.com/Microsoft/hcsshim v0.12.9 github.com/Microsoft/hcsshim v0.8.20
github.com/alexflint/go-filemutex v1.3.0 github.com/alexflint/go-filemutex v1.1.0
github.com/buger/jsonparser v1.1.1 github.com/buger/jsonparser v1.1.1
github.com/containernetworking/cni v1.2.3 github.com/containernetworking/cni v1.0.1
github.com/coreos/go-iptables v0.8.0 github.com/coreos/go-iptables v0.6.0
github.com/coreos/go-systemd/v22 v22.5.0 github.com/coreos/go-systemd/v22 v22.3.2
github.com/godbus/dbus/v5 v5.1.0 github.com/d2g/dhcp4 v0.0.0-20170904100407-a1d1b6c41b1c
github.com/insomniacslk/dhcp v0.0.0-20240829085014-a3a4c1f04475 github.com/d2g/dhcp4client v1.0.0
github.com/d2g/dhcp4server v0.0.0-20181031114812-7d4a0a7f59a5
github.com/godbus/dbus/v5 v5.0.4
github.com/mattn/go-shellwords v1.0.12 github.com/mattn/go-shellwords v1.0.12
github.com/networkplumbing/go-nft v0.4.0 github.com/networkplumbing/go-nft v0.2.0
github.com/onsi/ginkgo/v2 v2.22.2 github.com/onsi/ginkgo v1.16.4
github.com/onsi/gomega v1.36.2 github.com/onsi/gomega v1.15.0
github.com/opencontainers/selinux v1.11.1 github.com/safchain/ethtool v0.0.0-20210803160452-9aa261dae9b1
github.com/safchain/ethtool v0.5.9 github.com/vishvananda/netlink v1.1.1-0.20210330154013-f5de75959ad5
github.com/vishvananda/netlink v1.3.0 golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e
golang.org/x/sys v0.29.0
sigs.k8s.io/knftables v0.0.18
) )
require ( require (
github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/go-winio v0.4.17 // indirect
github.com/containerd/cgroups/v3 v3.0.3 // indirect github.com/containerd/cgroups v1.0.1 // indirect
github.com/containerd/errdefs v0.3.0 // indirect github.com/fsnotify/fsnotify v1.4.9 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/containerd/typeurl/v2 v2.2.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect
github.com/google/go-cmp v0.6.0 // indirect github.com/nxadm/tail v1.4.8 // indirect
github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/mdlayher/packet v1.1.2 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
github.com/pierrec/lz4/v4 v4.1.21 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect github.com/sirupsen/logrus v1.8.1 // indirect
github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 // indirect github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f // indirect
github.com/vishvananda/netns v0.0.4 // indirect go.opencensus.io v0.22.3 // indirect
go.opencensus.io v0.24.0 // indirect golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 // indirect
golang.org/x/net v0.33.0 // indirect golang.org/x/text v0.3.6 // indirect
golang.org/x/sync v0.10.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
golang.org/x/text v0.21.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
golang.org/x/tools v0.28.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/grpc v1.67.0 // indirect
google.golang.org/protobuf v1.36.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )

953
go.sum

File diff suppressed because it is too large Load Diff

View File

@ -4,7 +4,7 @@
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
@ -14,21 +14,21 @@
package integration_test package integration_test
import ( import (
"bytes"
"fmt" "fmt"
"io"
"log"
"math/rand" "math/rand"
"net"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"bytes"
"io"
"net"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes" "github.com/onsi/gomega/gbytes"
"github.com/onsi/gomega/gexec" "github.com/onsi/gomega/gexec"
@ -61,13 +61,6 @@ var _ = Describe("Basic PTP using cnitool", func() {
netConfPath, err := filepath.Abs("./testdata") netConfPath, err := filepath.Abs("./testdata")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Flush ipam stores to avoid conflicts
err = os.RemoveAll("/tmp/chained-ptp-bandwidth-test")
Expect(err).NotTo(HaveOccurred())
err = os.RemoveAll("/tmp/basic-ptp-test")
Expect(err).NotTo(HaveOccurred())
env = TestEnv([]string{ env = TestEnv([]string{
"CNI_PATH=" + cniPath, "CNI_PATH=" + cniPath,
"NETCONFPATH=" + netConfPath, "NETCONFPATH=" + netConfPath,
@ -90,7 +83,6 @@ var _ = Describe("Basic PTP using cnitool", func() {
env.runInNS(hostNS, cnitoolBin, "add", netName, contNS.LongName()) env.runInNS(hostNS, cnitoolBin, "add", netName, contNS.LongName())
addrOutput := env.runInNS(contNS, "ip", "addr") addrOutput := env.runInNS(contNS, "ip", "addr")
Expect(addrOutput).To(ContainSubstring(expectedIPPrefix)) Expect(addrOutput).To(ContainSubstring(expectedIPPrefix))
env.runInNS(hostNS, cnitoolBin, "del", netName, contNS.LongName()) env.runInNS(hostNS, cnitoolBin, "del", netName, contNS.LongName())
@ -154,14 +146,10 @@ var _ = Describe("Basic PTP using cnitool", func() {
chainedBridgeBandwidthEnv.runInNS(hostNS, cnitoolBin, "del", "network-chain-test", contNS1.LongName()) chainedBridgeBandwidthEnv.runInNS(hostNS, cnitoolBin, "del", "network-chain-test", contNS1.LongName())
basicBridgeEnv.runInNS(hostNS, cnitoolBin, "del", "network-chain-test", contNS2.LongName()) basicBridgeEnv.runInNS(hostNS, cnitoolBin, "del", "network-chain-test", contNS2.LongName())
contNS1.Del()
contNS2.Del()
hostNS.Del()
}) })
It("limits traffic only on the restricted bandwidth veth device", func() { Measure("limits traffic only on the restricted bandwith veth device", func(b Benchmarker) {
ipRegexp := regexp.MustCompile(`10\.1[12]\.2\.\d{1,3}`) ipRegexp := regexp.MustCompile("10\\.1[12]\\.2\\.\\d{1,3}")
By(fmt.Sprintf("adding %s to %s\n\n", "chained-bridge-bandwidth", contNS1.ShortName())) By(fmt.Sprintf("adding %s to %s\n\n", "chained-bridge-bandwidth", contNS1.ShortName()))
chainedBridgeBandwidthEnv.runInNS(hostNS, cnitoolBin, "add", "network-chain-test", contNS1.LongName()) chainedBridgeBandwidthEnv.runInNS(hostNS, cnitoolBin, "add", "network-chain-test", contNS1.LongName())
@ -174,30 +162,31 @@ var _ = Describe("Basic PTP using cnitool", func() {
Expect(basicBridgeIP).To(ContainSubstring("10.11.2.")) Expect(basicBridgeIP).To(ContainSubstring("10.11.2."))
var chainedBridgeBandwidthPort, basicBridgePort int var chainedBridgeBandwidthPort, basicBridgePort int
var err error
By(fmt.Sprintf("starting echo server in %s\n\n", contNS1.ShortName())) By(fmt.Sprintf("starting echo server in %s\n\n", contNS1.ShortName()))
chainedBridgeBandwidthPort, chainedBridgeBandwidthSession = startEchoServerInNamespace(contNS1) chainedBridgeBandwidthPort, chainedBridgeBandwidthSession, err = startEchoServerInNamespace(contNS1)
Expect(err).ToNot(HaveOccurred())
By(fmt.Sprintf("starting echo server in %s\n\n", contNS2.ShortName())) By(fmt.Sprintf("starting echo server in %s\n\n", contNS2.ShortName()))
basicBridgePort, basicBridgeSession = startEchoServerInNamespace(contNS2) basicBridgePort, basicBridgeSession, err = startEchoServerInNamespace(contNS2)
Expect(err).ToNot(HaveOccurred())
packetInBytes := 20000 // The shaper needs to 'warm'. Send enough to cause it to throttle, packetInBytes := 20000 // The shaper needs to 'warm'. Send enough to cause it to throttle,
// balanced by run time. // balanced by run time.
By(fmt.Sprintf("sending tcp traffic to the chained, bridged, traffic shaped container on ip address '%s:%d'\n\n", chainedBridgeIP, chainedBridgeBandwidthPort)) By(fmt.Sprintf("sending tcp traffic to the chained, bridged, traffic shaped container on ip address '%s:%d'\n\n", chainedBridgeIP, chainedBridgeBandwidthPort))
start := time.Now() runtimeWithLimit := b.Time("with chained bridge and bandwidth plugins", func() {
makeTCPClientInNS(hostNS.ShortName(), chainedBridgeIP, chainedBridgeBandwidthPort, packetInBytes) makeTcpClientInNS(hostNS.ShortName(), chainedBridgeIP, chainedBridgeBandwidthPort, packetInBytes)
runtimeWithLimit := time.Since(start) })
log.Printf("Runtime with qos limit %.2f seconds", runtimeWithLimit.Seconds())
By(fmt.Sprintf("sending tcp traffic to the basic bridged container on ip address '%s:%d'\n\n", basicBridgeIP, basicBridgePort)) By(fmt.Sprintf("sending tcp traffic to the basic bridged container on ip address '%s:%d'\n\n", basicBridgeIP, basicBridgePort))
start = time.Now() runtimeWithoutLimit := b.Time("with basic bridged plugin", func() {
makeTCPClientInNS(hostNS.ShortName(), basicBridgeIP, basicBridgePort, packetInBytes) makeTcpClientInNS(hostNS.ShortName(), basicBridgeIP, basicBridgePort, packetInBytes)
runtimeWithoutLimit := time.Since(start) })
log.Printf("Runtime without qos limit %.2f seconds", runtimeWithoutLimit.Seconds())
Expect(runtimeWithLimit).To(BeNumerically(">", runtimeWithoutLimit+1000*time.Millisecond)) Expect(runtimeWithLimit).To(BeNumerically(">", runtimeWithoutLimit+1000*time.Millisecond))
}) }, 1)
}) })
}) })
@ -235,7 +224,7 @@ func (n Namespace) Del() {
(TestEnv{}).run("ip", "netns", "del", string(n)) (TestEnv{}).run("ip", "netns", "del", string(n))
} }
func makeTCPClientInNS(netns string, address string, port int, numBytes int) { func makeTcpClientInNS(netns string, address string, port int, numBytes int) {
payload := bytes.Repeat([]byte{'a'}, numBytes) payload := bytes.Repeat([]byte{'a'}, numBytes)
message := string(payload) message := string(payload)
@ -254,7 +243,7 @@ func makeTCPClientInNS(netns string, address string, port int, numBytes int) {
Expect(string(out)).To(Equal(message)) Expect(string(out)).To(Equal(message))
} }
func startEchoServerInNamespace(netNS Namespace) (int, *gexec.Session) { func startEchoServerInNamespace(netNS Namespace) (int, *gexec.Session, error) {
session, err := startInNetNS(echoServerBinaryPath, netNS) session, err := startInNetNS(echoServerBinaryPath, netNS)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -271,7 +260,7 @@ func startEchoServerInNamespace(netNS Namespace) (int, *gexec.Session) {
io.Copy(GinkgoWriter, io.MultiReader(session.Out, session.Err)) io.Copy(GinkgoWriter, io.MultiReader(session.Out, session.Err))
}() }()
return port, session return port, session, nil
} }
func startInNetNS(binPath string, namespace Namespace) (*gexec.Session, error) { func startInNetNS(binPath string, namespace Namespace) (*gexec.Session, error) {

View File

@ -4,7 +4,7 @@
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
@ -17,7 +17,7 @@ import (
"strings" "strings"
"testing" "testing"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"github.com/onsi/gomega/gexec" "github.com/onsi/gomega/gexec"
) )

View File

@ -6,7 +6,6 @@
"mtu": 512, "mtu": 512,
"ipam": { "ipam": {
"type": "host-local", "type": "host-local",
"subnet": "10.1.2.0/24", "subnet": "10.1.2.0/24"
"dataDir": "/tmp/basic-ptp-test"
} }
} }

View File

@ -8,8 +8,7 @@
"mtu": 512, "mtu": 512,
"ipam": { "ipam": {
"type": "host-local", "type": "host-local",
"subnet": "10.9.2.0/24", "subnet": "10.9.2.0/24"
"dataDir": "/tmp/chained-ptp-bandwidth-test"
} }
}, },
{ {

View File

@ -43,7 +43,7 @@ func TestAnnotate(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
if !reflect.DeepEqual(Annotate(test.existingErr, test.contextMessage), test.expectedErr) { if !reflect.DeepEqual(Annotatef(test.existingErr, test.contextMessage), test.expectedErr) {
t.Errorf("test case %s fails", test.name) t.Errorf("test case %s fails", test.name)
return return
} }

View File

@ -39,7 +39,6 @@ type EndpointInfo struct {
NetworkId string NetworkId string
Gateway net.IP Gateway net.IP
IpAddress net.IP IpAddress net.IP
MacAddress string
} }
// GetSandboxContainerID returns the sandbox ID of this pod. // GetSandboxContainerID returns the sandbox ID of this pod.
@ -249,7 +248,6 @@ func GenerateHcnEndpoint(epInfo *EndpointInfo, n *NetConf) (*hcn.HostComputeEndp
Minor: 0, Minor: 0,
}, },
Name: epInfo.EndpointName, Name: epInfo.EndpointName,
MacAddress: epInfo.MacAddress,
HostComputeNetwork: epInfo.NetworkId, HostComputeNetwork: epInfo.NetworkId,
Dns: hcn.Dns{ Dns: hcn.Dns{
Domain: epInfo.DNS.Domain, Domain: epInfo.DNS.Domain,
@ -282,16 +280,6 @@ func RemoveHcnEndpoint(epName string) error {
} }
return errors.Annotatef(err, "failed to find HostComputeEndpoint %s", epName) return errors.Annotatef(err, "failed to find HostComputeEndpoint %s", epName)
} }
epNamespace, err := hcn.GetNamespaceByID(hcnEndpoint.HostComputeNamespace)
if err != nil && !hcn.IsNotFoundError(err) {
return errors.Annotatef(err, "failed to get HostComputeNamespace %s", epName)
}
if epNamespace != nil {
err = hcn.RemoveNamespaceEndpoint(hcnEndpoint.HostComputeNamespace, hcnEndpoint.Id)
if err != nil && !hcn.IsNotFoundError(err) {
return errors.Annotatef(err,"error removing endpoint: %s from namespace", epName)
}
}
err = hcnEndpoint.Delete() err = hcnEndpoint.Delete()
if err != nil { if err != nil {

View File

@ -14,7 +14,7 @@
package hns package hns
import ( import (
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"testing" "testing"

View File

@ -18,7 +18,7 @@ import (
"net" "net"
"github.com/Microsoft/hcsshim/hcn" "github.com/Microsoft/hcsshim/hcn"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )

View File

@ -19,87 +19,43 @@ import (
"net" "net"
) )
// NextIP returns IP incremented by 1, if IP is invalid, return nil // NextIP returns IP incremented by 1
func NextIP(ip net.IP) net.IP { func NextIP(ip net.IP) net.IP {
normalizedIP := normalizeIP(ip) i := ipToInt(ip)
if normalizedIP == nil { return intToIP(i.Add(i, big.NewInt(1)))
return nil
}
i := ipToInt(normalizedIP)
return intToIP(i.Add(i, big.NewInt(1)), len(normalizedIP) == net.IPv6len)
} }
// PrevIP returns IP decremented by 1, if IP is invalid, return nil // PrevIP returns IP decremented by 1
func PrevIP(ip net.IP) net.IP { func PrevIP(ip net.IP) net.IP {
normalizedIP := normalizeIP(ip) i := ipToInt(ip)
if normalizedIP == nil { return intToIP(i.Sub(i, big.NewInt(1)))
return nil
}
i := ipToInt(normalizedIP)
return intToIP(i.Sub(i, big.NewInt(1)), len(normalizedIP) == net.IPv6len)
} }
// Cmp compares two IPs, returning the usual ordering: // Cmp compares two IPs, returning the usual ordering:
// a < b : -1 // a < b : -1
// a == b : 0 // a == b : 0
// a > b : 1 // a > b : 1
// incomparable : -2
func Cmp(a, b net.IP) int { func Cmp(a, b net.IP) int {
normalizedA := normalizeIP(a) aa := ipToInt(a)
normalizedB := normalizeIP(b) bb := ipToInt(b)
return aa.Cmp(bb)
if len(normalizedA) == len(normalizedB) && len(normalizedA) != 0 {
return ipToInt(normalizedA).Cmp(ipToInt(normalizedB))
}
return -2
} }
func ipToInt(ip net.IP) *big.Int { func ipToInt(ip net.IP) *big.Int {
return big.NewInt(0).SetBytes(ip) if v := ip.To4(); v != nil {
return big.NewInt(0).SetBytes(v)
}
return big.NewInt(0).SetBytes(ip.To16())
} }
func intToIP(i *big.Int, isIPv6 bool) net.IP { func intToIP(i *big.Int) net.IP {
intBytes := i.Bytes() return net.IP(i.Bytes())
if len(intBytes) == net.IPv4len || len(intBytes) == net.IPv6len {
return intBytes
}
if isIPv6 {
return append(make([]byte, net.IPv6len-len(intBytes)), intBytes...)
}
return append(make([]byte, net.IPv4len-len(intBytes)), intBytes...)
} }
// normalizeIP will normalize IP by family, // Network masks off the host portion of the IP
// IPv4 : 4-byte form
// IPv6 : 16-byte form
// others : nil
func normalizeIP(ip net.IP) net.IP {
if ipTo4 := ip.To4(); ipTo4 != nil {
return ipTo4
}
return ip.To16()
}
// Network masks off the host portion of the IP, if IPNet is invalid,
// return nil
func Network(ipn *net.IPNet) *net.IPNet { func Network(ipn *net.IPNet) *net.IPNet {
if ipn == nil {
return nil
}
maskedIP := ipn.IP.Mask(ipn.Mask)
if maskedIP == nil {
return nil
}
return &net.IPNet{ return &net.IPNet{
IP: maskedIP, IP: ipn.IP.Mask(ipn.Mask),
Mask: ipn.Mask, Mask: ipn.Mask,
} }
} }

View File

@ -1,247 +0,0 @@
// Copyright 2022 CNI authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ip
import (
"net"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("CIDR functions", func() {
It("NextIP", func() {
testCases := []struct {
ip net.IP
nextIP net.IP
}{
{
[]byte{192, 0, 2},
nil,
},
{
net.ParseIP("192.168.0.1"),
net.IPv4(192, 168, 0, 2).To4(),
},
{
net.ParseIP("192.168.0.255"),
net.IPv4(192, 168, 1, 0).To4(),
},
{
net.ParseIP("0.1.0.5"),
net.IPv4(0, 1, 0, 6).To4(),
},
{
net.ParseIP("AB12::123"),
net.ParseIP("AB12::124"),
},
{
net.ParseIP("AB12::FFFF"),
net.ParseIP("AB12::1:0"),
},
{
net.ParseIP("0::123"),
net.ParseIP("0::124"),
},
}
for _, test := range testCases {
ip := NextIP(test.ip)
Expect(ip).To(Equal(test.nextIP))
}
})
It("PrevIP", func() {
testCases := []struct {
ip net.IP
prevIP net.IP
}{
{
[]byte{192, 0, 2},
nil,
},
{
net.ParseIP("192.168.0.2"),
net.IPv4(192, 168, 0, 1).To4(),
},
{
net.ParseIP("192.168.1.0"),
net.IPv4(192, 168, 0, 255).To4(),
},
{
net.ParseIP("0.1.0.5"),
net.IPv4(0, 1, 0, 4).To4(),
},
{
net.ParseIP("AB12::123"),
net.ParseIP("AB12::122"),
},
{
net.ParseIP("AB12::1:0"),
net.ParseIP("AB12::FFFF"),
},
{
net.ParseIP("0::124"),
net.ParseIP("0::123"),
},
}
for _, test := range testCases {
ip := PrevIP(test.ip)
Expect(ip).To(Equal(test.prevIP))
}
})
It("Cmp", func() {
testCases := []struct {
a net.IP
b net.IP
result int
}{
{
net.ParseIP("192.168.0.2"),
nil,
-2,
},
{
net.ParseIP("192.168.0.2"),
[]byte{192, 168, 5},
-2,
},
{
net.ParseIP("192.168.0.2"),
net.ParseIP("AB12::123"),
-2,
},
{
net.ParseIP("192.168.0.2"),
net.ParseIP("192.168.0.5"),
-1,
},
{
net.ParseIP("192.168.0.2"),
net.ParseIP("192.168.0.5").To4(),
-1,
},
{
net.ParseIP("192.168.0.10"),
net.ParseIP("192.168.0.5"),
1,
},
{
net.ParseIP("192.168.0.10"),
net.ParseIP("192.168.0.10"),
0,
},
{
net.ParseIP("192.168.0.10"),
net.ParseIP("192.168.0.10").To4(),
0,
},
{
net.ParseIP("AB12::122"),
net.ParseIP("AB12::123"),
-1,
},
{
net.ParseIP("AB12::210"),
net.ParseIP("AB12::123"),
1,
},
{
net.ParseIP("AB12::210"),
net.ParseIP("AB12::210"),
0,
},
}
for _, test := range testCases {
result := Cmp(test.a, test.b)
Expect(result).To(Equal(test.result))
}
})
It("Network", func() {
testCases := []struct {
ipNet *net.IPNet
result *net.IPNet
}{
{
nil,
nil,
},
{
&net.IPNet{
IP: nil,
Mask: net.IPv4Mask(255, 255, 255, 0),
},
nil,
},
{
&net.IPNet{
IP: net.IPv4(192, 168, 0, 1),
Mask: nil,
},
nil,
},
{
&net.IPNet{
IP: net.ParseIP("AB12::123"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
nil,
},
{
&net.IPNet{
IP: net.IPv4(192, 168, 0, 100).To4(),
Mask: net.CIDRMask(120, 128),
},
&net.IPNet{
IP: net.IPv4(192, 168, 0, 0).To4(),
Mask: net.CIDRMask(120, 128),
},
},
{
&net.IPNet{
IP: net.IPv4(192, 168, 0, 100),
Mask: net.CIDRMask(24, 32),
},
&net.IPNet{
IP: net.IPv4(192, 168, 0, 0).To4(),
Mask: net.CIDRMask(24, 32),
},
},
{
&net.IPNet{
IP: net.ParseIP("AB12::123"),
Mask: net.CIDRMask(120, 128),
},
&net.IPNet{
IP: net.ParseIP("AB12::100"),
Mask: net.CIDRMask(120, 128),
},
},
}
for _, test := range testCases {
result := Network(test.ipNet)
Expect(result).To(Equal(test.result))
}
})
})

View File

@ -47,12 +47,13 @@ func ParseIP(s string) *IP {
return nil return nil
} }
return newIP(ip, ipNet.Mask) return newIP(ip, ipNet.Mask)
} else {
ip := net.ParseIP(s)
if ip == nil {
return nil
}
return newIP(ip, nil)
} }
ip := net.ParseIP(s)
if ip == nil {
return nil
}
return newIP(ip, nil)
} }
// ToIP will return a net.IP in standard form from this IP. // ToIP will return a net.IP in standard form from this IP.

View File

@ -15,10 +15,10 @@
package ip_test package ip_test
import ( import (
"testing" . "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"testing"
) )
func TestIp(t *testing.T) { func TestIp(t *testing.T) {

View File

@ -19,7 +19,7 @@ import (
"fmt" "fmt"
"net" "net"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -124,7 +124,7 @@ var _ = Describe("IP Operations", func() {
} }
for _, test := range testCases { for _, test := range testCases {
Expect(test.ip.ToIP()).To(HaveLen(test.expectedLen)) Expect(len(test.ip.ToIP())).To(Equal(test.expectedLen))
Expect(test.ip.ToIP()).To(Equal(test.expectedIP)) Expect(test.ip.ToIP()).To(Equal(test.expectedIP))
} }
}) })
@ -174,8 +174,8 @@ var _ = Describe("IP Operations", func() {
} }
}) })
Context("Decode", func() { It("Decode", func() {
It("valid IP", func() { Context("valid IP", func() {
testCases := []struct { testCases := []struct {
text string text string
expected *IP expected *IP
@ -205,9 +205,10 @@ var _ = Describe("IP Operations", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(ip).To(Equal(test.expected)) Expect(ip).To(Equal(test.expected))
} }
}) })
It("empty text", func() { Context("empty text", func() {
ip := &IP{} ip := &IP{}
err := json.Unmarshal([]byte(`""`), ip) err := json.Unmarshal([]byte(`""`), ip)
@ -215,7 +216,7 @@ var _ = Describe("IP Operations", func() {
Expect(ip).To(Equal(newIP(nil, nil))) Expect(ip).To(Equal(newIP(nil, nil)))
}) })
It("invalid IP", func() { Context("invalid IP", func() {
testCases := []struct { testCases := []struct {
text string text string
expectedErr error expectedErr error
@ -242,7 +243,7 @@ var _ = Describe("IP Operations", func() {
} }
}) })
It("IP slice", func() { Context("IP slice", func() {
testCases := []struct { testCases := []struct {
text string text string
expected []*IP expected []*IP

View File

@ -16,7 +16,7 @@ package ip
import ( import (
"bytes" "bytes"
"os" "io/ioutil"
current "github.com/containernetworking/cni/pkg/types/100" current "github.com/containernetworking/cni/pkg/types/100"
) )
@ -53,10 +53,10 @@ func EnableForward(ips []*current.IPConfig) error {
} }
func echo1(f string) error { func echo1(f string) error {
if content, err := os.ReadFile(f); err == nil { if content, err := ioutil.ReadFile(f); err == nil {
if bytes.Equal(bytes.TrimSpace(content), []byte("1")) { if bytes.Equal(bytes.TrimSpace(content), []byte("1")) {
return nil return nil
} }
} }
return os.WriteFile(f, []byte("1"), 0o644) return ioutil.WriteFile(f, []byte("1"), 0644)
} }

View File

@ -1,16 +1,17 @@
package ip package ip
import ( import (
"io/ioutil"
"os" "os"
"time" "time"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
var _ = Describe("IpforwardLinux", func() { var _ = Describe("IpforwardLinux", func() {
It("echo1 must not write the file if content is 1", func() { It("echo1 must not write the file if content is 1", func() {
file, err := os.CreateTemp("", "containernetworking") file, err := ioutil.TempFile(os.TempDir(), "containernetworking")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
defer os.Remove(file.Name()) defer os.Remove(file.Name())
err = echo1(file.Name()) err = echo1(file.Name())

View File

@ -1,180 +0,0 @@
// Copyright 2015 CNI authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ip
import (
"errors"
"fmt"
"net"
"strings"
"github.com/coreos/go-iptables/iptables"
"github.com/containernetworking/cni/pkg/types"
"github.com/containernetworking/plugins/pkg/utils"
)
// setupIPMasqIPTables is the iptables-based implementation of SetupIPMasqForNetworks
func setupIPMasqIPTables(ipns []*net.IPNet, network, _, containerID string) error {
// Note: for historical reasons, the iptables implementation ignores ifname.
chain := utils.FormatChainName(network, containerID)
comment := utils.FormatComment(network, containerID)
for _, ip := range ipns {
if err := SetupIPMasq(ip, chain, comment); err != nil {
return err
}
}
return nil
}
// SetupIPMasq installs iptables rules to masquerade traffic
// coming from ip of ipn and going outside of ipn.
// Deprecated: This function only supports iptables. Use SetupIPMasqForNetworks, which
// supports both iptables and nftables.
func SetupIPMasq(ipn *net.IPNet, chain string, comment string) error {
isV6 := ipn.IP.To4() == nil
var ipt *iptables.IPTables
var err error
var multicastNet string
if isV6 {
ipt, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
multicastNet = "ff00::/8"
} else {
ipt, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
multicastNet = "224.0.0.0/4"
}
if err != nil {
return fmt.Errorf("failed to locate iptables: %v", err)
}
// Create chain if doesn't exist
exists := false
chains, err := ipt.ListChains("nat")
if err != nil {
return fmt.Errorf("failed to list chains: %v", err)
}
for _, ch := range chains {
if ch == chain {
exists = true
break
}
}
if !exists {
if err = ipt.NewChain("nat", chain); err != nil {
return err
}
}
// Packets to this network should not be touched
if err := ipt.AppendUnique("nat", chain, "-d", ipn.String(), "-j", "ACCEPT", "-m", "comment", "--comment", comment); err != nil {
return err
}
// Don't masquerade multicast - pods should be able to talk to other pods
// on the local network via multicast.
if err := ipt.AppendUnique("nat", chain, "!", "-d", multicastNet, "-j", "MASQUERADE", "-m", "comment", "--comment", comment); err != nil {
return err
}
// Packets from the specific IP of this network will hit the chain
return ipt.AppendUnique("nat", "POSTROUTING", "-s", ipn.IP.String(), "-j", chain, "-m", "comment", "--comment", comment)
}
// teardownIPMasqIPTables is the iptables-based implementation of TeardownIPMasqForNetworks
func teardownIPMasqIPTables(ipns []*net.IPNet, network, _, containerID string) error {
// Note: for historical reasons, the iptables implementation ignores ifname.
chain := utils.FormatChainName(network, containerID)
comment := utils.FormatComment(network, containerID)
var errs []string
for _, ipn := range ipns {
err := TeardownIPMasq(ipn, chain, comment)
if err != nil {
errs = append(errs, err.Error())
}
}
if errs == nil {
return nil
}
return errors.New(strings.Join(errs, "\n"))
}
// TeardownIPMasq undoes the effects of SetupIPMasq.
// Deprecated: This function only supports iptables. Use TeardownIPMasqForNetworks, which
// supports both iptables and nftables.
func TeardownIPMasq(ipn *net.IPNet, chain string, comment string) error {
isV6 := ipn.IP.To4() == nil
var ipt *iptables.IPTables
var err error
if isV6 {
ipt, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
} else {
ipt, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
}
if err != nil {
return fmt.Errorf("failed to locate iptables: %v", err)
}
err = ipt.Delete("nat", "POSTROUTING", "-s", ipn.IP.String(), "-j", chain, "-m", "comment", "--comment", comment)
if err != nil && !isNotExist(err) {
return err
}
// for downward compatibility
err = ipt.Delete("nat", "POSTROUTING", "-s", ipn.String(), "-j", chain, "-m", "comment", "--comment", comment)
if err != nil && !isNotExist(err) {
return err
}
err = ipt.ClearChain("nat", chain)
if err != nil && !isNotExist(err) {
return err
}
err = ipt.DeleteChain("nat", chain)
if err != nil && !isNotExist(err) {
return err
}
return nil
}
// gcIPMasqIPTables is the iptables-based implementation of GCIPMasqForNetwork
func gcIPMasqIPTables(_ string, _ []types.GCAttachment) error {
// FIXME: The iptables implementation does not support GC.
//
// (In theory, it _could_ backward-compatibly support it, by adding a no-op rule
// with a comment indicating the network to each chain it creates, so that it
// could later figure out which chains corresponded to which networks; older
// implementations would ignore the extra rule but would still correctly delete
// the chain on teardown (because they ClearChain() before doing DeleteChain()).
return nil
}
// isNotExist returnst true if the error is from iptables indicating
// that the target does not exist.
func isNotExist(err error) bool {
e, ok := err.(*iptables.Error)
if !ok {
return false
}
return e.IsNotExist()
}

View File

@ -15,78 +15,112 @@
package ip package ip
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"strings"
"github.com/containernetworking/cni/pkg/types" "github.com/coreos/go-iptables/iptables"
"github.com/containernetworking/plugins/pkg/utils"
) )
// SetupIPMasqForNetworks installs rules to masquerade traffic coming from ips of ipns and // SetupIPMasq installs iptables rules to masquerade traffic
// going outside of ipns, using a chain name based on network, ifname, and containerID. The // coming from ip of ipn and going outside of ipn
// backend can be either "iptables" or "nftables"; if it is nil, then a suitable default func SetupIPMasq(ipn *net.IPNet, chain string, comment string) error {
// implementation will be used. isV6 := ipn.IP.To4() == nil
func SetupIPMasqForNetworks(backend *string, ipns []*net.IPNet, network, ifname, containerID string) error {
if backend == nil { var ipt *iptables.IPTables
// Prefer iptables, unless only nftables is available var err error
defaultBackend := "iptables" var multicastNet string
if !utils.SupportsIPTables() && utils.SupportsNFTables() {
defaultBackend = "nftables" if isV6 {
ipt, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
multicastNet = "ff00::/8"
} else {
ipt, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
multicastNet = "224.0.0.0/4"
}
if err != nil {
return fmt.Errorf("failed to locate iptables: %v", err)
}
// Create chain if doesn't exist
exists := false
chains, err := ipt.ListChains("nat")
if err != nil {
return fmt.Errorf("failed to list chains: %v", err)
}
for _, ch := range chains {
if ch == chain {
exists = true
break
}
}
if !exists {
if err = ipt.NewChain("nat", chain); err != nil {
return err
} }
backend = &defaultBackend
} }
switch *backend { // Packets to this network should not be touched
case "iptables": if err := ipt.AppendUnique("nat", chain, "-d", ipn.String(), "-j", "ACCEPT", "-m", "comment", "--comment", comment); err != nil {
return setupIPMasqIPTables(ipns, network, ifname, containerID) return err
case "nftables":
return setupIPMasqNFTables(ipns, network, ifname, containerID)
default:
return fmt.Errorf("unknown ipmasq backend %q", *backend)
} }
// Don't masquerade multicast - pods should be able to talk to other pods
// on the local network via multicast.
if err := ipt.AppendUnique("nat", chain, "!", "-d", multicastNet, "-j", "MASQUERADE", "-m", "comment", "--comment", comment); err != nil {
return err
}
// Packets from the specific IP of this network will hit the chain
return ipt.AppendUnique("nat", "POSTROUTING", "-s", ipn.IP.String(), "-j", chain, "-m", "comment", "--comment", comment)
} }
// TeardownIPMasqForNetworks undoes the effects of SetupIPMasqForNetworks // TeardownIPMasq undoes the effects of SetupIPMasq
func TeardownIPMasqForNetworks(ipns []*net.IPNet, network, ifname, containerID string) error { func TeardownIPMasq(ipn *net.IPNet, chain string, comment string) error {
var errs []string isV6 := ipn.IP.To4() == nil
// Do both the iptables and the nftables cleanup, since the pod may have been var ipt *iptables.IPTables
// created with a different version of this plugin or a different configuration. var err error
err := teardownIPMasqIPTables(ipns, network, ifname, containerID) if isV6 {
if err != nil && utils.SupportsIPTables() { ipt, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
errs = append(errs, err.Error()) } else {
ipt, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
}
if err != nil {
return fmt.Errorf("failed to locate iptables: %v", err)
} }
err = teardownIPMasqNFTables(ipns, network, ifname, containerID) err = ipt.Delete("nat", "POSTROUTING", "-s", ipn.IP.String(), "-j", chain, "-m", "comment", "--comment", comment)
if err != nil && utils.SupportsNFTables() { if err != nil && !isNotExist(err) {
errs = append(errs, err.Error()) return err
} }
if errs == nil { // for downward compatibility
return nil err = ipt.Delete("nat", "POSTROUTING", "-s", ipn.String(), "-j", chain, "-m", "comment", "--comment", comment)
if err != nil && !isNotExist(err) {
return err
} }
return errors.New(strings.Join(errs, "\n"))
err = ipt.ClearChain("nat", chain)
if err != nil && !isNotExist(err) {
return err
}
err = ipt.DeleteChain("nat", chain)
if err != nil && !isNotExist(err) {
return err
}
return nil
} }
// GCIPMasqForNetwork garbage collects stale IPMasq entries for network // isNotExist returnst true if the error is from iptables indicating
func GCIPMasqForNetwork(network string, attachments []types.GCAttachment) error { // that the target does not exist.
var errs []string func isNotExist(err error) bool {
e, ok := err.(*iptables.Error)
err := gcIPMasqIPTables(network, attachments) if !ok {
if err != nil && utils.SupportsIPTables() { return false
errs = append(errs, err.Error())
} }
return e.IsNotExist()
err = gcIPMasqNFTables(network, attachments)
if err != nil && utils.SupportsNFTables() {
errs = append(errs, err.Error())
}
if errs == nil {
return nil
}
return errors.New(strings.Join(errs, "\n"))
} }

View File

@ -1,231 +0,0 @@
// Copyright 2023 CNI authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ip
import (
"context"
"fmt"
"net"
"strings"
"sigs.k8s.io/knftables"
"github.com/containernetworking/cni/pkg/types"
"github.com/containernetworking/plugins/pkg/utils"
)
const (
ipMasqTableName = "cni_plugins_masquerade"
ipMasqChainName = "masq_checks"
)
// The nftables ipmasq implementation is mostly like the iptables implementation, with
// minor updates to fix a bug (adding `ifname`) and to allow future GC support.
//
// We add a rule for each mapping, with a comment containing a hash of its identifiers,
// so that we can later reliably delete the rules we want. (This is important because in
// edge cases, it's possible the plugin might see "ADD container A with IP 192.168.1.3",
// followed by "ADD container B with IP 192.168.1.3" followed by "DEL container A with IP
// 192.168.1.3", and we need to make sure that the DEL causes us to delete the rule for
// container A, and not the rule for container B.)
//
// It would be more nftables-y to have a chain with a single rule doing a lookup against a
// set with an element per mapping, rather than having a chain with a rule per mapping.
// But there's no easy, non-racy way to say "delete the element 192.168.1.3 from the set,
// but only if it was added for container A, not if it was added for container B".
// hashForNetwork returns a unique hash for this network
func hashForNetwork(network string) string {
return utils.MustFormatHashWithPrefix(16, "", network)
}
// hashForInstance returns a unique hash identifying the rules for this
// network/ifname/containerID
func hashForInstance(network, ifname, containerID string) string {
return hashForNetwork(network) + "-" + utils.MustFormatHashWithPrefix(16, "", ifname+":"+containerID)
}
// commentForInstance returns a comment string that begins with a unique hash and
// ends with a (possibly-truncated) human-readable description.
func commentForInstance(network, ifname, containerID string) string {
comment := fmt.Sprintf("%s, net: %s, if: %s, id: %s",
hashForInstance(network, ifname, containerID),
strings.ReplaceAll(network, `"`, ``),
strings.ReplaceAll(ifname, `"`, ``),
strings.ReplaceAll(containerID, `"`, ``),
)
if len(comment) > knftables.CommentLengthMax {
comment = comment[:knftables.CommentLengthMax]
}
return comment
}
// setupIPMasqNFTables is the nftables-based implementation of SetupIPMasqForNetworks
func setupIPMasqNFTables(ipns []*net.IPNet, network, ifname, containerID string) error {
nft, err := knftables.New(knftables.InetFamily, ipMasqTableName)
if err != nil {
return err
}
return setupIPMasqNFTablesWithInterface(nft, ipns, network, ifname, containerID)
}
func setupIPMasqNFTablesWithInterface(nft knftables.Interface, ipns []*net.IPNet, network, ifname, containerID string) error {
staleRules, err := findRules(nft, hashForInstance(network, ifname, containerID))
if err != nil {
return err
}
tx := nft.NewTransaction()
// Ensure that our table and chains exist.
tx.Add(&knftables.Table{
Comment: knftables.PtrTo("Masquerading for plugins from github.com/containernetworking/plugins"),
})
tx.Add(&knftables.Chain{
Name: ipMasqChainName,
Comment: knftables.PtrTo("Masquerade traffic from certain IPs to any (non-multicast) IP outside their subnet"),
})
// Ensure that the postrouting chain exists and has the correct rules. (Has to be
// done after creating ipMasqChainName, so we can jump to it.)
tx.Add(&knftables.Chain{
Name: "postrouting",
Type: knftables.PtrTo(knftables.NATType),
Hook: knftables.PtrTo(knftables.PostroutingHook),
Priority: knftables.PtrTo(knftables.SNATPriority),
})
tx.Flush(&knftables.Chain{
Name: "postrouting",
})
tx.Add(&knftables.Rule{
Chain: "postrouting",
Rule: "ip daddr == 224.0.0.0/4 return",
})
tx.Add(&knftables.Rule{
Chain: "postrouting",
Rule: "ip6 daddr == ff00::/8 return",
})
tx.Add(&knftables.Rule{
Chain: "postrouting",
Rule: knftables.Concat(
"goto", ipMasqChainName,
),
})
// Delete stale rules, add new rules to masquerade chain
for _, rule := range staleRules {
tx.Delete(rule)
}
for _, ipn := range ipns {
ip := "ip"
if ipn.IP.To4() == nil {
ip = "ip6"
}
// e.g. if ipn is "192.168.1.4/24", then dstNet is "192.168.1.0/24"
dstNet := &net.IPNet{IP: ipn.IP.Mask(ipn.Mask), Mask: ipn.Mask}
tx.Add(&knftables.Rule{
Chain: ipMasqChainName,
Rule: knftables.Concat(
ip, "saddr", "==", ipn.IP,
ip, "daddr", "!=", dstNet,
"masquerade",
),
Comment: knftables.PtrTo(commentForInstance(network, ifname, containerID)),
})
}
return nft.Run(context.TODO(), tx)
}
// teardownIPMasqNFTables is the nftables-based implementation of TeardownIPMasqForNetworks
func teardownIPMasqNFTables(ipns []*net.IPNet, network, ifname, containerID string) error {
nft, err := knftables.New(knftables.InetFamily, ipMasqTableName)
if err != nil {
return err
}
return teardownIPMasqNFTablesWithInterface(nft, ipns, network, ifname, containerID)
}
func teardownIPMasqNFTablesWithInterface(nft knftables.Interface, _ []*net.IPNet, network, ifname, containerID string) error {
rules, err := findRules(nft, hashForInstance(network, ifname, containerID))
if err != nil {
return err
} else if len(rules) == 0 {
return nil
}
tx := nft.NewTransaction()
for _, rule := range rules {
tx.Delete(rule)
}
return nft.Run(context.TODO(), tx)
}
// gcIPMasqNFTables is the nftables-based implementation of GCIPMasqForNetwork
func gcIPMasqNFTables(network string, attachments []types.GCAttachment) error {
nft, err := knftables.New(knftables.InetFamily, ipMasqTableName)
if err != nil {
return err
}
return gcIPMasqNFTablesWithInterface(nft, network, attachments)
}
func gcIPMasqNFTablesWithInterface(nft knftables.Interface, network string, attachments []types.GCAttachment) error {
// Find all rules for the network
rules, err := findRules(nft, hashForNetwork(network))
if err != nil {
return err
} else if len(rules) == 0 {
return nil
}
// Compute the comments for all elements of attachments
validAttachments := map[string]bool{}
for _, attachment := range attachments {
validAttachments[commentForInstance(network, attachment.IfName, attachment.ContainerID)] = true
}
// Delete anything in rules that isn't in validAttachments
tx := nft.NewTransaction()
for _, rule := range rules {
if !validAttachments[*rule.Comment] {
tx.Delete(rule)
}
}
return nft.Run(context.TODO(), tx)
}
// findRules finds rules with comments that start with commentPrefix.
func findRules(nft knftables.Interface, commentPrefix string) ([]*knftables.Rule, error) {
rules, err := nft.ListRules(context.TODO(), ipMasqChainName)
if err != nil {
if knftables.IsNotFound(err) {
// If ipMasqChainName doesn't exist yet, that's fine
return nil, nil
}
return nil, err
}
matchingRules := make([]*knftables.Rule, 0, 1)
for _, rule := range rules {
if rule.Comment != nil && strings.HasPrefix(*rule.Comment, commentPrefix) {
matchingRules = append(matchingRules, rule)
}
}
return matchingRules, nil
}

View File

@ -1,213 +0,0 @@
// Copyright 2023 CNI authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ip
import (
"net"
"strings"
"testing"
"github.com/vishvananda/netlink"
"sigs.k8s.io/knftables"
"github.com/containernetworking/cni/pkg/types"
)
func Test_setupIPMasqNFTables(t *testing.T) {
nft := knftables.NewFake(knftables.InetFamily, ipMasqTableName)
containers := []struct {
network string
ifname string
containerID string
addrs []string
}{
{
network: "unit-test",
ifname: "eth0",
containerID: "one",
addrs: []string{"192.168.1.1/24"},
},
{
network: "unit-test",
ifname: "eth0",
containerID: "two",
addrs: []string{"192.168.1.2/24", "2001:db8::2/64"},
},
{
network: "unit-test",
ifname: "eth0",
containerID: "three",
addrs: []string{"192.168.99.5/24"},
},
{
network: "alternate",
ifname: "net1",
containerID: "three",
addrs: []string{
"10.0.0.5/24",
"10.0.0.6/24",
"10.0.1.7/24",
"2001:db8::5/64",
"2001:db8::6/64",
"2001:db8:1::7/64",
},
},
}
for _, c := range containers {
ipns := []*net.IPNet{}
for _, addr := range c.addrs {
nladdr, err := netlink.ParseAddr(addr)
if err != nil {
t.Fatalf("failed to parse test addr: %v", err)
}
ipns = append(ipns, nladdr.IPNet)
}
err := setupIPMasqNFTablesWithInterface(nft, ipns, c.network, c.ifname, c.containerID)
if err != nil {
t.Fatalf("error from setupIPMasqNFTables: %v", err)
}
}
expected := strings.TrimSpace(`
add table inet cni_plugins_masquerade { comment "Masquerading for plugins from github.com/containernetworking/plugins" ; }
add chain inet cni_plugins_masquerade masq_checks { comment "Masquerade traffic from certain IPs to any (non-multicast) IP outside their subnet" ; }
add chain inet cni_plugins_masquerade postrouting { type nat hook postrouting priority 100 ; }
add rule inet cni_plugins_masquerade masq_checks ip saddr == 192.168.1.1 ip daddr != 192.168.1.0/24 masquerade comment "6fd94d501e58f0aa-287fc69eff0574a2, net: unit-test, if: eth0, id: one"
add rule inet cni_plugins_masquerade masq_checks ip saddr == 192.168.1.2 ip daddr != 192.168.1.0/24 masquerade comment "6fd94d501e58f0aa-d750b2c8f0f25d5f, net: unit-test, if: eth0, id: two"
add rule inet cni_plugins_masquerade masq_checks ip6 saddr == 2001:db8::2 ip6 daddr != 2001:db8::/64 masquerade comment "6fd94d501e58f0aa-d750b2c8f0f25d5f, net: unit-test, if: eth0, id: two"
add rule inet cni_plugins_masquerade masq_checks ip saddr == 192.168.99.5 ip daddr != 192.168.99.0/24 masquerade comment "6fd94d501e58f0aa-a4d4adb82b669cfe, net: unit-test, if: eth0, id: three"
add rule inet cni_plugins_masquerade masq_checks ip saddr == 10.0.0.5 ip daddr != 10.0.0.0/24 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade masq_checks ip saddr == 10.0.0.6 ip daddr != 10.0.0.0/24 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade masq_checks ip saddr == 10.0.1.7 ip daddr != 10.0.1.0/24 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade masq_checks ip6 saddr == 2001:db8::5 ip6 daddr != 2001:db8::/64 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade masq_checks ip6 saddr == 2001:db8::6 ip6 daddr != 2001:db8::/64 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade masq_checks ip6 saddr == 2001:db8:1::7 ip6 daddr != 2001:db8:1::/64 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade postrouting ip daddr == 224.0.0.0/4 return
add rule inet cni_plugins_masquerade postrouting ip6 daddr == ff00::/8 return
add rule inet cni_plugins_masquerade postrouting goto masq_checks
`)
dump := strings.TrimSpace(nft.Dump())
if dump != expected {
t.Errorf("expected nftables state:\n%s\n\nactual:\n%s\n\n", expected, dump)
}
// Add a new container reusing "one"'s address, before deleting "one"
c := containers[0]
addr, err := netlink.ParseAddr(c.addrs[0])
if err != nil {
t.Fatalf("failed to parse test addr: %v", err)
}
err = setupIPMasqNFTablesWithInterface(nft, []*net.IPNet{addr.IPNet}, "unit-test", "eth0", "four")
if err != nil {
t.Fatalf("error from setupIPMasqNFTables: %v", err)
}
// Remove "one"
err = teardownIPMasqNFTablesWithInterface(nft, []*net.IPNet{addr.IPNet}, c.network, c.ifname, c.containerID)
if err != nil {
t.Fatalf("error from teardownIPMasqNFTables: %v", err)
}
// Check that "one" was deleted (and "four" wasn't)
expected = strings.TrimSpace(`
add table inet cni_plugins_masquerade { comment "Masquerading for plugins from github.com/containernetworking/plugins" ; }
add chain inet cni_plugins_masquerade masq_checks { comment "Masquerade traffic from certain IPs to any (non-multicast) IP outside their subnet" ; }
add chain inet cni_plugins_masquerade postrouting { type nat hook postrouting priority 100 ; }
add rule inet cni_plugins_masquerade masq_checks ip saddr == 192.168.1.2 ip daddr != 192.168.1.0/24 masquerade comment "6fd94d501e58f0aa-d750b2c8f0f25d5f, net: unit-test, if: eth0, id: two"
add rule inet cni_plugins_masquerade masq_checks ip6 saddr == 2001:db8::2 ip6 daddr != 2001:db8::/64 masquerade comment "6fd94d501e58f0aa-d750b2c8f0f25d5f, net: unit-test, if: eth0, id: two"
add rule inet cni_plugins_masquerade masq_checks ip saddr == 192.168.99.5 ip daddr != 192.168.99.0/24 masquerade comment "6fd94d501e58f0aa-a4d4adb82b669cfe, net: unit-test, if: eth0, id: three"
add rule inet cni_plugins_masquerade masq_checks ip saddr == 10.0.0.5 ip daddr != 10.0.0.0/24 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade masq_checks ip saddr == 10.0.0.6 ip daddr != 10.0.0.0/24 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade masq_checks ip saddr == 10.0.1.7 ip daddr != 10.0.1.0/24 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade masq_checks ip6 saddr == 2001:db8::5 ip6 daddr != 2001:db8::/64 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade masq_checks ip6 saddr == 2001:db8::6 ip6 daddr != 2001:db8::/64 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade masq_checks ip6 saddr == 2001:db8:1::7 ip6 daddr != 2001:db8:1::/64 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade masq_checks ip saddr == 192.168.1.1 ip daddr != 192.168.1.0/24 masquerade comment "6fd94d501e58f0aa-e766de567ef6c543, net: unit-test, if: eth0, id: four"
add rule inet cni_plugins_masquerade postrouting ip daddr == 224.0.0.0/4 return
add rule inet cni_plugins_masquerade postrouting ip6 daddr == ff00::/8 return
add rule inet cni_plugins_masquerade postrouting goto masq_checks
`)
dump = strings.TrimSpace(nft.Dump())
if dump != expected {
t.Errorf("expected nftables state:\n%s\n\nactual:\n%s\n\n", expected, dump)
}
// GC "four" from the "unit-test" network
err = gcIPMasqNFTablesWithInterface(nft, "unit-test", []types.GCAttachment{
{IfName: "eth0", ContainerID: "two"},
{IfName: "eth0", ContainerID: "three"},
// (irrelevant extra element)
{IfName: "eth0", ContainerID: "one"},
})
if err != nil {
t.Fatalf("error from gcIPMasqNFTables: %v", err)
}
// GC the "alternate" network without removing anything
err = gcIPMasqNFTablesWithInterface(nft, "alternate", []types.GCAttachment{
{IfName: "net1", ContainerID: "three"},
})
if err != nil {
t.Fatalf("error from gcIPMasqNFTables: %v", err)
}
// Re-dump
expected = strings.TrimSpace(`
add table inet cni_plugins_masquerade { comment "Masquerading for plugins from github.com/containernetworking/plugins" ; }
add chain inet cni_plugins_masquerade masq_checks { comment "Masquerade traffic from certain IPs to any (non-multicast) IP outside their subnet" ; }
add chain inet cni_plugins_masquerade postrouting { type nat hook postrouting priority 100 ; }
add rule inet cni_plugins_masquerade masq_checks ip saddr == 192.168.1.2 ip daddr != 192.168.1.0/24 masquerade comment "6fd94d501e58f0aa-d750b2c8f0f25d5f, net: unit-test, if: eth0, id: two"
add rule inet cni_plugins_masquerade masq_checks ip6 saddr == 2001:db8::2 ip6 daddr != 2001:db8::/64 masquerade comment "6fd94d501e58f0aa-d750b2c8f0f25d5f, net: unit-test, if: eth0, id: two"
add rule inet cni_plugins_masquerade masq_checks ip saddr == 192.168.99.5 ip daddr != 192.168.99.0/24 masquerade comment "6fd94d501e58f0aa-a4d4adb82b669cfe, net: unit-test, if: eth0, id: three"
add rule inet cni_plugins_masquerade masq_checks ip saddr == 10.0.0.5 ip daddr != 10.0.0.0/24 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade masq_checks ip saddr == 10.0.0.6 ip daddr != 10.0.0.0/24 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade masq_checks ip saddr == 10.0.1.7 ip daddr != 10.0.1.0/24 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade masq_checks ip6 saddr == 2001:db8::5 ip6 daddr != 2001:db8::/64 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade masq_checks ip6 saddr == 2001:db8::6 ip6 daddr != 2001:db8::/64 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade masq_checks ip6 saddr == 2001:db8:1::7 ip6 daddr != 2001:db8:1::/64 masquerade comment "82783ef24bdc7036-acb19d111858e348, net: alternate, if: net1, id: three"
add rule inet cni_plugins_masquerade postrouting ip daddr == 224.0.0.0/4 return
add rule inet cni_plugins_masquerade postrouting ip6 daddr == ff00::/8 return
add rule inet cni_plugins_masquerade postrouting goto masq_checks
`)
dump = strings.TrimSpace(nft.Dump())
if dump != expected {
t.Errorf("expected nftables state:\n%s\n\nactual:\n%s\n\n", expected, dump)
}
// GC everything
err = gcIPMasqNFTablesWithInterface(nft, "unit-test", []types.GCAttachment{})
if err != nil {
t.Fatalf("error from gcIPMasqNFTables: %v", err)
}
err = gcIPMasqNFTablesWithInterface(nft, "alternate", []types.GCAttachment{})
if err != nil {
t.Fatalf("error from gcIPMasqNFTables: %v", err)
}
expected = strings.TrimSpace(`
add table inet cni_plugins_masquerade { comment "Masquerading for plugins from github.com/containernetworking/plugins" ; }
add chain inet cni_plugins_masquerade masq_checks { comment "Masquerade traffic from certain IPs to any (non-multicast) IP outside their subnet" ; }
add chain inet cni_plugins_masquerade postrouting { type nat hook postrouting priority 100 ; }
add rule inet cni_plugins_masquerade postrouting ip daddr == 224.0.0.0/4 return
add rule inet cni_plugins_masquerade postrouting ip6 daddr == ff00::/8 return
add rule inet cni_plugins_masquerade postrouting goto masq_checks
`)
dump = strings.TrimSpace(nft.Dump())
if dump != expected {
t.Errorf("expected nftables state:\n%s\n\nactual:\n%s\n\n", expected, dump)
}
}

View File

@ -28,16 +28,17 @@ import (
"github.com/containernetworking/plugins/pkg/utils/sysctl" "github.com/containernetworking/plugins/pkg/utils/sysctl"
) )
var ErrLinkNotFound = errors.New("link not found") var (
ErrLinkNotFound = errors.New("link not found")
)
// makeVethPair is called from within the container's network namespace // makeVethPair is called from within the container's network namespace
func makeVethPair(name, peer string, mtu int, mac string, hostNS ns.NetNS) (netlink.Link, error) { func makeVethPair(name, peer string, mtu int, mac string, hostNS ns.NetNS) (netlink.Link, error) {
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = name
linkAttrs.MTU = mtu
veth := &netlink.Veth{ veth := &netlink.Veth{
LinkAttrs: linkAttrs, LinkAttrs: netlink.LinkAttrs{
Name: name,
MTU: mtu,
},
PeerName: peer, PeerName: peer,
PeerNamespace: netlink.NsFd(int(hostNS.Fd())), PeerNamespace: netlink.NsFd(int(hostNS.Fd())),
} }
@ -68,37 +69,38 @@ func peerExists(name string) bool {
return true return true
} }
func makeVeth(name, vethPeerName string, mtu int, mac string, hostNS ns.NetNS) (string, netlink.Link, error) { func makeVeth(name, vethPeerName string, mtu int, mac string, hostNS ns.NetNS) (peerName string, veth netlink.Link, err error) {
var peerName string
var veth netlink.Link
var err error
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
if vethPeerName != "" { if vethPeerName != "" {
peerName = vethPeerName peerName = vethPeerName
} else { } else {
peerName, err = RandomVethName() peerName, err = RandomVethName()
if err != nil { if err != nil {
return peerName, nil, err return
} }
} }
veth, err = makeVethPair(name, peerName, mtu, mac, hostNS) veth, err = makeVethPair(name, peerName, mtu, mac, hostNS)
switch { switch {
case err == nil: case err == nil:
return peerName, veth, nil return
case os.IsExist(err): case os.IsExist(err):
if peerExists(peerName) && vethPeerName == "" { if peerExists(peerName) && vethPeerName == "" {
continue continue
} }
return peerName, veth, fmt.Errorf("container veth name (%q) peer provided (%q) already exists", name, peerName) err = fmt.Errorf("container veth name provided (%v) already exists", name)
return
default: default:
return peerName, veth, fmt.Errorf("failed to make veth pair: %v", err) err = fmt.Errorf("failed to make veth pair: %v", err)
return
} }
} }
// should really never be hit // should really never be hit
return peerName, nil, fmt.Errorf("failed to find a unique veth name") err = fmt.Errorf("failed to find a unique veth name")
return
} }
// RandomVethName returns string "veth" with random prefix (hashed from entropy) // RandomVethName returns string "veth" with random prefix (hashed from entropy)

View File

@ -20,15 +20,22 @@ import (
"fmt" "fmt"
"net" "net"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"github.com/vishvananda/netlink"
"github.com/containernetworking/plugins/pkg/ip" "github.com/containernetworking/plugins/pkg/ip"
"github.com/containernetworking/plugins/pkg/ns" "github.com/containernetworking/plugins/pkg/ns"
"github.com/containernetworking/plugins/pkg/testutils" "github.com/containernetworking/plugins/pkg/testutils"
"github.com/vishvananda/netlink"
) )
func getHwAddr(linkname string) string {
veth, err := netlink.LinkByName(linkname)
Expect(err).NotTo(HaveOccurred())
return fmt.Sprintf("%s", veth.Attrs().HardwareAddr)
}
var _ = Describe("Link", func() { var _ = Describe("Link", func() {
const ( const (
ifaceFormatString string = "i%d" ifaceFormatString string = "i%d"
@ -57,7 +64,7 @@ var _ = Describe("Link", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
fakeBytes := make([]byte, 20) fakeBytes := make([]byte, 20)
// to be reset in AfterEach block //to be reset in AfterEach block
rand.Reader = bytes.NewReader(fakeBytes) rand.Reader = bytes.NewReader(fakeBytes)
_ = containerNetNS.Do(func(ns.NetNS) error { _ = containerNetNS.Do(func(ns.NetNS) error {
@ -149,9 +156,9 @@ var _ = Describe("Link", func() {
It("returns useful error", func() { It("returns useful error", func() {
_ = containerNetNS.Do(func(ns.NetNS) error { _ = containerNetNS.Do(func(ns.NetNS) error {
defer GinkgoRecover() defer GinkgoRecover()
testHostVethName := "test" + hostVethName
_, _, err := ip.SetupVethWithName(containerVethName, testHostVethName, mtu, "", hostNetNS) _, _, err := ip.SetupVeth(containerVethName, mtu, "", hostNetNS)
Expect(err.Error()).To(Equal(fmt.Sprintf("container veth name (%q) peer provided (%q) already exists", containerVethName, testHostVethName))) Expect(err.Error()).To(Equal(fmt.Sprintf("container veth name provided (%s) already exists", containerVethName)))
return nil return nil
}) })
@ -174,14 +181,15 @@ var _ = Describe("Link", func() {
Context("when there is no name available for the host-side", func() { Context("when there is no name available for the host-side", func() {
BeforeEach(func() { BeforeEach(func() {
// adding different interface to container ns //adding different interface to container ns
containerVethName += "0" containerVethName += "0"
}) })
It("returns useful error", func() { It("returns useful error", func() {
_ = containerNetNS.Do(func(ns.NetNS) error { _ = containerNetNS.Do(func(ns.NetNS) error {
defer GinkgoRecover() defer GinkgoRecover()
_, _, err := ip.SetupVethWithName(containerVethName, hostVethName, mtu, "", hostNetNS) _, _, err := ip.SetupVeth(containerVethName, mtu, "", hostNetNS)
Expect(err.Error()).To(Equal(fmt.Sprintf("container veth name (%q) peer provided (%q) already exists", containerVethName, hostVethName))) Expect(err.Error()).To(HavePrefix("container veth name provided"))
Expect(err.Error()).To(HaveSuffix("already exists"))
return nil return nil
}) })
}) })
@ -189,7 +197,7 @@ var _ = Describe("Link", func() {
Context("when there is no name conflict for the host or container interfaces", func() { Context("when there is no name conflict for the host or container interfaces", func() {
BeforeEach(func() { BeforeEach(func() {
// adding different interface to container and host ns //adding different interface to container and host ns
containerVethName += "0" containerVethName += "0"
rand.Reader = originalRandReader rand.Reader = originalRandReader
}) })
@ -203,7 +211,7 @@ var _ = Describe("Link", func() {
return nil return nil
}) })
// verify veths are in different namespaces //verify veths are in different namespaces
_ = containerNetNS.Do(func(ns.NetNS) error { _ = containerNetNS.Do(func(ns.NetNS) error {
defer GinkgoRecover() defer GinkgoRecover()
@ -282,7 +290,7 @@ var _ = Describe("Link", func() {
// this will delete the host endpoint too // this will delete the host endpoint too
addr, err := ip.DelLinkByNameAddr(containerVethName) addr, err := ip.DelLinkByNameAddr(containerVethName)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(addr).To(BeEmpty()) Expect(addr).To(HaveLen(0))
return nil return nil
}) })
}) })

View File

@ -42,24 +42,6 @@ func AddHostRoute(ipn *net.IPNet, gw net.IP, dev netlink.Link) error {
// AddDefaultRoute sets the default route on the given gateway. // AddDefaultRoute sets the default route on the given gateway.
func AddDefaultRoute(gw net.IP, dev netlink.Link) error { func AddDefaultRoute(gw net.IP, dev netlink.Link) error {
var defNet *net.IPNet _, defNet, _ := net.ParseCIDR("0.0.0.0/0")
if gw.To4() != nil {
_, defNet, _ = net.ParseCIDR("0.0.0.0/0")
} else {
_, defNet, _ = net.ParseCIDR("::/0")
}
return AddRoute(defNet, gw, dev) return AddRoute(defNet, gw, dev)
} }
// IsIPNetZero check if the IPNet is "0.0.0.0/0" or "::/0"
// This is needed as go-netlink replaces nil Dst with a '0' IPNet since
// https://github.com/vishvananda/netlink/commit/acdc658b8613655ddb69f978e9fb4cf413e2b830
func IsIPNetZero(ipnet *net.IPNet) bool {
if ipnet == nil {
return true
}
if ones, _ := ipnet.Mask.Size(); ones != 0 {
return false
}
return ipnet.IP.Equal(net.IPv4zero) || ipnet.IP.Equal(net.IPv6zero)
}

View File

@ -21,13 +21,13 @@ import (
"fmt" "fmt"
"net" "net"
"github.com/vishvananda/netlink"
"github.com/containernetworking/cni/pkg/types" "github.com/containernetworking/cni/pkg/types"
current "github.com/containernetworking/cni/pkg/types/100" current "github.com/containernetworking/cni/pkg/types/100"
"github.com/vishvananda/netlink"
) )
func ValidateExpectedInterfaceIPs(ifName string, resultIPs []*current.IPConfig) error { func ValidateExpectedInterfaceIPs(ifName string, resultIPs []*current.IPConfig) error {
// Ensure ips // Ensure ips
for _, ips := range resultIPs { for _, ips := range resultIPs {
ourAddr := netlink.Addr{IPNet: &ips.Address} ourAddr := netlink.Addr{IPNet: &ips.Address}
@ -49,15 +49,12 @@ func ValidateExpectedInterfaceIPs(ifName string, resultIPs []*current.IPConfig)
break break
} }
} }
if !match { if match == false {
return fmt.Errorf("Failed to match addr %v on interface %v", ourAddr, ifName) return fmt.Errorf("Failed to match addr %v on interface %v", ourAddr, ifName)
} }
// Convert the host/prefixlen to just prefix for route lookup. // Convert the host/prefixlen to just prefix for route lookup.
_, ourPrefix, err := net.ParseCIDR(ourAddr.String()) _, ourPrefix, err := net.ParseCIDR(ourAddr.String())
if err != nil {
return err
}
findGwy := &netlink.Route{Dst: ourPrefix} findGwy := &netlink.Route{Dst: ourPrefix}
routeFilter := netlink.RT_FILTER_DST routeFilter := netlink.RT_FILTER_DST
@ -80,13 +77,11 @@ func ValidateExpectedInterfaceIPs(ifName string, resultIPs []*current.IPConfig)
} }
func ValidateExpectedRoute(resultRoutes []*types.Route) error { func ValidateExpectedRoute(resultRoutes []*types.Route) error {
// Ensure that each static route in prevResults is found in the routing table // Ensure that each static route in prevResults is found in the routing table
for _, route := range resultRoutes { for _, route := range resultRoutes {
find := &netlink.Route{Dst: &route.Dst, Gw: route.GW} find := &netlink.Route{Dst: &route.Dst, Gw: route.GW}
routeFilter := netlink.RT_FILTER_DST routeFilter := netlink.RT_FILTER_DST | netlink.RT_FILTER_GW
if route.GW != nil {
routeFilter |= netlink.RT_FILTER_GW
}
var family int var family int
switch { switch {

View File

@ -16,7 +16,6 @@ package ipam
import ( import (
"context" "context"
"github.com/containernetworking/cni/pkg/invoke" "github.com/containernetworking/cni/pkg/invoke"
"github.com/containernetworking/cni/pkg/types" "github.com/containernetworking/cni/pkg/types"
) )
@ -32,7 +31,3 @@ func ExecCheck(plugin string, netconf []byte) error {
func ExecDel(plugin string, netconf []byte) error { func ExecDel(plugin string, netconf []byte) error {
return invoke.DelegateDel(context.TODO(), plugin, netconf, nil) return invoke.DelegateDel(context.TODO(), plugin, netconf, nil)
} }
func ExecStatus(plugin string, netconf []byte) error {
return invoke.DelegateStatus(context.TODO(), plugin, netconf, nil)
}

View File

@ -19,11 +19,11 @@ import (
"net" "net"
"os" "os"
"github.com/vishvananda/netlink"
current "github.com/containernetworking/cni/pkg/types/100" current "github.com/containernetworking/cni/pkg/types/100"
"github.com/containernetworking/plugins/pkg/ip" "github.com/containernetworking/plugins/pkg/ip"
"github.com/containernetworking/plugins/pkg/utils/sysctl" "github.com/containernetworking/plugins/pkg/utils/sysctl"
"github.com/vishvananda/netlink"
) )
const ( const (
@ -44,7 +44,7 @@ func ConfigureIface(ifName string, res *current.Result) error {
} }
var v4gw, v6gw net.IP var v4gw, v6gw net.IP
hasEnabledIpv6 := false var has_enabled_ipv6 bool = false
for _, ipc := range res.IPs { for _, ipc := range res.IPs {
if ipc.Interface == nil { if ipc.Interface == nil {
continue continue
@ -57,7 +57,7 @@ func ConfigureIface(ifName string, res *current.Result) error {
// Make sure sysctl "disable_ipv6" is 0 if we are about to add // Make sure sysctl "disable_ipv6" is 0 if we are about to add
// an IPv6 address to the interface // an IPv6 address to the interface
if !hasEnabledIpv6 && ipc.Address.IP.To4() == nil { if !has_enabled_ipv6 && ipc.Address.IP.To4() == nil {
// Enabled IPv6 for loopback "lo" and the interface // Enabled IPv6 for loopback "lo" and the interface
// being configured // being configured
for _, iface := range [2]string{"lo", ifName} { for _, iface := range [2]string{"lo", ifName} {
@ -79,7 +79,7 @@ func ConfigureIface(ifName string, res *current.Result) error {
return fmt.Errorf("failed to enable IPv6 for interface %q (%s=%s): %v", iface, ipv6SysctlValueName, value, err) return fmt.Errorf("failed to enable IPv6 for interface %q (%s=%s): %v", iface, ipv6SysctlValueName, value, err)
} }
} }
hasEnabledIpv6 = true has_enabled_ipv6 = true
} }
addr := &netlink.Addr{IPNet: &ipc.Address, Label: ""} addr := &netlink.Addr{IPNet: &ipc.Address, Label: ""}
@ -117,27 +117,10 @@ func ConfigureIface(ifName string, res *current.Result) error {
Dst: &r.Dst, Dst: &r.Dst,
LinkIndex: link.Attrs().Index, LinkIndex: link.Attrs().Index,
Gw: gw, Gw: gw,
Priority: r.Priority,
}
if r.Table != nil {
route.Table = *r.Table
}
if r.Scope != nil {
route.Scope = netlink.Scope(*r.Scope)
}
if r.Table != nil {
route.Table = *r.Table
}
if r.Scope != nil {
route.Scope = netlink.Scope(*r.Scope)
} }
if err = netlink.RouteAddEcmp(&route); err != nil { if err = netlink.RouteAddEcmp(&route); err != nil {
return fmt.Errorf("failed to add route '%v via %v dev %v metric %d (Scope: %v, Table: %d)': %v", r.Dst, gw, ifName, r.Priority, route.Scope, route.Table, err) return fmt.Errorf("failed to add route '%v via %v dev %v': %v", r.Dst, gw, ifName, err)
} }
} }

View File

@ -18,14 +18,15 @@ import (
"net" "net"
"syscall" "syscall"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/vishvananda/netlink"
"github.com/containernetworking/cni/pkg/types" "github.com/containernetworking/cni/pkg/types"
current "github.com/containernetworking/cni/pkg/types/100" current "github.com/containernetworking/cni/pkg/types/100"
"github.com/containernetworking/plugins/pkg/ns" "github.com/containernetworking/plugins/pkg/ns"
"github.com/containernetworking/plugins/pkg/testutils" "github.com/containernetworking/plugins/pkg/testutils"
"github.com/vishvananda/netlink"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
) )
const LINK_NAME = "eth0" const LINK_NAME = "eth0"
@ -41,11 +42,9 @@ func ipNetEqual(a, b *net.IPNet) bool {
var _ = Describe("ConfigureIface", func() { var _ = Describe("ConfigureIface", func() {
var originalNS ns.NetNS var originalNS ns.NetNS
var ipv4, ipv6, routev4, routev6, routev4Scope *net.IPNet var ipv4, ipv6, routev4, routev6 *net.IPNet
var ipgw4, ipgw6, routegwv4, routegwv6 net.IP var ipgw4, ipgw6, routegwv4, routegwv6 net.IP
var routeScope int
var result *current.Result var result *current.Result
var routeTable int
BeforeEach(func() { BeforeEach(func() {
// Create a new NetNS so we don't modify the host // Create a new NetNS so we don't modify the host
@ -56,12 +55,11 @@ var _ = Describe("ConfigureIface", func() {
err = originalNS.Do(func(ns.NetNS) error { err = originalNS.Do(func(ns.NetNS) error {
defer GinkgoRecover() defer GinkgoRecover()
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = LINK_NAME
// Add master // Add master
err = netlink.LinkAdd(&netlink.Dummy{ err = netlink.LinkAdd(&netlink.Dummy{
LinkAttrs: linkAttrs, LinkAttrs: netlink.LinkAttrs{
Name: LINK_NAME,
},
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
_, err = netlink.LinkByName(LINK_NAME) _, err = netlink.LinkByName(LINK_NAME)
@ -80,10 +78,6 @@ var _ = Describe("ConfigureIface", func() {
routegwv4 = net.ParseIP("1.2.3.5") routegwv4 = net.ParseIP("1.2.3.5")
Expect(routegwv4).NotTo(BeNil()) Expect(routegwv4).NotTo(BeNil())
_, routev4Scope, err = net.ParseCIDR("1.2.3.4/32")
Expect(err).NotTo(HaveOccurred())
Expect(routev4Scope).NotTo(BeNil())
ipgw4 = net.ParseIP("1.2.3.1") ipgw4 = net.ParseIP("1.2.3.1")
Expect(ipgw4).NotTo(BeNil()) Expect(ipgw4).NotTo(BeNil())
@ -100,9 +94,6 @@ var _ = Describe("ConfigureIface", func() {
ipgw6 = net.ParseIP("abcd:1234:ffff::1") ipgw6 = net.ParseIP("abcd:1234:ffff::1")
Expect(ipgw6).NotTo(BeNil()) Expect(ipgw6).NotTo(BeNil())
routeTable := 5000
routeScope = 200
result = &current.Result{ result = &current.Result{
Interfaces: []*current.Interface{ Interfaces: []*current.Interface{
{ {
@ -131,8 +122,6 @@ var _ = Describe("ConfigureIface", func() {
Routes: []*types.Route{ Routes: []*types.Route{
{Dst: *routev4, GW: routegwv4}, {Dst: *routev4, GW: routegwv4},
{Dst: *routev6, GW: routegwv6}, {Dst: *routev6, GW: routegwv6},
{Dst: *routev4, GW: routegwv4, Table: &routeTable},
{Dst: *routev4Scope, Scope: &routeScope},
}, },
} }
}) })
@ -154,12 +143,12 @@ var _ = Describe("ConfigureIface", func() {
v4addrs, err := netlink.AddrList(link, syscall.AF_INET) v4addrs, err := netlink.AddrList(link, syscall.AF_INET)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(v4addrs).To(HaveLen(1)) Expect(len(v4addrs)).To(Equal(1))
Expect(ipNetEqual(v4addrs[0].IPNet, ipv4)).To(BeTrue()) Expect(ipNetEqual(v4addrs[0].IPNet, ipv4)).To(Equal(true))
v6addrs, err := netlink.AddrList(link, syscall.AF_INET6) v6addrs, err := netlink.AddrList(link, syscall.AF_INET6)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(v6addrs).To(HaveLen(2)) Expect(len(v6addrs)).To(Equal(2))
var found bool var found bool
for _, a := range v6addrs { for _, a := range v6addrs {
@ -168,13 +157,13 @@ var _ = Describe("ConfigureIface", func() {
break break
} }
} }
Expect(found).To(BeTrue()) Expect(found).To(Equal(true))
// Ensure the v4 route, v6 route, and subnet route // Ensure the v4 route, v6 route, and subnet route
routes, err := netlink.RouteList(link, 0) routes, err := netlink.RouteList(link, 0)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
var v4found, v6found, v4Scopefound bool var v4found, v6found bool
for _, route := range routes { for _, route := range routes {
isv4 := route.Dst.IP.To4() != nil isv4 := route.Dst.IP.To4() != nil
if isv4 && ipNetEqual(route.Dst, routev4) && route.Gw.Equal(routegwv4) { if isv4 && ipNetEqual(route.Dst, routev4) && route.Gw.Equal(routegwv4) {
@ -183,17 +172,13 @@ var _ = Describe("ConfigureIface", func() {
if !isv4 && ipNetEqual(route.Dst, routev6) && route.Gw.Equal(routegwv6) { if !isv4 && ipNetEqual(route.Dst, routev6) && route.Gw.Equal(routegwv6) {
v6found = true v6found = true
} }
if isv4 && ipNetEqual(route.Dst, routev4Scope) && int(route.Scope) == routeScope {
v4Scopefound = true
}
if v4found && v6found && v4Scopefound { if v4found && v6found {
break break
} }
} }
Expect(v4found).To(BeTrue()) Expect(v4found).To(Equal(true))
Expect(v6found).To(BeTrue()) Expect(v6found).To(Equal(true))
Expect(v4Scopefound).To(BeTrue())
return nil return nil
}) })
@ -217,7 +202,7 @@ var _ = Describe("ConfigureIface", func() {
routes, err := netlink.RouteList(link, 0) routes, err := netlink.RouteList(link, 0)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
var v4found, v6found, v4Tablefound bool var v4found, v6found bool
for _, route := range routes { for _, route := range routes {
isv4 := route.Dst.IP.To4() != nil isv4 := route.Dst.IP.To4() != nil
if isv4 && ipNetEqual(route.Dst, routev4) && route.Gw.Equal(ipgw4) { if isv4 && ipNetEqual(route.Dst, routev4) && route.Gw.Equal(ipgw4) {
@ -231,31 +216,8 @@ var _ = Describe("ConfigureIface", func() {
break break
} }
} }
Expect(v4found).To(BeTrue()) Expect(v4found).To(Equal(true))
Expect(v6found).To(BeTrue()) Expect(v6found).To(Equal(true))
// Need to read all tables, so cannot use RouteList
routeFilter := &netlink.Route{
Table: routeTable,
}
routes, err = netlink.RouteListFiltered(netlink.FAMILY_ALL,
routeFilter,
netlink.RT_FILTER_TABLE)
Expect(err).NotTo(HaveOccurred())
for _, route := range routes {
isv4 := route.Dst.IP.To4() != nil
if isv4 && ipNetEqual(route.Dst, routev4) && route.Gw.Equal(ipgw4) {
v4Tablefound = true
}
if v4Tablefound {
break
}
}
Expect(v4Tablefound).To(BeTrue())
return nil return nil
}) })

View File

@ -15,10 +15,10 @@
package ipam_test package ipam_test
import ( import (
"testing" . "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"testing"
) )
func TestIpam(t *testing.T) { func TestIpam(t *testing.T) {

View File

@ -17,7 +17,7 @@ package link_test
import ( import (
"testing" "testing"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )

View File

@ -15,10 +15,8 @@
package link package link
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"time"
"github.com/networkplumbing/go-nft/nft" "github.com/networkplumbing/go-nft/nft"
"github.com/networkplumbing/go-nft/nft/schema" "github.com/networkplumbing/go-nft/nft/schema"
@ -30,8 +28,8 @@ const (
) )
type NftConfigurer interface { type NftConfigurer interface {
Apply(*nft.Config) (*nft.Config, error) Apply(*nft.Config) error
Read(filterCommands ...string) (*nft.Config, error) Read() (*nft.Config, error)
} }
type SpoofChecker struct { type SpoofChecker struct {
@ -39,23 +37,16 @@ type SpoofChecker struct {
macAddress string macAddress string
refID string refID string
configurer NftConfigurer configurer NftConfigurer
rulestore *nft.Config
} }
type defaultNftConfigurer struct{} type defaultNftConfigurer struct{}
func (dnc defaultNftConfigurer) Apply(cfg *nft.Config) (*nft.Config, error) { func (_ defaultNftConfigurer) Apply(cfg *nft.Config) error {
const timeout = 55 * time.Second return nft.ApplyConfig(cfg)
ctxWithTimeout, cancelFunc := context.WithTimeout(context.Background(), timeout)
defer cancelFunc()
return nft.ApplyConfigEcho(ctxWithTimeout, cfg)
} }
func (dnc defaultNftConfigurer) Read(filterCommands ...string) (*nft.Config, error) { func (_ defaultNftConfigurer) Read() (*nft.Config, error) {
const timeout = 55 * time.Second return nft.ReadConfig()
ctxWithTimeout, cancelFunc := context.WithTimeout(context.Background(), timeout)
defer cancelFunc()
return nft.ReadConfigContext(ctxWithTimeout, filterCommands...)
} }
func NewSpoofChecker(iface, macAddress, refID string) *SpoofChecker { func NewSpoofChecker(iface, macAddress, refID string) *SpoofChecker {
@ -63,7 +54,7 @@ func NewSpoofChecker(iface, macAddress, refID string) *SpoofChecker {
} }
func NewSpoofCheckerWithConfigurer(iface, macAddress, refID string, configurer NftConfigurer) *SpoofChecker { func NewSpoofCheckerWithConfigurer(iface, macAddress, refID string, configurer NftConfigurer) *SpoofChecker {
return &SpoofChecker{iface, macAddress, refID, configurer, nil} return &SpoofChecker{iface, macAddress, refID, configurer}
} }
// Setup applies nftables configuration to restrict traffic // Setup applies nftables configuration to restrict traffic
@ -92,7 +83,7 @@ func (sc *SpoofChecker) Setup() error {
macChain := sc.macChain(ifaceChain.Name) macChain := sc.macChain(ifaceChain.Name)
baseConfig.AddChain(macChain) baseConfig.AddChain(macChain)
if _, err := sc.configurer.Apply(baseConfig); err != nil { if err := sc.configurer.Apply(baseConfig); err != nil {
return fmt.Errorf("failed to setup spoof-check: %v", err) return fmt.Errorf("failed to setup spoof-check: %v", err)
} }
@ -106,51 +97,37 @@ func (sc *SpoofChecker) Setup() error {
rulesConfig.AddRule(sc.matchMacRule(macChain.Name)) rulesConfig.AddRule(sc.matchMacRule(macChain.Name))
rulesConfig.AddRule(sc.dropRule(macChain.Name)) rulesConfig.AddRule(sc.dropRule(macChain.Name))
rulestore, err := sc.configurer.Apply(rulesConfig) if err := sc.configurer.Apply(rulesConfig); err != nil {
if err != nil {
return fmt.Errorf("failed to setup spoof-check: %v", err) return fmt.Errorf("failed to setup spoof-check: %v", err)
} }
sc.rulestore = rulestore
return nil return nil
} }
func (sc *SpoofChecker) findPreroutingRule(ruleToFind *schema.Rule) ([]*schema.Rule, error) {
ruleset := sc.rulestore
if ruleset == nil {
chain, err := sc.configurer.Read(listChainBridgeNatPrerouting()...)
if err != nil {
return nil, err
}
ruleset = chain
}
return ruleset.LookupRule(ruleToFind), nil
}
// Teardown removes the interface and mac-address specific chains and their rules. // Teardown removes the interface and mac-address specific chains and their rules.
// The table and base-chain are expected to survive while the base-chain rule that matches the // The table and base-chain are expected to survive while the base-chain rule that matches the
// interface is removed. // interface is removed.
func (sc *SpoofChecker) Teardown() error { func (sc *SpoofChecker) Teardown() error {
ifaceChain := sc.ifaceChain() ifaceChain := sc.ifaceChain()
expectedRuleToFind := sc.matchIfaceJumpToChainRule(preRoutingBaseChainName, ifaceChain.Name) currentConfig, ifaceMatchRuleErr := sc.configurer.Read()
// It is safer to exclude the statement matching, avoiding cases where a current statement includes if ifaceMatchRuleErr == nil {
// additional default entries (e.g. counters). expectedRuleToFind := sc.matchIfaceJumpToChainRule(preRoutingBaseChainName, ifaceChain.Name)
ruleToFindExcludingStatements := *expectedRuleToFind // It is safer to exclude the statement matching, avoiding cases where a current statement includes
ruleToFindExcludingStatements.Expr = nil // additional default entries (e.g. counters).
ruleToFindExcludingStatements := *expectedRuleToFind
rules, ifaceMatchRuleErr := sc.findPreroutingRule(&ruleToFindExcludingStatements) ruleToFindExcludingStatements.Expr = nil
if ifaceMatchRuleErr == nil && len(rules) > 0 { rules := currentConfig.LookupRule(&ruleToFindExcludingStatements)
c := nft.NewConfig() if len(rules) > 0 {
for _, rule := range rules { c := nft.NewConfig()
c.DeleteRule(rule) for _, rule := range rules {
c.DeleteRule(rule)
}
if err := sc.configurer.Apply(c); err != nil {
ifaceMatchRuleErr = fmt.Errorf("failed to delete iface match rule: %v", err)
}
} else {
fmt.Fprintf(os.Stderr, "spoofcheck/teardown: unable to detect iface match rule for deletion: %+v", expectedRuleToFind)
} }
if _, err := sc.configurer.Apply(c); err != nil {
ifaceMatchRuleErr = fmt.Errorf("failed to delete iface match rule: %v", err)
}
// Drop the cache, it should contain deleted rule(s) now
sc.rulestore = nil
} else {
fmt.Fprintf(os.Stderr, "spoofcheck/teardown: unable to detect iface match rule for deletion: %+v", expectedRuleToFind)
} }
regularChainsConfig := nft.NewConfig() regularChainsConfig := nft.NewConfig()
@ -158,7 +135,7 @@ func (sc *SpoofChecker) Teardown() error {
regularChainsConfig.DeleteChain(sc.macChain(ifaceChain.Name)) regularChainsConfig.DeleteChain(sc.macChain(ifaceChain.Name))
var regularChainsErr error var regularChainsErr error
if _, err := sc.configurer.Apply(regularChainsConfig); err != nil { if err := sc.configurer.Apply(regularChainsConfig); err != nil {
regularChainsErr = fmt.Errorf("failed to delete regular chains: %v", err) regularChainsErr = fmt.Errorf("failed to delete regular chains: %v", err)
} }
@ -218,10 +195,12 @@ func (sc *SpoofChecker) matchMacRule(chain string) *schema.Rule {
} }
func (sc *SpoofChecker) dropRule(chain string) *schema.Rule { func (sc *SpoofChecker) dropRule(chain string) *schema.Rule {
macRulesIndex := nft.NewRuleIndex()
return &schema.Rule{ return &schema.Rule{
Family: schema.FamilyBridge, Family: schema.FamilyBridge,
Table: natTableName, Table: natTableName,
Chain: chain, Chain: chain,
Index: macRulesIndex.Next(),
Expr: []schema.Statement{ Expr: []schema.Statement{
{Verdict: schema.Verdict{SimpleVerdict: schema.SimpleVerdict{Drop: true}}}, {Verdict: schema.Verdict{SimpleVerdict: schema.SimpleVerdict{Drop: true}}},
}, },
@ -229,7 +208,7 @@ func (sc *SpoofChecker) dropRule(chain string) *schema.Rule {
} }
} }
func (sc *SpoofChecker) baseChain() *schema.Chain { func (_ *SpoofChecker) baseChain() *schema.Chain {
chainPriority := -300 chainPriority := -300
return &schema.Chain{ return &schema.Chain{
Family: schema.FamilyBridge, Family: schema.FamilyBridge,
@ -251,7 +230,7 @@ func (sc *SpoofChecker) ifaceChain() *schema.Chain {
} }
} }
func (sc *SpoofChecker) macChain(ifaceChainName string) *schema.Chain { func (_ *SpoofChecker) macChain(ifaceChainName string) *schema.Chain {
macChainName := ifaceChainName + "-mac" macChainName := ifaceChainName + "-mac"
return &schema.Chain{ return &schema.Chain{
Family: schema.FamilyBridge, Family: schema.FamilyBridge,
@ -264,7 +243,3 @@ func ruleComment(id string) string {
const refIDPrefix = "macspoofchk-" const refIDPrefix = "macspoofchk-"
return refIDPrefix + id return refIDPrefix + id
} }
func listChainBridgeNatPrerouting() []string {
return []string{"chain", "bridge", natTableName, preRoutingBaseChainName}
}

View File

@ -15,11 +15,10 @@
package link_test package link_test
import ( import (
"errors"
"fmt" "fmt"
"github.com/networkplumbing/go-nft/nft" "github.com/networkplumbing/go-nft/nft"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"github.com/containernetworking/plugins/pkg/link" "github.com/containernetworking/plugins/pkg/link"
@ -114,32 +113,13 @@ var _ = Describe("spoofcheck", func() {
))) )))
}) })
}) })
Context("echo", func() {
It("succeeds, no read called", func() {
c := configurerStub{}
sc := link.NewSpoofCheckerWithConfigurer(iface, mac, id, &c)
Expect(sc.Setup()).To(Succeed())
Expect(sc.Teardown()).To(Succeed())
Expect(c.readCalled).To(BeFalse())
})
It("succeeds, fall back to config read", func() {
c := configurerStub{applyReturnNil: true}
sc := link.NewSpoofCheckerWithConfigurer(iface, mac, id, &c)
Expect(sc.Setup()).To(Succeed())
c.readConfig = c.applyConfig[0]
Expect(sc.Teardown()).To(Succeed())
Expect(c.readCalled).To(BeTrue())
})
})
}) })
func assertExpectedRegularChainsDeletionInTeardownConfig(action configurerStub) { func assertExpectedRegularChainsDeletionInTeardownConfig(action configurerStub) {
deleteRegularChainRulesJSONConfig, err := action.applyConfig[1].ToJSON() deleteRegularChainRulesJsonConfig, err := action.applyConfig[1].ToJSON()
ExpectWithOffset(1, err).NotTo(HaveOccurred()) ExpectWithOffset(1, err).NotTo(HaveOccurred())
expectedDeleteRegularChainRulesJSONConfig := ` expectedDeleteRegularChainRulesJsonConfig := `
{"nftables": [ {"nftables": [
{"delete": {"chain": { {"delete": {"chain": {
"family": "bridge", "family": "bridge",
@ -153,14 +133,14 @@ func assertExpectedRegularChainsDeletionInTeardownConfig(action configurerStub)
}}} }}}
]}` ]}`
ExpectWithOffset(1, string(deleteRegularChainRulesJSONConfig)).To(MatchJSON(expectedDeleteRegularChainRulesJSONConfig)) ExpectWithOffset(1, string(deleteRegularChainRulesJsonConfig)).To(MatchJSON(expectedDeleteRegularChainRulesJsonConfig))
} }
func assertExpectedBaseChainRuleDeletionInTeardownConfig(action configurerStub) { func assertExpectedBaseChainRuleDeletionInTeardownConfig(action configurerStub) {
deleteBaseChainRuleJSONConfig, err := action.applyConfig[0].ToJSON() deleteBaseChainRuleJsonConfig, err := action.applyConfig[0].ToJSON()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
expectedDeleteIfaceMatchRuleJSONConfig := ` expectedDeleteIfaceMatchRuleJsonConfig := `
{"nftables": [ {"nftables": [
{"delete": {"rule": { {"delete": {"rule": {
"family": "bridge", "family": "bridge",
@ -177,7 +157,7 @@ func assertExpectedBaseChainRuleDeletionInTeardownConfig(action configurerStub)
"comment": "macspoofchk-container99-net1" "comment": "macspoofchk-container99-net1"
}}} }}}
]}` ]}`
Expect(string(deleteBaseChainRuleJSONConfig)).To(MatchJSON(expectedDeleteIfaceMatchRuleJSONConfig)) Expect(string(deleteBaseChainRuleJsonConfig)).To(MatchJSON(expectedDeleteIfaceMatchRuleJsonConfig))
} }
func rowConfigWithRulesOnly() string { func rowConfigWithRulesOnly() string {
@ -274,6 +254,7 @@ func assertExpectedRulesInSetupConfig(c configurerStub) {
"comment":"macspoofchk-container99-net1"}}, "comment":"macspoofchk-container99-net1"}},
{"rule":{"family":"bridge","table":"nat","chain":"cni-br-iface-container99-net1-mac", {"rule":{"family":"bridge","table":"nat","chain":"cni-br-iface-container99-net1-mac",
"expr":[{"drop":null}], "expr":[{"drop":null}],
"index":0,
"comment":"macspoofchk-container99-net1"}} "comment":"macspoofchk-container99-net1"}}
]}` ]}`
ExpectWithOffset(1, string(jsonConfig)).To(MatchJSON(expectedConfig)) ExpectWithOffset(1, string(jsonConfig)).To(MatchJSON(expectedConfig))
@ -294,30 +275,23 @@ type configurerStub struct {
failFirstApplyConfig bool failFirstApplyConfig bool
failSecondApplyConfig bool failSecondApplyConfig bool
failReadConfig bool failReadConfig bool
applyReturnNil bool
readCalled bool
} }
func (a *configurerStub) Apply(c *nft.Config) (*nft.Config, error) { func (a *configurerStub) Apply(c *nft.Config) error {
a.applyCounter++ a.applyCounter++
if a.failFirstApplyConfig && a.applyCounter == 1 { if a.failFirstApplyConfig && a.applyCounter == 1 {
return nil, errors.New(errorFirstApplyText) return fmt.Errorf(errorFirstApplyText)
} }
if a.failSecondApplyConfig && a.applyCounter == 2 { if a.failSecondApplyConfig && a.applyCounter == 2 {
return nil, errors.New(errorSecondApplyText) return fmt.Errorf(errorSecondApplyText)
} }
a.applyConfig = append(a.applyConfig, c) a.applyConfig = append(a.applyConfig, c)
if a.applyReturnNil { return nil
return nil, nil
}
return c, nil
} }
func (a *configurerStub) Read(_ ...string) (*nft.Config, error) { func (a *configurerStub) Read() (*nft.Config, error) {
a.readCalled = true
if a.failReadConfig { if a.failReadConfig {
return nil, errors.New(errorReadText) return nil, fmt.Errorf(errorReadText)
} }
return a.readConfig, nil return a.readConfig, nil
} }

View File

@ -13,10 +13,10 @@ The `ns.Do()` method provides **partial** control over network namespaces for yo
```go ```go
err = targetNs.Do(func(hostNs ns.NetNS) error { err = targetNs.Do(func(hostNs ns.NetNS) error {
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = "dummy0"
dummy := &netlink.Dummy{ dummy := &netlink.Dummy{
LinkAttrs: linkAttrs, LinkAttrs: netlink.LinkAttrs{
Name: "dummy0",
},
} }
return netlink.LinkAdd(dummy) return netlink.LinkAdd(dummy)
}) })

View File

@ -31,10 +31,6 @@ func GetCurrentNS() (NetNS, error) {
// return an unexpected network namespace. // return an unexpected network namespace.
runtime.LockOSThread() runtime.LockOSThread()
defer runtime.UnlockOSThread() defer runtime.UnlockOSThread()
return getCurrentNSNoLock()
}
func getCurrentNSNoLock() (NetNS, error) {
return GetNS(getCurrentThreadNetNSPath()) return GetNS(getCurrentThreadNetNSPath())
} }
@ -156,54 +152,6 @@ func GetNS(nspath string) (NetNS, error) {
return &netNS{file: fd}, nil return &netNS{file: fd}, nil
} }
// Returns a new empty NetNS.
// Calling Close() let the kernel garbage collect the network namespace.
func TempNetNS() (NetNS, error) {
var tempNS NetNS
var err error
var wg sync.WaitGroup
wg.Add(1)
// Create the new namespace in a new goroutine so that if we later fail
// to switch the namespace back to the original one, we can safely
// leave the thread locked to die without a risk of the current thread
// left lingering with incorrect namespace.
go func() {
defer wg.Done()
runtime.LockOSThread()
var threadNS NetNS
// save a handle to current network namespace
threadNS, err = getCurrentNSNoLock()
if err != nil {
err = fmt.Errorf("failed to open current namespace: %v", err)
return
}
defer threadNS.Close()
// create the temporary network namespace
err = unix.Unshare(unix.CLONE_NEWNET)
if err != nil {
return
}
// get a handle to the temporary network namespace
tempNS, err = getCurrentNSNoLock()
err2 := threadNS.Set()
if err2 == nil {
// Unlock the current thread only when we successfully switched back
// to the original namespace; otherwise leave the thread locked which
// will force the runtime to scrap the current thread, that is maybe
// not as optimal but at least always safe to do.
runtime.UnlockOSThread()
}
}()
wg.Wait()
return tempNS, err
}
func (ns *netNS) Path() string { func (ns *netNS) Path() string {
return ns.file.Name() return ns.file.Name()
} }
@ -225,7 +173,7 @@ func (ns *netNS) Do(toRun func(NetNS) error) error {
} }
containedCall := func(hostNS NetNS) error { containedCall := func(hostNS NetNS) error {
threadNS, err := getCurrentNSNoLock() threadNS, err := GetCurrentNS()
if err != nil { if err != nil {
return fmt.Errorf("failed to open current netns: %v", err) return fmt.Errorf("failed to open current netns: %v", err)
} }

View File

@ -17,16 +17,16 @@ package ns_test
import ( import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"sync" "sync"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"golang.org/x/sys/unix"
"github.com/containernetworking/plugins/pkg/ns" "github.com/containernetworking/plugins/pkg/ns"
"github.com/containernetworking/plugins/pkg/testutils" "github.com/containernetworking/plugins/pkg/testutils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"golang.org/x/sys/unix"
) )
func getInodeCurNetNS() (uint64, error) { func getInodeCurNetNS() (uint64, error) {
@ -182,7 +182,7 @@ var _ = Describe("Linux namespace operations", func() {
testNsInode, err := getInodeNS(targetNetNS) testNsInode, err := getInodeNS(targetNetNS)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(testNsInode).NotTo(Equal(uint64(0))) Expect(testNsInode).NotTo(Equal(0))
Expect(testNsInode).NotTo(Equal(origNSInode)) Expect(testNsInode).NotTo(Equal(origNSInode))
}) })
@ -208,7 +208,7 @@ var _ = Describe("Linux namespace operations", func() {
}) })
It("fails when the path is not a namespace", func() { It("fails when the path is not a namespace", func() {
tempFile, err := os.CreateTemp("", "nstest") tempFile, err := ioutil.TempFile("", "nstest")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
defer tempFile.Close() defer tempFile.Close()
@ -262,7 +262,7 @@ var _ = Describe("Linux namespace operations", func() {
}) })
It("should refuse other paths", func() { It("should refuse other paths", func() {
tempFile, err := os.CreateTemp("", "nstest") tempFile, err := ioutil.TempFile("", "nstest")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
defer tempFile.Close() defer tempFile.Close()

View File

@ -15,14 +15,18 @@
package ns_test package ns_test
import ( import (
"math/rand"
"runtime" "runtime"
"testing"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo"
"github.com/onsi/ginkgo/config"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"testing"
) )
func TestNs(t *testing.T) { func TestNs(t *testing.T) {
rand.Seed(config.GinkgoConfig.RandomSeed)
runtime.LockOSThread() runtime.LockOSThread()
RegisterFailHandler(Fail) RegisterFailHandler(Fail)

View File

@ -21,7 +21,7 @@ type BadReader struct {
Error error Error error
} }
func (r *BadReader) Read(_ []byte) (int, error) { func (r *BadReader) Read(buffer []byte) (int, error) {
if r.Error != nil { if r.Error != nil {
return 0, r.Error return 0, r.Error
} }

View File

@ -15,7 +15,7 @@
package testutils package testutils
import ( import (
"io" "io/ioutil"
"os" "os"
"github.com/containernetworking/cni/pkg/skel" "github.com/containernetworking/cni/pkg/skel"
@ -29,7 +29,6 @@ func envCleanup() {
os.Unsetenv("CNI_NETNS") os.Unsetenv("CNI_NETNS")
os.Unsetenv("CNI_IFNAME") os.Unsetenv("CNI_IFNAME")
os.Unsetenv("CNI_CONTAINERID") os.Unsetenv("CNI_CONTAINERID")
os.Unsetenv("CNI_NETNS_OVERRIDE")
} }
func CmdAdd(cniNetns, cniContainerID, cniIfname string, conf []byte, f func() error) (types.Result, []byte, error) { func CmdAdd(cniNetns, cniContainerID, cniIfname string, conf []byte, f func() error) (types.Result, []byte, error) {
@ -38,7 +37,6 @@ func CmdAdd(cniNetns, cniContainerID, cniIfname string, conf []byte, f func() er
os.Setenv("CNI_NETNS", cniNetns) os.Setenv("CNI_NETNS", cniNetns)
os.Setenv("CNI_IFNAME", cniIfname) os.Setenv("CNI_IFNAME", cniIfname)
os.Setenv("CNI_CONTAINERID", cniContainerID) os.Setenv("CNI_CONTAINERID", cniContainerID)
os.Setenv("CNI_NETNS_OVERRIDE", "1")
defer envCleanup() defer envCleanup()
// Redirect stdout to capture plugin result // Redirect stdout to capture plugin result
@ -54,7 +52,7 @@ func CmdAdd(cniNetns, cniContainerID, cniIfname string, conf []byte, f func() er
var out []byte var out []byte
if err == nil { if err == nil {
out, err = io.ReadAll(r) out, err = ioutil.ReadAll(r)
} }
os.Stdout = oldStdout os.Stdout = oldStdout
@ -83,20 +81,19 @@ func CmdAddWithArgs(args *skel.CmdArgs, f func() error) (types.Result, []byte, e
return CmdAdd(args.Netns, args.ContainerID, args.IfName, args.StdinData, f) return CmdAdd(args.Netns, args.ContainerID, args.IfName, args.StdinData, f)
} }
func CmdCheck(cniNetns, cniContainerID, cniIfname string, f func() error) error { func CmdCheck(cniNetns, cniContainerID, cniIfname string, conf []byte, f func() error) error {
os.Setenv("CNI_COMMAND", "CHECK") os.Setenv("CNI_COMMAND", "CHECK")
os.Setenv("CNI_PATH", os.Getenv("PATH")) os.Setenv("CNI_PATH", os.Getenv("PATH"))
os.Setenv("CNI_NETNS", cniNetns) os.Setenv("CNI_NETNS", cniNetns)
os.Setenv("CNI_IFNAME", cniIfname) os.Setenv("CNI_IFNAME", cniIfname)
os.Setenv("CNI_CONTAINERID", cniContainerID) os.Setenv("CNI_CONTAINERID", cniContainerID)
os.Setenv("CNI_NETNS_OVERRIDE", "1")
defer envCleanup() defer envCleanup()
return f() return f()
} }
func CmdCheckWithArgs(args *skel.CmdArgs, f func() error) error { func CmdCheckWithArgs(args *skel.CmdArgs, f func() error) error {
return CmdCheck(args.Netns, args.ContainerID, args.IfName, f) return CmdCheck(args.Netns, args.ContainerID, args.IfName, args.StdinData, f)
} }
func CmdDel(cniNetns, cniContainerID, cniIfname string, f func() error) error { func CmdDel(cniNetns, cniContainerID, cniIfname string, f func() error) error {
@ -105,7 +102,6 @@ func CmdDel(cniNetns, cniContainerID, cniIfname string, f func() error) error {
os.Setenv("CNI_NETNS", cniNetns) os.Setenv("CNI_NETNS", cniNetns)
os.Setenv("CNI_IFNAME", cniIfname) os.Setenv("CNI_IFNAME", cniIfname)
os.Setenv("CNI_CONTAINERID", cniContainerID) os.Setenv("CNI_CONTAINERID", cniContainerID)
os.Setenv("CNI_NETNS_OVERRIDE", "1")
defer envCleanup() defer envCleanup()
return f() return f()
@ -114,12 +110,3 @@ func CmdDel(cniNetns, cniContainerID, cniIfname string, f func() error) error {
func CmdDelWithArgs(args *skel.CmdArgs, f func() error) error { func CmdDelWithArgs(args *skel.CmdArgs, f func() error) error {
return CmdDel(args.Netns, args.ContainerID, args.IfName, f) return CmdDel(args.Netns, args.ContainerID, args.IfName, f)
} }
func CmdStatus(f func() error) error {
os.Setenv("CNI_COMMAND", "STATUS")
os.Setenv("CNI_PATH", os.Getenv("PATH"))
os.Setenv("CNI_NETNS_OVERRIDE", "1")
defer envCleanup()
return f()
}

View File

@ -16,6 +16,7 @@ package testutils
import ( import (
"fmt" "fmt"
"io/ioutil"
"os" "os"
"strings" "strings"
@ -27,7 +28,7 @@ import (
// an error if any occurs while creating/writing the file. It is the caller's // an error if any occurs while creating/writing the file. It is the caller's
// responsibility to remove the file. // responsibility to remove the file.
func TmpResolvConf(dnsConf types.DNS) (string, error) { func TmpResolvConf(dnsConf types.DNS) (string, error) {
f, err := os.CreateTemp("", "cni_test_resolv.conf") f, err := ioutil.TempFile("", "cni_test_resolv.conf")
if err != nil { if err != nil {
return "", fmt.Errorf("failed to get temp file for CNI test resolv.conf: %v", err) return "", fmt.Errorf("failed to get temp file for CNI test resolv.conf: %v", err)
} }

View File

@ -2,12 +2,12 @@ package main_test
import ( import (
"fmt" "fmt"
"io" "io/ioutil"
"net" "net"
"os/exec" "os/exec"
"strings" "strings"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"github.com/onsi/gomega/gbytes" "github.com/onsi/gomega/gbytes"
"github.com/onsi/gomega/gexec" "github.com/onsi/gomega/gexec"
@ -74,7 +74,7 @@ var _ = Describe("Echosvr", func() {
defer conn.Close() defer conn.Close()
fmt.Fprintf(conn, "hello\n") fmt.Fprintf(conn, "hello\n")
Expect(io.ReadAll(conn)).To(Equal([]byte("hello"))) Expect(ioutil.ReadAll(conn)).To(Equal([]byte("hello")))
}) })
}) })
@ -86,7 +86,7 @@ var _ = Describe("Echosvr", func() {
It("connects successfully using echo client", func() { It("connects successfully using echo client", func() {
Eventually(session.Out).Should(gbytes.Say("\n")) Eventually(session.Out).Should(gbytes.Say("\n"))
serverAddress := strings.TrimSpace(string(session.Out.Contents())) serverAddress := strings.TrimSpace(string(session.Out.Contents()))
fmt.Println("Server address", serverAddress) fmt.Println("Server address", string(serverAddress))
cmd := exec.Command(clientBinaryPath, "-target", serverAddress, "-message", "hello") cmd := exec.Command(clientBinaryPath, "-target", serverAddress, "-message", "hello")
clientSession, err := gexec.Start(cmd, GinkgoWriter, GinkgoWriter) clientSession, err := gexec.Start(cmd, GinkgoWriter, GinkgoWriter)

View File

@ -1,10 +1,10 @@
package main_test package main_test
import ( import (
"testing" . "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"testing"
) )
func TestEchosvr(t *testing.T) { func TestEchosvr(t *testing.T) {

View File

@ -1,10 +1,9 @@
// Echosvr is a simple TCP echo server // Echosvr is a simple TCP echo server
// //
// It prints its listen address on stdout // It prints its listen address on stdout
// // 127.0.0.1:xxxxx
// 127.0.0.1:xxxxx // A test should wait for this line, parse it
// A test should wait for this line, parse it // and may then attempt to connect.
// and may then attempt to connect.
package main package main
import ( import (
@ -44,13 +43,11 @@ func main() {
// Start UDP server // Start UDP server
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%s", port)) addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%s", port))
if err != nil { if err != nil {
log.Printf("Error from net.ResolveUDPAddr(): %s", err) log.Fatalf("Error from net.ResolveUDPAddr(): %s", err)
return
} }
sock, err := net.ListenUDP("udp", addr) sock, err := net.ListenUDP("udp", addr)
if err != nil { if err != nil {
log.Printf("Error from ListenUDP(): %s", err) log.Fatalf("Error from ListenUDP(): %s", err)
return
} }
defer sock.Close() defer sock.Close()
@ -58,11 +55,10 @@ func main() {
for { for {
n, addr, err := sock.ReadFrom(buffer) n, addr, err := sock.ReadFrom(buffer)
if err != nil { if err != nil {
log.Printf("Error from ReadFrom(): %s", err) log.Fatalf("Error from ReadFrom(): %s", err)
return
} }
sock.SetWriteDeadline(time.Now().Add(1 * time.Minute)) sock.SetWriteDeadline(time.Now().Add(1 * time.Minute))
_, err = sock.WriteTo(buffer[0:n], addr) n, err = sock.WriteTo(buffer[0:n], addr)
if err != nil { if err != nil {
return return
} }

View File

@ -24,9 +24,8 @@ import (
"sync" "sync"
"syscall" "syscall"
"golang.org/x/sys/unix"
"github.com/containernetworking/plugins/pkg/ns" "github.com/containernetworking/plugins/pkg/ns"
"golang.org/x/sys/unix"
) )
func getNsRunDir() string { func getNsRunDir() string {
@ -50,6 +49,7 @@ func getNsRunDir() string {
// Creates a new persistent (bind-mounted) network namespace and returns an object // Creates a new persistent (bind-mounted) network namespace and returns an object
// representing that namespace, without switching to it. // representing that namespace, without switching to it.
func NewNS() (ns.NetNS, error) { func NewNS() (ns.NetNS, error) {
nsRunDir := getNsRunDir() nsRunDir := getNsRunDir()
b := make([]byte, 16) b := make([]byte, 16)
@ -61,7 +61,7 @@ func NewNS() (ns.NetNS, error) {
// Create the directory for mounting network namespaces // Create the directory for mounting network namespaces
// This needs to be a shared mountpoint in case it is mounted in to // This needs to be a shared mountpoint in case it is mounted in to
// other namespaces (containers) // other namespaces (containers)
err = os.MkdirAll(nsRunDir, 0o755) err = os.MkdirAll(nsRunDir, 0755)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -19,7 +19,7 @@ import (
) )
// AllSpecVersions contains all CNI spec version numbers // AllSpecVersions contains all CNI spec version numbers
var AllSpecVersions = [...]string{"0.1.0", "0.2.0", "0.3.0", "0.3.1", "0.4.0", "1.0.0", "1.1.0"} var AllSpecVersions = [...]string{"0.1.0", "0.2.0", "0.3.0", "0.3.1", "0.4.0", "1.0.0"}
// SpecVersionHasIPVersion returns true if the given CNI specification version // SpecVersionHasIPVersion returns true if the given CNI specification version
// includes the "version" field in the IP address elements // includes the "version" field in the IP address elements
@ -39,13 +39,6 @@ func SpecVersionHasCHECK(ver string) bool {
return ok return ok
} }
// SpecVersionHasSTATUS returns true if the given CNI specification version
// supports the STATUS command
func SpecVersionHasSTATUS(ver string) bool {
ok, _ := version.GreaterThanOrEqualTo(ver, "1.1.0")
return ok
}
// SpecVersionHasChaining returns true if the given CNI specification version // SpecVersionHasChaining returns true if the given CNI specification version
// supports plugin chaining // supports plugin chaining
func SpecVersionHasChaining(ver string) bool { func SpecVersionHasChaining(ver string) bool {

View File

@ -51,7 +51,7 @@ func DeleteConntrackEntriesForDstIP(dstIP string, protocol uint8) error {
filter.AddIP(netlink.ConntrackOrigDstIP, ip) filter.AddIP(netlink.ConntrackOrigDstIP, ip)
filter.AddProtocol(protocol) filter.AddProtocol(protocol)
_, err := netlink.ConntrackDeleteFilters(netlink.ConntrackTable, family, filter) _, err := netlink.ConntrackDeleteFilter(netlink.ConntrackTable, family, filter)
if err != nil { if err != nil {
return fmt.Errorf("error deleting connection tracking state for protocol: %d IP: %s, error: %v", protocol, ip, err) return fmt.Errorf("error deleting connection tracking state for protocol: %d IP: %s, error: %v", protocol, ip, err)
} }
@ -65,7 +65,7 @@ func DeleteConntrackEntriesForDstPort(port uint16, protocol uint8, family netlin
filter.AddProtocol(protocol) filter.AddProtocol(protocol)
filter.AddPort(netlink.ConntrackOrigDstPort, port) filter.AddPort(netlink.ConntrackOrigDstPort, port)
_, err := netlink.ConntrackDeleteFilters(netlink.ConntrackTable, family, filter) _, err := netlink.ConntrackDeleteFilter(netlink.ConntrackTable, family, filter)
if err != nil { if err != nil {
return fmt.Errorf("error deleting connection tracking state for protocol: %d Port: %d, error: %v", protocol, port, err) return fmt.Errorf("error deleting connection tracking state for protocol: %d Port: %d, error: %v", protocol, port, err)
} }

View File

@ -29,9 +29,9 @@ func EnsureChain(ipt *iptables.IPTables, table, chain string) error {
if ipt == nil { if ipt == nil {
return errors.New("failed to ensure iptable chain: IPTables was nil") return errors.New("failed to ensure iptable chain: IPTables was nil")
} }
exists, err := ipt.ChainExists(table, chain) exists, err := ChainExists(ipt, table, chain)
if err != nil { if err != nil {
return fmt.Errorf("failed to check iptables chain existence: %v", err) return fmt.Errorf("failed to list iptables chains: %v", err)
} }
if !exists { if !exists {
err = ipt.NewChain(table, chain) err = ipt.NewChain(table, chain)
@ -45,6 +45,24 @@ func EnsureChain(ipt *iptables.IPTables, table, chain string) error {
return nil return nil
} }
// ChainExists checks whether an iptables chain exists.
func ChainExists(ipt *iptables.IPTables, table, chain string) (bool, error) {
if ipt == nil {
return false, errors.New("failed to check iptable chain: IPTables was nil")
}
chains, err := ipt.ListChains(table)
if err != nil {
return false, err
}
for _, ch := range chains {
if ch == chain {
return true, nil
}
}
return false, nil
}
// DeleteRule idempotently delete the iptables rule in the specified table/chain. // DeleteRule idempotently delete the iptables rule in the specified table/chain.
// It does not return an error if the referring chain doesn't exist // It does not return an error if the referring chain doesn't exist
func DeleteRule(ipt *iptables.IPTables, table, chain string, rulespec ...string) error { func DeleteRule(ipt *iptables.IPTables, table, chain string, rulespec ...string) error {
@ -115,6 +133,7 @@ func InsertUnique(ipt *iptables.IPTables, table, chain string, prepend bool, rul
if prepend { if prepend {
return ipt.Insert(table, chain, 1, rule...) return ipt.Insert(table, chain, 1, rule...)
} else {
return ipt.Append(table, chain, rule...)
} }
return ipt.Append(table, chain, rule...)
} }

View File

@ -19,12 +19,11 @@ import (
"math/rand" "math/rand"
"runtime" "runtime"
"github.com/coreos/go-iptables/iptables"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/containernetworking/plugins/pkg/ns" "github.com/containernetworking/plugins/pkg/ns"
"github.com/containernetworking/plugins/pkg/testutils" "github.com/containernetworking/plugins/pkg/testutils"
"github.com/coreos/go-iptables/iptables"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
) )
const TABLE = "filter" // We'll monkey around here const TABLE = "filter" // We'll monkey around here
@ -35,6 +34,7 @@ var _ = Describe("chain tests", func() {
var cleanup func() var cleanup func()
BeforeEach(func() { BeforeEach(func() {
// Save a reference to the original namespace, // Save a reference to the original namespace,
// Add a new NS // Add a new NS
currNs, err := ns.GetCurrentNS() currNs, err := ns.GetCurrentNS()
@ -60,6 +60,7 @@ var _ = Describe("chain tests", func() {
ipt.DeleteChain(TABLE, testChain) ipt.DeleteChain(TABLE, testChain)
currNs.Set() currNs.Set()
} }
}) })
AfterEach(func() { AfterEach(func() {
@ -92,4 +93,5 @@ var _ = Describe("chain tests", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
}) })
}) })

View File

@ -1,46 +0,0 @@
// Copyright 2023 CNI authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"github.com/coreos/go-iptables/iptables"
"sigs.k8s.io/knftables"
)
// SupportsIPTables tests whether the system supports using netfilter via the iptables API
// (whether via "iptables-legacy" or "iptables-nft"). (Note that this returns true if it
// is *possible* to use iptables; it does not test whether any other components on the
// system are *actually* using iptables.)
func SupportsIPTables() bool {
ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return false
}
// We don't care whether the chain actually exists, only whether we can *check*
// whether it exists.
_, err = ipt.ChainExists("filter", "INPUT")
return err == nil
}
// SupportsNFTables tests whether the system supports using netfilter via the nftables API
// (ie, not via "iptables-nft"). (Note that this returns true if it is *possible* to use
// nftables; it does not test whether any other components on the system are *actually*
// using nftables.)
func SupportsNFTables() bool {
// knftables.New() does sanity checks so we don't need any further test like in
// the iptables case.
_, err := knftables.New(knftables.IPv4Family, "supports_nftables_test")
return err == nil
}

View File

@ -1,52 +0,0 @@
// Copyright 2023 CNI authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package utils
import (
"os"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("netfilter support", func() {
When("it is available", func() {
It("reports that iptables is supported", func() {
Expect(SupportsIPTables()).To(BeTrue(), "This test should only fail if iptables is not available, but the test suite as a whole requires it to be available.")
})
It("reports that nftables is supported", func() {
Expect(SupportsNFTables()).To(BeTrue(), "This test should only fail if nftables is not available, but the test suite as a whole requires it to be available.")
})
})
// These are Serial because os.Setenv has process-wide effect
When("it is not available", Serial, func() {
var origPath string
BeforeEach(func() {
origPath = os.Getenv("PATH")
os.Setenv("PATH", "/does-not-exist")
})
AfterEach(func() {
os.Setenv("PATH", origPath)
})
It("reports that iptables is not supported", func() {
Expect(SupportsIPTables()).To(BeFalse(), "found iptables outside of PATH??")
})
It("reports that nftables is not supported", func() {
Expect(SupportsNFTables()).To(BeFalse(), "found nftables outside of PATH??")
})
})
})

View File

@ -16,7 +16,7 @@ package sysctl
import ( import (
"fmt" "fmt"
"os" "io/ioutil"
"path/filepath" "path/filepath"
"strings" "strings"
) )
@ -36,7 +36,7 @@ func Sysctl(name string, params ...string) (string, error) {
func getSysctl(name string) (string, error) { func getSysctl(name string) (string, error) {
fullName := filepath.Join("/proc/sys", toNormalName(name)) fullName := filepath.Join("/proc/sys", toNormalName(name))
data, err := os.ReadFile(fullName) data, err := ioutil.ReadFile(fullName)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -46,7 +46,7 @@ func getSysctl(name string) (string, error) {
func setSysctl(name, value string) (string, error) { func setSysctl(name, value string) (string, error) {
fullName := filepath.Join("/proc/sys", toNormalName(name)) fullName := filepath.Join("/proc/sys", toNormalName(name))
if err := os.WriteFile(fullName, []byte(value), 0o644); err != nil { if err := ioutil.WriteFile(fullName, []byte(value), 0644); err != nil {
return "", err return "", err
} }

View File

@ -20,13 +20,12 @@ import (
"runtime" "runtime"
"strings" "strings"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/vishvananda/netlink"
"github.com/containernetworking/plugins/pkg/ns" "github.com/containernetworking/plugins/pkg/ns"
"github.com/containernetworking/plugins/pkg/testutils" "github.com/containernetworking/plugins/pkg/testutils"
"github.com/containernetworking/plugins/pkg/utils/sysctl" "github.com/containernetworking/plugins/pkg/utils/sysctl"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/vishvananda/netlink"
) )
const ( const (
@ -38,7 +37,8 @@ var _ = Describe("Sysctl tests", func() {
var testIfaceName string var testIfaceName string
var cleanup func() var cleanup func()
beforeEach := func() { BeforeEach(func() {
// Save a reference to the original namespace, // Save a reference to the original namespace,
// Add a new NS // Add a new NS
currNs, err := ns.GetCurrentNS() currNs, err := ns.GetCurrentNS()
@ -48,11 +48,11 @@ var _ = Describe("Sysctl tests", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
testIfaceName = fmt.Sprintf("cnitest.%d", rand.Intn(100000)) testIfaceName = fmt.Sprintf("cnitest.%d", rand.Intn(100000))
testLinkAttrs := netlink.NewLinkAttrs()
testLinkAttrs.Name = testIfaceName
testLinkAttrs.Namespace = netlink.NsFd(int(testNs.Fd()))
testIface := &netlink.Dummy{ testIface := &netlink.Dummy{
LinkAttrs: testLinkAttrs, LinkAttrs: netlink.LinkAttrs{
Name: testIfaceName,
Namespace: netlink.NsFd(int(testNs.Fd())),
},
} }
err = netlink.LinkAdd(testIface) err = netlink.LinkAdd(testIface)
@ -66,7 +66,8 @@ var _ = Describe("Sysctl tests", func() {
netlink.LinkDel(testIface) netlink.LinkDel(testIface)
currNs.Set() currNs.Set()
} }
}
})
AfterEach(func() { AfterEach(func() {
cleanup() cleanup()
@ -74,8 +75,7 @@ var _ = Describe("Sysctl tests", func() {
Describe("Sysctl", func() { Describe("Sysctl", func() {
It("reads keys with dot separators", func() { It("reads keys with dot separators", func() {
beforeEach() sysctlIfaceName := strings.Replace(testIfaceName, ".", "/", -1)
sysctlIfaceName := strings.ReplaceAll(testIfaceName, ".", "/")
sysctlKey := fmt.Sprintf(sysctlDotKeyTemplate, sysctlIfaceName) sysctlKey := fmt.Sprintf(sysctlDotKeyTemplate, sysctlIfaceName)
_, err := sysctl.Sysctl(sysctlKey) _, err := sysctl.Sysctl(sysctlKey)
@ -85,7 +85,6 @@ var _ = Describe("Sysctl tests", func() {
Describe("Sysctl", func() { Describe("Sysctl", func() {
It("reads keys with slash separators", func() { It("reads keys with slash separators", func() {
beforeEach()
sysctlKey := fmt.Sprintf(sysctlSlashKeyTemplate, testIfaceName) sysctlKey := fmt.Sprintf(sysctlSlashKeyTemplate, testIfaceName)
_, err := sysctl.Sysctl(sysctlKey) _, err := sysctl.Sysctl(sysctlKey)
@ -95,8 +94,7 @@ var _ = Describe("Sysctl tests", func() {
Describe("Sysctl", func() { Describe("Sysctl", func() {
It("writes keys with dot separators", func() { It("writes keys with dot separators", func() {
beforeEach() sysctlIfaceName := strings.Replace(testIfaceName, ".", "/", -1)
sysctlIfaceName := strings.ReplaceAll(testIfaceName, ".", "/")
sysctlKey := fmt.Sprintf(sysctlDotKeyTemplate, sysctlIfaceName) sysctlKey := fmt.Sprintf(sysctlDotKeyTemplate, sysctlIfaceName)
_, err := sysctl.Sysctl(sysctlKey, "1") _, err := sysctl.Sysctl(sysctlKey, "1")
@ -106,11 +104,11 @@ var _ = Describe("Sysctl tests", func() {
Describe("Sysctl", func() { Describe("Sysctl", func() {
It("writes keys with slash separators", func() { It("writes keys with slash separators", func() {
beforeEach()
sysctlKey := fmt.Sprintf(sysctlSlashKeyTemplate, testIfaceName) sysctlKey := fmt.Sprintf(sysctlSlashKeyTemplate, testIfaceName)
_, err := sysctl.Sysctl(sysctlKey, "1") _, err := sysctl.Sysctl(sysctlKey, "1")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
}) })
}) })

View File

@ -17,7 +17,7 @@ package sysctl_test
import ( import (
"testing" "testing"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )

View File

@ -15,10 +15,10 @@
package utils_test package utils_test
import ( import (
"testing" . "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"testing"
) )
func TestUtils(t *testing.T) { func TestUtils(t *testing.T) {

View File

@ -18,7 +18,7 @@ import (
"fmt" "fmt"
"strings" "strings"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -26,29 +26,29 @@ var _ = Describe("Utils", func() {
Describe("FormatChainName", func() { Describe("FormatChainName", func() {
It("must format a short name", func() { It("must format a short name", func() {
chain := FormatChainName("test", "1234") chain := FormatChainName("test", "1234")
Expect(chain).To(HaveLen(maxChainLength)) Expect(len(chain)).To(Equal(maxChainLength))
Expect(chain).To(Equal("CNI-2bbe0c48b91a7d1b8a6753a8")) Expect(chain).To(Equal("CNI-2bbe0c48b91a7d1b8a6753a8"))
}) })
It("must truncate a long name", func() { It("must truncate a long name", func() {
chain := FormatChainName("testalongnamethatdoesnotmakesense", "1234") chain := FormatChainName("testalongnamethatdoesnotmakesense", "1234")
Expect(chain).To(HaveLen(maxChainLength)) Expect(len(chain)).To(Equal(maxChainLength))
Expect(chain).To(Equal("CNI-374f33fe84ab0ed84dcdebe3")) Expect(chain).To(Equal("CNI-374f33fe84ab0ed84dcdebe3"))
}) })
It("must be predictable", func() { It("must be predictable", func() {
chain1 := FormatChainName("testalongnamethatdoesnotmakesense", "1234") chain1 := FormatChainName("testalongnamethatdoesnotmakesense", "1234")
chain2 := FormatChainName("testalongnamethatdoesnotmakesense", "1234") chain2 := FormatChainName("testalongnamethatdoesnotmakesense", "1234")
Expect(chain1).To(HaveLen(maxChainLength)) Expect(len(chain1)).To(Equal(maxChainLength))
Expect(chain2).To(HaveLen(maxChainLength)) Expect(len(chain2)).To(Equal(maxChainLength))
Expect(chain1).To(Equal(chain2)) Expect(chain1).To(Equal(chain2))
}) })
It("must change when a character changes", func() { It("must change when a character changes", func() {
chain1 := FormatChainName("testalongnamethatdoesnotmakesense", "1234") chain1 := FormatChainName("testalongnamethatdoesnotmakesense", "1234")
chain2 := FormatChainName("testalongnamethatdoesnotmakesense", "1235") chain2 := FormatChainName("testalongnamethatdoesnotmakesense", "1235")
Expect(chain1).To(HaveLen(maxChainLength)) Expect(len(chain1)).To(Equal(maxChainLength))
Expect(chain2).To(HaveLen(maxChainLength)) Expect(len(chain2)).To(Equal(maxChainLength))
Expect(chain1).To(Equal("CNI-374f33fe84ab0ed84dcdebe3")) Expect(chain1).To(Equal("CNI-374f33fe84ab0ed84dcdebe3"))
Expect(chain1).NotTo(Equal(chain2)) Expect(chain1).NotTo(Equal(chain2))
}) })
@ -57,35 +57,35 @@ var _ = Describe("Utils", func() {
Describe("MustFormatChainNameWithPrefix", func() { Describe("MustFormatChainNameWithPrefix", func() {
It("generates a chain name with a prefix", func() { It("generates a chain name with a prefix", func() {
chain := MustFormatChainNameWithPrefix("test", "1234", "PREFIX-") chain := MustFormatChainNameWithPrefix("test", "1234", "PREFIX-")
Expect(chain).To(HaveLen(maxChainLength)) Expect(len(chain)).To(Equal(maxChainLength))
Expect(chain).To(Equal("CNI-PREFIX-2bbe0c48b91a7d1b8")) Expect(chain).To(Equal("CNI-PREFIX-2bbe0c48b91a7d1b8"))
}) })
It("must format a short name", func() { It("must format a short name", func() {
chain := MustFormatChainNameWithPrefix("test", "1234", "PREFIX-") chain := MustFormatChainNameWithPrefix("test", "1234", "PREFIX-")
Expect(chain).To(HaveLen(maxChainLength)) Expect(len(chain)).To(Equal(maxChainLength))
Expect(chain).To(Equal("CNI-PREFIX-2bbe0c48b91a7d1b8")) Expect(chain).To(Equal("CNI-PREFIX-2bbe0c48b91a7d1b8"))
}) })
It("must truncate a long name", func() { It("must truncate a long name", func() {
chain := MustFormatChainNameWithPrefix("testalongnamethatdoesnotmakesense", "1234", "PREFIX-") chain := MustFormatChainNameWithPrefix("testalongnamethatdoesnotmakesense", "1234", "PREFIX-")
Expect(chain).To(HaveLen(maxChainLength)) Expect(len(chain)).To(Equal(maxChainLength))
Expect(chain).To(Equal("CNI-PREFIX-374f33fe84ab0ed84")) Expect(chain).To(Equal("CNI-PREFIX-374f33fe84ab0ed84"))
}) })
It("must be predictable", func() { It("must be predictable", func() {
chain1 := MustFormatChainNameWithPrefix("testalongnamethatdoesnotmakesense", "1234", "PREFIX-") chain1 := MustFormatChainNameWithPrefix("testalongnamethatdoesnotmakesense", "1234", "PREFIX-")
chain2 := MustFormatChainNameWithPrefix("testalongnamethatdoesnotmakesense", "1234", "PREFIX-") chain2 := MustFormatChainNameWithPrefix("testalongnamethatdoesnotmakesense", "1234", "PREFIX-")
Expect(chain1).To(HaveLen(maxChainLength)) Expect(len(chain1)).To(Equal(maxChainLength))
Expect(chain2).To(HaveLen(maxChainLength)) Expect(len(chain2)).To(Equal(maxChainLength))
Expect(chain1).To(Equal(chain2)) Expect(chain1).To(Equal(chain2))
}) })
It("must change when a character changes", func() { It("must change when a character changes", func() {
chain1 := MustFormatChainNameWithPrefix("testalongnamethatdoesnotmakesense", "1234", "PREFIX-") chain1 := MustFormatChainNameWithPrefix("testalongnamethatdoesnotmakesense", "1234", "PREFIX-")
chain2 := MustFormatChainNameWithPrefix("testalongnamethatdoesnotmakesense", "1235", "PREFIX-") chain2 := MustFormatChainNameWithPrefix("testalongnamethatdoesnotmakesense", "1235", "PREFIX-")
Expect(chain1).To(HaveLen(maxChainLength)) Expect(len(chain1)).To(Equal(maxChainLength))
Expect(chain2).To(HaveLen(maxChainLength)) Expect(len(chain2)).To(Equal(maxChainLength))
Expect(chain1).To(Equal("CNI-PREFIX-374f33fe84ab0ed84")) Expect(chain1).To(Equal("CNI-PREFIX-374f33fe84ab0ed84"))
Expect(chain1).NotTo(Equal(chain2)) Expect(chain1).NotTo(Equal(chain2))
}) })
@ -161,4 +161,5 @@ var _ = Describe("Utils", func() {
) )
}) })
}) })
}) })

135
plugins/ipam/dhcp/client.go Normal file
View File

@ -0,0 +1,135 @@
// Copyright 2021 CNI authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"github.com/d2g/dhcp4"
"github.com/d2g/dhcp4client"
)
const (
MaxDHCPLen = 576
)
//Send the Discovery Packet to the Broadcast Channel
func DhcpSendDiscoverPacket(c *dhcp4client.Client, options dhcp4.Options) (dhcp4.Packet, error) {
discoveryPacket := c.DiscoverPacket()
for opt, data := range options {
discoveryPacket.AddOption(opt, data)
}
discoveryPacket.PadToMinSize()
return discoveryPacket, c.SendPacket(discoveryPacket)
}
//Send Request Based On the offer Received.
func DhcpSendRequest(c *dhcp4client.Client, options dhcp4.Options, offerPacket *dhcp4.Packet) (dhcp4.Packet, error) {
requestPacket := c.RequestPacket(offerPacket)
for opt, data := range options {
requestPacket.AddOption(opt, data)
}
requestPacket.PadToMinSize()
return requestPacket, c.SendPacket(requestPacket)
}
//Send Decline to the received acknowledgement.
func DhcpSendDecline(c *dhcp4client.Client, acknowledgementPacket *dhcp4.Packet, options dhcp4.Options) (dhcp4.Packet, error) {
declinePacket := c.DeclinePacket(acknowledgementPacket)
for opt, data := range options {
declinePacket.AddOption(opt, data)
}
declinePacket.PadToMinSize()
return declinePacket, c.SendPacket(declinePacket)
}
//Lets do a Full DHCP Request.
func DhcpRequest(c *dhcp4client.Client, options dhcp4.Options) (bool, dhcp4.Packet, error) {
discoveryPacket, err := DhcpSendDiscoverPacket(c, options)
if err != nil {
return false, discoveryPacket, err
}
offerPacket, err := c.GetOffer(&discoveryPacket)
if err != nil {
return false, offerPacket, err
}
requestPacket, err := DhcpSendRequest(c, options, &offerPacket)
if err != nil {
return false, requestPacket, err
}
acknowledgement, err := c.GetAcknowledgement(&requestPacket)
if err != nil {
return false, acknowledgement, err
}
acknowledgementOptions := acknowledgement.ParseOptions()
if dhcp4.MessageType(acknowledgementOptions[dhcp4.OptionDHCPMessageType][0]) != dhcp4.ACK {
return false, acknowledgement, nil
}
return true, acknowledgement, nil
}
//Renew a lease backed on the Acknowledgement Packet.
//Returns Successful, The AcknoledgementPacket, Any Errors
func DhcpRenew(c *dhcp4client.Client, acknowledgement dhcp4.Packet, options dhcp4.Options) (bool, dhcp4.Packet, error) {
renewRequest := c.RenewalRequestPacket(&acknowledgement)
for opt, data := range options {
renewRequest.AddOption(opt, data)
}
renewRequest.PadToMinSize()
err := c.SendPacket(renewRequest)
if err != nil {
return false, renewRequest, err
}
newAcknowledgement, err := c.GetAcknowledgement(&renewRequest)
if err != nil {
return false, newAcknowledgement, err
}
newAcknowledgementOptions := newAcknowledgement.ParseOptions()
if dhcp4.MessageType(newAcknowledgementOptions[dhcp4.OptionDHCPMessageType][0]) != dhcp4.ACK {
return false, newAcknowledgement, nil
}
return true, newAcknowledgement, nil
}
//Release a lease backed on the Acknowledgement Packet.
//Returns Any Errors
func DhcpRelease(c *dhcp4client.Client, acknowledgement dhcp4.Packet, options dhcp4.Options) error {
release := c.ReleasePacket(&acknowledgement)
for opt, data := range options {
release.AddOption(opt, data)
}
release.PadToMinSize()
return c.SendPacket(release)
}

View File

@ -15,50 +15,45 @@
package main package main
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"os" "os"
"os/signal"
"path/filepath" "path/filepath"
"runtime" "runtime"
"sync" "sync"
"syscall"
"time" "time"
"github.com/coreos/go-systemd/v22/activation"
"github.com/containernetworking/cni/pkg/skel" "github.com/containernetworking/cni/pkg/skel"
current "github.com/containernetworking/cni/pkg/types/100" current "github.com/containernetworking/cni/pkg/types/100"
"github.com/coreos/go-systemd/v22/activation"
) )
var errNoMoreTries = errors.New("no more tries") var errNoMoreTries = errors.New("no more tries")
type DHCP struct { type DHCP struct {
mux sync.Mutex mux sync.Mutex
leases map[string]*DHCPLease leases map[string]*DHCPLease
hostNetnsPrefix string hostNetnsPrefix string
clientTimeout time.Duration clientTimeout time.Duration
clientResendMax time.Duration clientResendMax time.Duration
clientResendTimeout time.Duration broadcast bool
broadcast bool
} }
func newDHCP(clientTimeout, clientResendMax time.Duration, resendTimeout time.Duration) *DHCP { func newDHCP(clientTimeout, clientResendMax time.Duration) *DHCP {
return &DHCP{ return &DHCP{
leases: make(map[string]*DHCPLease), leases: make(map[string]*DHCPLease),
clientTimeout: clientTimeout, clientTimeout: clientTimeout,
clientResendMax: clientResendMax, clientResendMax: clientResendMax,
clientResendTimeout: resendTimeout,
} }
} }
// TODO: current client ID is too long. At least the container ID should not be used directly. // TODO: current client ID is too long. At least the container ID should not be used directly.
// A separate issue is necessary to ensure no breaking change is affecting other users. // A seperate issue is necessary to ensure no breaking change is affecting other users.
func generateClientID(containerID string, netName string, ifName string) string { func generateClientID(containerID string, netName string, ifName string) string {
clientID := containerID + "/" + netName + "/" + ifName clientID := containerID + "/" + netName + "/" + ifName
// defined in RFC 2132, length size can not be larger than 1 octet. So we truncate 254 to make everyone happy. // defined in RFC 2132, length size can not be larger than 1 octet. So we truncate 254 to make everyone happy.
@ -76,26 +71,18 @@ func (d *DHCP) Allocate(args *skel.CmdArgs, result *current.Result) error {
return fmt.Errorf("error parsing netconf: %v", err) return fmt.Errorf("error parsing netconf: %v", err)
} }
opts, err := prepareOptions(args.Args, conf.IPAM.ProvideOptions, conf.IPAM.RequestOptions) optsRequesting, optsProviding, err := prepareOptions(args.Args, conf.IPAM.ProvideOptions, conf.IPAM.RequestOptions)
if err != nil { if err != nil {
return err return err
} }
clientID := generateClientID(args.ContainerID, conf.Name, args.IfName) clientID := generateClientID(args.ContainerID, conf.Name, args.IfName)
hostNetns := d.hostNetnsPrefix + args.Netns
// If we already have an active lease for this clientID, do not create l, err := AcquireLease(clientID, hostNetns, args.IfName,
// another one optsRequesting, optsProviding,
l := d.getLease(clientID) d.clientTimeout, d.clientResendMax, d.broadcast)
if l != nil { if err != nil {
l.Check() return err
} else {
hostNetns := d.hostNetnsPrefix + args.Netns
l, err = AcquireLease(clientID, hostNetns, args.IfName,
opts,
d.clientTimeout, d.clientResendMax, d.clientResendTimeout, d.broadcast)
if err != nil {
return err
}
} }
ipn, err := l.IPNet() ipn, err := l.IPNet()
@ -111,18 +98,13 @@ func (d *DHCP) Allocate(args *skel.CmdArgs, result *current.Result) error {
Gateway: l.Gateway(), Gateway: l.Gateway(),
}} }}
result.Routes = l.Routes() result.Routes = l.Routes()
if conf.IPAM.Priority != 0 {
for _, r := range result.Routes {
r.Priority = conf.IPAM.Priority
}
}
return nil return nil
} }
// Release stops maintenance of the lease acquired in Allocate() // Release stops maintenance of the lease acquired in Allocate()
// and sends a release msg to the DHCP server. // and sends a release msg to the DHCP server.
func (d *DHCP) Release(args *skel.CmdArgs, _ *struct{}) error { func (d *DHCP) Release(args *skel.CmdArgs, reply *struct{}) error {
conf := NetConf{} conf := NetConf{}
if err := json.Unmarshal(args.StdinData, &conf); err != nil { if err := json.Unmarshal(args.StdinData, &conf); err != nil {
return fmt.Errorf("error parsing netconf: %v", err) return fmt.Errorf("error parsing netconf: %v", err)
@ -157,7 +139,7 @@ func (d *DHCP) setLease(clientID string, l *DHCPLease) {
d.leases[clientID] = l d.leases[clientID] = l
} }
// func (d *DHCP) clearLease(contID, netName, ifName string) { //func (d *DHCP) clearLease(contID, netName, ifName string) {
func (d *DHCP) clearLease(clientID string) { func (d *DHCP) clearLease(clientID string) {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
@ -174,7 +156,7 @@ func getListener(socketPath string) (net.Listener, error) {
switch { switch {
case len(l) == 0: case len(l) == 0:
if err := os.MkdirAll(filepath.Dir(socketPath), 0o700); err != nil { if err := os.MkdirAll(filepath.Dir(socketPath), 0700); err != nil {
return nil, err return nil, err
} }
return net.Listen("unix", socketPath) return net.Listen("unix", socketPath)
@ -192,8 +174,7 @@ func getListener(socketPath string) (net.Listener, error) {
func runDaemon( func runDaemon(
pidfilePath, hostPrefix, socketPath string, pidfilePath, hostPrefix, socketPath string,
dhcpClientTimeout time.Duration, resendMax time.Duration, resendTimeout time.Duration, dhcpClientTimeout time.Duration, resendMax time.Duration, broadcast bool,
broadcast bool,
) error { ) error {
// since other goroutines (on separate threads) will change namespaces, // since other goroutines (on separate threads) will change namespaces,
// ensure the RPC server does not get scheduled onto those // ensure the RPC server does not get scheduled onto those
@ -204,7 +185,7 @@ func runDaemon(
if !filepath.IsAbs(pidfilePath) { if !filepath.IsAbs(pidfilePath) {
return fmt.Errorf("Error writing pidfile %q: path not absolute", pidfilePath) return fmt.Errorf("Error writing pidfile %q: path not absolute", pidfilePath)
} }
if err := os.WriteFile(pidfilePath, []byte(fmt.Sprintf("%d", os.Getpid())), 0o644); err != nil { if err := ioutil.WriteFile(pidfilePath, []byte(fmt.Sprintf("%d", os.Getpid())), 0644); err != nil {
return fmt.Errorf("Error writing pidfile %q: %v", pidfilePath, err) return fmt.Errorf("Error writing pidfile %q: %v", pidfilePath, err)
} }
} }
@ -214,27 +195,11 @@ func runDaemon(
return fmt.Errorf("Error getting listener: %v", err) return fmt.Errorf("Error getting listener: %v", err)
} }
srv := http.Server{} dhcp := newDHCP(dhcpClientTimeout, resendMax)
exit := make(chan os.Signal, 1)
done := make(chan bool, 1)
signal.Notify(exit, os.Interrupt, syscall.SIGTERM)
go func() {
<-exit
srv.Shutdown(context.TODO())
os.Remove(hostPrefix + socketPath)
os.Remove(pidfilePath)
done <- true
}()
dhcp := newDHCP(dhcpClientTimeout, resendMax, resendTimeout)
dhcp.hostNetnsPrefix = hostPrefix dhcp.hostNetnsPrefix = hostPrefix
dhcp.broadcast = broadcast dhcp.broadcast = broadcast
rpc.Register(dhcp) rpc.Register(dhcp)
rpc.HandleHTTP() rpc.HandleHTTP()
srv.Serve(l) http.Serve(l, nil)
<-done
return nil return nil
} }

View File

@ -16,19 +16,21 @@ package main
import ( import (
"fmt" "fmt"
"net"
"os" "os"
"os/exec" "os/exec"
"sync" "sync"
"time" "time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/vishvananda/netlink"
"github.com/containernetworking/cni/pkg/skel" "github.com/containernetworking/cni/pkg/skel"
current "github.com/containernetworking/cni/pkg/types/100" current "github.com/containernetworking/cni/pkg/types/100"
"github.com/containernetworking/plugins/pkg/ns" "github.com/containernetworking/plugins/pkg/ns"
"github.com/containernetworking/plugins/pkg/testutils" "github.com/containernetworking/plugins/pkg/testutils"
"github.com/vishvananda/netlink"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
) )
var _ = Describe("DHCP Multiple Lease Operations", func() { var _ = Describe("DHCP Multiple Lease Operations", func() {
@ -38,10 +40,11 @@ var _ = Describe("DHCP Multiple Lease Operations", func() {
var clientCmd *exec.Cmd var clientCmd *exec.Cmd
var socketPath string var socketPath string
var tmpDir string var tmpDir string
var serverIP net.IPNet
var err error var err error
BeforeEach(func() { BeforeEach(func() {
dhcpServerStopCh, socketPath, originalNS, targetNS, err = dhcpSetupOriginalNS() dhcpServerStopCh, serverIP, socketPath, originalNS, targetNS, err = dhcpSetupOriginalNS()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Move the container side to the container's NS // Move the container side to the container's NS
@ -61,12 +64,13 @@ var _ = Describe("DHCP Multiple Lease Operations", func() {
}) })
// Start the DHCP server // Start the DHCP server
dhcpServerDone = dhcpServerStart(originalNS, 2, dhcpServerStopCh) dhcpServerDone, err = dhcpServerStart(originalNS, net.IPv4(192, 168, 1, 5), serverIP.IP, 2, dhcpServerStopCh)
Expect(err).NotTo(HaveOccurred())
// Start the DHCP client daemon // Start the DHCP client daemon
dhcpPluginPath, err := exec.LookPath("dhcp") dhcpPluginPath, err := exec.LookPath("dhcp")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
clientCmd = exec.Command(dhcpPluginPath, "daemon", "-socketpath", socketPath, "--timeout", "2s", "--resendtimeout", "8s") clientCmd = exec.Command(dhcpPluginPath, "daemon", "-socketpath", socketPath)
err = clientCmd.Start() err = clientCmd.Start()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(clientCmd.Process).NotTo(BeNil()) Expect(clientCmd.Process).NotTo(BeNil())
@ -119,7 +123,7 @@ var _ = Describe("DHCP Multiple Lease Operations", func() {
addResult, err = current.GetResult(r) addResult, err = current.GetResult(r)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(addResult.IPs).To(HaveLen(1)) Expect(len(addResult.IPs)).To(Equal(1))
Expect(addResult.IPs[0].Address.String()).To(Equal("192.168.1.5/24")) Expect(addResult.IPs[0].Address.String()).To(Equal("192.168.1.5/24"))
return nil return nil
}) })
@ -142,7 +146,7 @@ var _ = Describe("DHCP Multiple Lease Operations", func() {
addResult, err = current.GetResult(r) addResult, err = current.GetResult(r)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(addResult.IPs).To(HaveLen(1)) Expect(len(addResult.IPs)).To(Equal(1))
Expect(addResult.IPs[0].Address.String()).To(Equal("192.168.1.6/24")) Expect(addResult.IPs[0].Address.String()).To(Equal("192.168.1.6/24"))
return nil return nil
}) })

View File

@ -15,10 +15,10 @@
package main package main
import ( import (
"testing" . "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"testing"
) )
func TestDHCP(t *testing.T) { func TestDHCP(t *testing.T) {

View File

@ -18,6 +18,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"os" "os"
"os/exec" "os/exec"
@ -25,18 +26,24 @@ import (
"sync" "sync"
"time" "time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/vishvananda/netlink"
"github.com/containernetworking/cni/pkg/skel" "github.com/containernetworking/cni/pkg/skel"
types100 "github.com/containernetworking/cni/pkg/types/100" "github.com/containernetworking/cni/pkg/types/100"
"github.com/containernetworking/plugins/pkg/ns" "github.com/containernetworking/plugins/pkg/ns"
"github.com/containernetworking/plugins/pkg/testutils" "github.com/containernetworking/plugins/pkg/testutils"
"github.com/vishvananda/netlink"
"github.com/d2g/dhcp4"
"github.com/d2g/dhcp4server"
"github.com/d2g/dhcp4server/leasepool"
"github.com/d2g/dhcp4server/leasepool/memorypool"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
) )
func getTmpDir() (string, error) { func getTmpDir() (string, error) {
tmpDir, err := os.MkdirTemp(cniDirPrefix, "dhcp") tmpDir, err := ioutil.TempDir(cniDirPrefix, "dhcp")
if err == nil { if err == nil {
tmpDir = filepath.ToSlash(tmpDir) tmpDir = filepath.ToSlash(tmpDir)
} }
@ -44,52 +51,31 @@ func getTmpDir() (string, error) {
return tmpDir, err return tmpDir, err
} }
type DhcpServer struct { func dhcpServerStart(netns ns.NetNS, leaseIP, serverIP net.IP, numLeases int, stopCh <-chan bool) (*sync.WaitGroup, error) {
cmd *exec.Cmd // Add the expected IP to the pool
lock sync.Mutex lp := memorypool.MemoryPool{}
startAddr net.IP Expect(numLeases).To(BeNumerically(">", 0))
endAddr net.IP // Currently tests only need at most 2
leaseTime time.Duration Expect(numLeases).To(BeNumerically("<=", 2))
}
func (s *DhcpServer) Serve() error { // tests expect first lease to be at address 192.168.1.5
if err := s.Start(); err != nil { for i := 5; i < numLeases+5; i++ {
return err err := lp.AddLease(leasepool.Lease{IP: dhcp4.IPAdd(net.IPv4(192, 168, 1, byte(i)), 0)})
if err != nil {
return nil, fmt.Errorf("error adding IP to DHCP pool: %v", err)
}
} }
return s.cmd.Wait()
}
func (s *DhcpServer) Start() error { dhcpServer, err := dhcp4server.New(
s.lock.Lock() net.IPv4(192, 168, 1, 1),
defer s.lock.Unlock() &lp,
dhcp4server.SetLocalAddr(net.UDPAddr{IP: net.IPv4(0, 0, 0, 0), Port: 67}),
s.cmd = exec.Command( dhcp4server.SetRemoteAddr(net.UDPAddr{IP: net.IPv4bcast, Port: 68}),
"dnsmasq", dhcp4server.LeaseDuration(time.Minute*15),
"--no-daemon",
"--dhcp-sequential-ip", // allocate IPs sequentially
"--port=0", // disable DNS
"--conf-file=-", // Do not read /etc/dnsmasq.conf
fmt.Sprintf("--dhcp-range=%s,%s,%d", s.startAddr, s.endAddr, int(s.leaseTime.Seconds())),
) )
s.cmd.Stdin = bytes.NewBufferString("") if err != nil {
s.cmd.Stdout = os.Stdout return nil, fmt.Errorf("failed to create DHCP server: %v", err)
s.cmd.Stderr = os.Stderr
return s.cmd.Start()
}
func (s *DhcpServer) Stop() error {
s.lock.Lock()
defer s.lock.Unlock()
return s.cmd.Process.Kill()
}
func dhcpServerStart(netns ns.NetNS, numLeases int, stopCh <-chan bool) *sync.WaitGroup {
dhcpServer := &DhcpServer{
startAddr: net.IPv4(192, 168, 1, 5),
endAddr: net.IPv4(192, 168, 1, 5+uint8(numLeases)-1),
leaseTime: 5 * time.Minute,
} }
stopWg := sync.WaitGroup{} stopWg := sync.WaitGroup{}
@ -101,10 +87,9 @@ func dhcpServerStart(netns ns.NetNS, numLeases int, stopCh <-chan bool) *sync.Wa
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
err := netns.Do(func(ns.NetNS) error { err = netns.Do(func(ns.NetNS) error {
startWg.Done() startWg.Done()
if err := dhcpServer.ListenAndServe(); err != nil {
if err := dhcpServer.Serve(); err != nil {
// Log, but don't trap errors; the server will // Log, but don't trap errors; the server will
// always report an error when stopped // always report an error when stopped
GinkgoT().Logf("DHCP server finished with error: %v", err) GinkgoT().Logf("DHCP server finished with error: %v", err)
@ -121,12 +106,12 @@ func dhcpServerStart(netns ns.NetNS, numLeases int, stopCh <-chan bool) *sync.Wa
go func() { go func() {
startWg.Done() startWg.Done()
<-stopCh <-stopCh
dhcpServer.Stop() dhcpServer.Shutdown()
stopWg.Done() stopWg.Done()
}() }()
startWg.Wait() startWg.Wait()
return &stopWg return &stopWg, nil
} }
const ( const (
@ -136,7 +121,7 @@ const (
) )
var _ = BeforeSuite(func() { var _ = BeforeSuite(func() {
err := os.MkdirAll(cniDirPrefix, 0o700) err := os.MkdirAll(cniDirPrefix, 0700)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
}) })
@ -173,11 +158,11 @@ var _ = Describe("DHCP Operations", func() {
err = originalNS.Do(func(ns.NetNS) error { err = originalNS.Do(func(ns.NetNS) error {
defer GinkgoRecover() defer GinkgoRecover()
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = hostVethName
err = netlink.LinkAdd(&netlink.Veth{ err = netlink.LinkAdd(&netlink.Veth{
LinkAttrs: linkAttrs, LinkAttrs: netlink.LinkAttrs{
PeerName: contVethName, Name: hostVethName,
},
PeerName: contVethName,
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -218,7 +203,8 @@ var _ = Describe("DHCP Operations", func() {
}) })
// Start the DHCP server // Start the DHCP server
dhcpServerDone = dhcpServerStart(originalNS, 1, dhcpServerStopCh) dhcpServerDone, err = dhcpServerStart(originalNS, net.IPv4(192, 168, 1, 5), serverIP.IP, 1, dhcpServerStopCh)
Expect(err).NotTo(HaveOccurred())
// Start the DHCP client daemon // Start the DHCP client daemon
dhcpPluginPath, err := exec.LookPath("dhcp") dhcpPluginPath, err := exec.LookPath("dhcp")
@ -288,7 +274,7 @@ var _ = Describe("DHCP Operations", func() {
addResult, err = types100.GetResult(r) addResult, err = types100.GetResult(r)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(addResult.IPs).To(HaveLen(1)) Expect(len(addResult.IPs)).To(Equal(1))
Expect(addResult.IPs[0].Address.String()).To(Equal("192.168.1.5/24")) Expect(addResult.IPs[0].Address.String()).To(Equal("192.168.1.5/24"))
return nil return nil
}) })
@ -331,7 +317,7 @@ var _ = Describe("DHCP Operations", func() {
addResult, err = types100.GetResult(r) addResult, err = types100.GetResult(r)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(addResult.IPs).To(HaveLen(1)) Expect(len(addResult.IPs)).To(Equal(1))
Expect(addResult.IPs[0].Address.String()).To(Equal("192.168.1.5/24")) Expect(addResult.IPs[0].Address.String()).To(Equal("192.168.1.5/24"))
return nil return nil
}) })
@ -349,17 +335,9 @@ var _ = Describe("DHCP Operations", func() {
started.Done() started.Done()
started.Wait() started.Wait()
err := originalNS.Do(func(ns.NetNS) error { err = originalNS.Do(func(ns.NetNS) error {
return testutils.CmdDelWithArgs(args, func() error { return testutils.CmdDelWithArgs(args, func() error {
copiedArgs := &skel.CmdArgs{ return cmdDel(args)
ContainerID: args.ContainerID,
Netns: args.Netns,
IfName: args.IfName,
StdinData: args.StdinData,
Path: args.Path,
Args: args.Args,
}
return cmdDel(copiedArgs)
}) })
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
@ -386,7 +364,7 @@ const (
contVethName1 string = "eth1" contVethName1 string = "eth1"
) )
func dhcpSetupOriginalNS() (chan bool, string, ns.NetNS, ns.NetNS, error) { func dhcpSetupOriginalNS() (chan bool, net.IPNet, string, ns.NetNS, ns.NetNS, error) {
var originalNS, targetNS ns.NetNS var originalNS, targetNS ns.NetNS
var dhcpServerStopCh chan bool var dhcpServerStopCh chan bool
var socketPath string var socketPath string
@ -407,15 +385,20 @@ func dhcpSetupOriginalNS() (chan bool, string, ns.NetNS, ns.NetNS, error) {
targetNS, err = testutils.NewNS() targetNS, err = testutils.NewNS()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
serverIP := net.IPNet{
IP: net.IPv4(192, 168, 1, 1),
Mask: net.IPv4Mask(255, 255, 255, 0),
}
// Use (original) NS // Use (original) NS
err = originalNS.Do(func(ns.NetNS) error { err = originalNS.Do(func(ns.NetNS) error {
defer GinkgoRecover() defer GinkgoRecover()
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = hostBridgeName
// Create bridge in the "host" (original) NS // Create bridge in the "host" (original) NS
br = &netlink.Bridge{ br = &netlink.Bridge{
LinkAttrs: linkAttrs, LinkAttrs: netlink.LinkAttrs{
Name: hostBridgeName,
},
} }
err = netlink.LinkAdd(br) err = netlink.LinkAdd(br)
@ -501,7 +484,7 @@ func dhcpSetupOriginalNS() (chan bool, string, ns.NetNS, ns.NetNS, error) {
return nil return nil
}) })
return dhcpServerStopCh, socketPath, originalNS, targetNS, err return dhcpServerStopCh, serverIP, socketPath, originalNS, targetNS, err
} }
var _ = Describe("DHCP Lease Unavailable Operations", func() { var _ = Describe("DHCP Lease Unavailable Operations", func() {
@ -511,10 +494,11 @@ var _ = Describe("DHCP Lease Unavailable Operations", func() {
var clientCmd *exec.Cmd var clientCmd *exec.Cmd
var socketPath string var socketPath string
var tmpDir string var tmpDir string
var serverIP net.IPNet
var err error var err error
BeforeEach(func() { BeforeEach(func() {
dhcpServerStopCh, socketPath, originalNS, targetNS, err = dhcpSetupOriginalNS() dhcpServerStopCh, serverIP, socketPath, originalNS, targetNS, err = dhcpSetupOriginalNS()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
// Move the container side to the container's NS // Move the container side to the container's NS
@ -534,7 +518,8 @@ var _ = Describe("DHCP Lease Unavailable Operations", func() {
}) })
// Start the DHCP server // Start the DHCP server
dhcpServerDone = dhcpServerStart(originalNS, 1, dhcpServerStopCh) dhcpServerDone, err = dhcpServerStart(originalNS, net.IPv4(192, 168, 1, 5), serverIP.IP, 1, dhcpServerStopCh)
Expect(err).NotTo(HaveOccurred())
// Start the DHCP client daemon // Start the DHCP client daemon
dhcpPluginPath, err := exec.LookPath("dhcp") dhcpPluginPath, err := exec.LookPath("dhcp")
@ -544,7 +529,7 @@ var _ = Describe("DHCP Lease Unavailable Operations", func() {
// `go test` timeout with default delays. Since our DHCP server // `go test` timeout with default delays. Since our DHCP server
// and client daemon are local processes anyway, we can depend on // and client daemon are local processes anyway, we can depend on
// them to respond very quickly. // them to respond very quickly.
clientCmd = exec.Command(dhcpPluginPath, "daemon", "-socketpath", socketPath, "-timeout", "2s", "-resendmax", "8s", "--resendtimeout", "10s") clientCmd = exec.Command(dhcpPluginPath, "daemon", "-socketpath", socketPath, "-timeout", "2s", "-resendmax", "8s")
// copy dhcp client's stdout/stderr to test stdout // copy dhcp client's stdout/stderr to test stdout
var b bytes.Buffer var b bytes.Buffer
@ -612,7 +597,7 @@ var _ = Describe("DHCP Lease Unavailable Operations", func() {
addResult, err = types100.GetResult(r) addResult, err = types100.GetResult(r)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(addResult.IPs).To(HaveLen(1)) Expect(len(addResult.IPs)).To(Equal(1))
Expect(addResult.IPs[0].Address.String()).To(Equal("192.168.1.5/24")) Expect(addResult.IPs[0].Address.String()).To(Equal("192.168.1.5/24"))
return nil return nil
}) })

View File

@ -15,7 +15,6 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"log" "log"
"math/rand" "math/rand"
@ -25,8 +24,8 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
dhcp4 "github.com/insomniacslk/dhcp/dhcpv4" "github.com/d2g/dhcp4"
"github.com/insomniacslk/dhcp/dhcpv4/nclient4" "github.com/d2g/dhcp4client"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"github.com/containernetworking/cni/pkg/types" "github.com/containernetworking/cni/pkg/types"
@ -35,19 +34,13 @@ import (
// RFC 2131 suggests using exponential backoff, starting with 4sec // RFC 2131 suggests using exponential backoff, starting with 4sec
// and randomized to +/- 1sec // and randomized to +/- 1sec
const ( const resendDelay0 = 4 * time.Second
resendDelay0 = 4 * time.Second const resendDelayMax = 62 * time.Second
resendDelayMax = 62 * time.Second
defaultLeaseTime = 60 * time.Minute
defaultResendTimeout = 208 * time.Second // fast resend + backoff resend
)
// To speed up the retry for first few failures, we retry without // To speed up the retry for first few failures, we retry without
// backoff for a few times // backoff for a few times
const ( const resendFastDelay = 2 * time.Second
resendFastDelay = 2 * time.Second const resendFastMax = 4
resendFastMax = 4
)
const ( const (
leaseStateBound = iota leaseStateBound = iota
@ -63,36 +56,31 @@ const (
type DHCPLease struct { type DHCPLease struct {
clientID string clientID string
latestLease *nclient4.Lease ack *dhcp4.Packet
opts dhcp4.Options
link netlink.Link link netlink.Link
renewalTime time.Time renewalTime time.Time
rebindingTime time.Time rebindingTime time.Time
expireTime time.Time expireTime time.Time
timeout time.Duration timeout time.Duration
resendMax time.Duration resendMax time.Duration
resendTimeout time.Duration
broadcast bool broadcast bool
stopping uint32 stopping uint32
stop chan struct{} stop chan struct{}
check chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
cancelFunc context.CancelFunc
ctx context.Context
// list of requesting and providing options and if they are necessary / their value // list of requesting and providing options and if they are necessary / their value
opts []dhcp4.Option optsRequesting map[dhcp4.OptionCode]bool
optsProviding map[dhcp4.OptionCode][]byte
} }
var requestOptionsDefault = []dhcp4.OptionCode{ var requestOptionsDefault = map[dhcp4.OptionCode]bool{
dhcp4.OptionRouter, dhcp4.OptionRouter: true,
dhcp4.OptionSubnetMask, dhcp4.OptionSubnetMask: true,
} }
func prepareOptions(cniArgs string, provideOptions []ProvideOption, requestOptions []RequestOption) ( func prepareOptions(cniArgs string, ProvideOptions []ProvideOption, RequestOptions []RequestOption) (
[]dhcp4.Option, error, optsRequesting map[dhcp4.OptionCode]bool, optsProviding map[dhcp4.OptionCode][]byte, err error) {
) {
var opts []dhcp4.Option
var err error
// parse CNI args // parse CNI args
cniArgsParsed := map[string]string{} cniArgsParsed := map[string]string{}
for _, argPair := range strings.Split(cniArgs, ";") { for _, argPair := range strings.Split(cniArgs, ";") {
@ -104,51 +92,50 @@ func prepareOptions(cniArgs string, provideOptions []ProvideOption, requestOptio
// parse providing options map // parse providing options map
var optParsed dhcp4.OptionCode var optParsed dhcp4.OptionCode
for _, opt := range provideOptions { optsProviding = make(map[dhcp4.OptionCode][]byte)
for _, opt := range ProvideOptions {
optParsed, err = parseOptionName(string(opt.Option)) optParsed, err = parseOptionName(string(opt.Option))
if err != nil { if err != nil {
return nil, fmt.Errorf("Can not parse option %q: %w", opt.Option, err) err = fmt.Errorf("Can not parse option %q: %w", opt.Option, err)
return
} }
if len(opt.Value) > 0 { if len(opt.Value) > 0 {
if len(opt.Value) > 255 { if len(opt.Value) > 255 {
return nil, fmt.Errorf("value too long for option %q: %q", opt.Option, opt.Value) err = fmt.Errorf("value too long for option %q: %q", opt.Option, opt.Value)
return
} }
opts = append(opts, dhcp4.Option{Code: optParsed, Value: dhcp4.String(opt.Value)}) optsProviding[optParsed] = []byte(opt.Value)
} }
if value, ok := cniArgsParsed[opt.ValueFromCNIArg]; ok { if value, ok := cniArgsParsed[opt.ValueFromCNIArg]; ok {
if len(value) > 255 { if len(value) > 255 {
return nil, fmt.Errorf("value too long for option %q from CNI_ARGS %q: %q", opt.Option, opt.ValueFromCNIArg, opt.Value) err = fmt.Errorf("value too long for option %q from CNI_ARGS %q: %q", opt.Option, opt.ValueFromCNIArg, opt.Value)
return
} }
opts = append(opts, dhcp4.Option{Code: optParsed, Value: dhcp4.String(value)}) optsProviding[optParsed] = []byte(value)
} }
} }
// parse necessary options map // parse necessary options map
var optsRequesting dhcp4.OptionCodeList optsRequesting = make(map[dhcp4.OptionCode]bool)
skipRequireDefault := false skipRequireDefault := false
for _, opt := range requestOptions { for _, opt := range RequestOptions {
if opt.SkipDefault { if opt.SkipDefault {
skipRequireDefault = true skipRequireDefault = true
} }
if opt.Option == "" {
continue
}
optParsed, err = parseOptionName(string(opt.Option)) optParsed, err = parseOptionName(string(opt.Option))
if err != nil { if err != nil {
return nil, fmt.Errorf("Can not parse option %q: %w", opt.Option, err) err = fmt.Errorf("Can not parse option %q: %w", opt.Option, err)
return
} }
optsRequesting.Add(optParsed) optsRequesting[optParsed] = true
} }
if !skipRequireDefault { for k, v := range requestOptionsDefault {
for _, opt := range requestOptionsDefault { // only set if not skipping default and this value does not exists
optsRequesting.Add(opt) if _, ok := optsRequesting[k]; !ok && !skipRequireDefault {
optsRequesting[k] = v
} }
} }
if len(optsRequesting) > 0 { return
opts = append(opts, dhcp4.Option{Code: dhcp4.OptionParameterRequestList, Value: optsRequesting})
}
return opts, err
} }
// AcquireLease gets an DHCP lease and then maintains it in the background // AcquireLease gets an DHCP lease and then maintains it in the background
@ -156,25 +143,18 @@ func prepareOptions(cniArgs string, provideOptions []ProvideOption, requestOptio
// calling DHCPLease.Stop() // calling DHCPLease.Stop()
func AcquireLease( func AcquireLease(
clientID, netns, ifName string, clientID, netns, ifName string,
opts []dhcp4.Option, optsRequesting map[dhcp4.OptionCode]bool, optsProviding map[dhcp4.OptionCode][]byte,
timeout, resendMax time.Duration, resendTimeout time.Duration, broadcast bool, timeout, resendMax time.Duration, broadcast bool,
) (*DHCPLease, error) { ) (*DHCPLease, error) {
errCh := make(chan error, 1) errCh := make(chan error, 1)
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
l := &DHCPLease{ l := &DHCPLease{
clientID: clientID, clientID: clientID,
stop: make(chan struct{}), stop: make(chan struct{}),
check: make(chan struct{}), timeout: timeout,
timeout: timeout, resendMax: resendMax,
resendMax: resendMax, broadcast: broadcast,
resendTimeout: resendTimeout, optsRequesting: optsRequesting,
broadcast: broadcast, optsProviding: optsProviding,
opts: opts,
cancelFunc: cancel,
ctx: ctx,
} }
log.Printf("%v: acquiring lease", clientID) log.Printf("%v: acquiring lease", clientID)
@ -216,74 +196,96 @@ func AcquireLease(
func (l *DHCPLease) Stop() { func (l *DHCPLease) Stop() {
if atomic.CompareAndSwapUint32(&l.stopping, 0, 1) { if atomic.CompareAndSwapUint32(&l.stopping, 0, 1) {
close(l.stop) close(l.stop)
l.cancelFunc()
} }
l.wg.Wait() l.wg.Wait()
} }
func (l *DHCPLease) Check() { func (l *DHCPLease) getOptionsWithClientId() dhcp4.Options {
l.check <- struct{}{} opts := make(dhcp4.Options)
opts[dhcp4.OptionClientIdentifier] = []byte(l.clientID)
// client identifier's first byte is "type"
newClientID := []byte{0}
newClientID = append(newClientID, opts[dhcp4.OptionClientIdentifier]...)
opts[dhcp4.OptionClientIdentifier] = newClientID
return opts
} }
func withClientID(clientID string) dhcp4.Modifier { func (l *DHCPLease) getAllOptions() dhcp4.Options {
return func(d *dhcp4.DHCPv4) { opts := l.getOptionsWithClientId()
optClientID := []byte{0}
optClientID = append(optClientID, []byte(clientID)...)
d.Options.Update(dhcp4.OptClientIdentifier(optClientID))
}
}
func withAllOptions(l *DHCPLease) dhcp4.Modifier { for k, v := range l.optsProviding {
return func(d *dhcp4.DHCPv4) { opts[k] = v
for _, opt := range l.opts {
d.Options.Update(opt)
}
} }
opts[dhcp4.OptionParameterRequestList] = []byte{}
for k := range l.optsRequesting {
opts[dhcp4.OptionParameterRequestList] = append(opts[dhcp4.OptionParameterRequestList], byte(k))
}
return opts
} }
func (l *DHCPLease) acquire() error { func (l *DHCPLease) acquire() error {
if (l.link.Attrs().Flags & net.FlagUp) != net.FlagUp { c, err := newDHCPClient(l.link, l.clientID, l.timeout, l.broadcast)
log.Printf("Link %q down. Attempting to set up", l.link.Attrs().Name)
if err := netlink.LinkSetUp(l.link); err != nil {
return err
}
}
c, err := newDHCPClient(l.link, l.timeout)
if err != nil { if err != nil {
return err return err
} }
defer c.Close() defer c.Close()
timeoutCtx, cancel := context.WithTimeoutCause(l.ctx, l.resendTimeout, errNoMoreTries) if (l.link.Attrs().Flags & net.FlagUp) != net.FlagUp {
defer cancel() log.Printf("Link %q down. Attempting to set up", l.link.Attrs().Name)
pkt, err := backoffRetry(timeoutCtx, l.resendMax, func() (*nclient4.Lease, error) { if err = netlink.LinkSetUp(l.link); err != nil {
return c.Request( return err
timeoutCtx, }
withClientID(l.clientID), }
withAllOptions(l),
) opts := l.getAllOptions()
pkt, err := backoffRetry(l.resendMax, func() (*dhcp4.Packet, error) {
ok, ack, err := DhcpRequest(c, opts)
switch {
case err != nil:
return nil, err
case !ok:
return nil, fmt.Errorf("DHCP server NACK'd own offer")
default:
return &ack, nil
}
}) })
if err != nil { if err != nil {
return err return err
} }
l.commit(pkt) return l.commit(pkt)
return nil
} }
func (l *DHCPLease) commit(lease *nclient4.Lease) { func (l *DHCPLease) commit(ack *dhcp4.Packet) error {
l.latestLease = lease opts := ack.ParseOptions()
ack := lease.ACK
leaseTime := ack.IPAddressLeaseTime(defaultLeaseTime) leaseTime, err := parseLeaseTime(opts)
rebindingTime := ack.IPAddressRebindingTime(leaseTime * 85 / 100) if err != nil {
renewalTime := ack.IPAddressRenewalTime(leaseTime / 2) return err
}
rebindingTime, err := parseRebindingTime(opts)
if err != nil || rebindingTime > leaseTime {
// Per RFC 2131 Section 4.4.5, it should default to 85% of lease time
rebindingTime = leaseTime * 85 / 100
}
renewalTime, err := parseRenewalTime(opts)
if err != nil || renewalTime > rebindingTime {
// Per RFC 2131 Section 4.4.5, it should default to 50% of lease time
renewalTime = leaseTime / 2
}
now := time.Now() now := time.Now()
l.expireTime = now.Add(leaseTime) l.expireTime = now.Add(leaseTime)
l.renewalTime = now.Add(renewalTime) l.renewalTime = now.Add(renewalTime)
l.rebindingTime = now.Add(rebindingTime) l.rebindingTime = now.Add(rebindingTime)
l.ack = ack
l.opts = opts
return nil
} }
func (l *DHCPLease) maintain() { func (l *DHCPLease) maintain() {
@ -294,7 +296,7 @@ func (l *DHCPLease) maintain() {
switch state { switch state {
case leaseStateBound: case leaseStateBound:
sleepDur = time.Until(l.renewalTime) sleepDur = l.renewalTime.Sub(time.Now())
if sleepDur <= 0 { if sleepDur <= 0 {
log.Printf("%v: renewing lease", l.clientID) log.Printf("%v: renewing lease", l.clientID)
state = leaseStateRenewing state = leaseStateRenewing
@ -306,7 +308,7 @@ func (l *DHCPLease) maintain() {
log.Printf("%v: %v", l.clientID, err) log.Printf("%v: %v", l.clientID, err)
if time.Now().After(l.rebindingTime) { if time.Now().After(l.rebindingTime) {
log.Printf("%v: renewal time expired, rebinding", l.clientID) log.Printf("%v: renawal time expired, rebinding", l.clientID)
state = leaseStateRebinding state = leaseStateRebinding
} }
} else { } else {
@ -332,9 +334,6 @@ func (l *DHCPLease) maintain() {
select { select {
case <-time.After(sleepDur): case <-time.After(sleepDur):
case <-l.check:
log.Printf("%v: Checking lease", l.clientID)
case <-l.stop: case <-l.stop:
if err := l.release(); err != nil { if err := l.release(); err != nil {
log.Printf("%v: failed to release DHCP lease: %v", l.clientID, err) log.Printf("%v: failed to release DHCP lease: %v", l.clientID, err)
@ -351,40 +350,44 @@ func (l *DHCPLease) downIface() {
} }
func (l *DHCPLease) renew() error { func (l *DHCPLease) renew() error {
c, err := newDHCPClient(l.link, l.timeout) c, err := newDHCPClient(l.link, l.clientID, l.timeout, l.broadcast)
if err != nil { if err != nil {
return err return err
} }
defer c.Close() defer c.Close()
timeoutCtx, cancel := context.WithTimeoutCause(l.ctx, l.resendTimeout, errNoMoreTries) opts := l.getOptionsWithClientId()
defer cancel() pkt, err := backoffRetry(l.resendMax, func() (*dhcp4.Packet, error) {
lease, err := backoffRetry(timeoutCtx, l.resendMax, func() (*nclient4.Lease, error) { ok, ack, err := DhcpRenew(c, *l.ack, opts)
return c.Renew( switch {
timeoutCtx, case err != nil:
l.latestLease, return nil, err
withClientID(l.clientID), case !ok:
withAllOptions(l), return nil, fmt.Errorf("DHCP server did not renew lease")
) default:
return &ack, nil
}
}) })
if err != nil { if err != nil {
return err return err
} }
l.commit(lease) l.commit(pkt)
return nil return nil
} }
func (l *DHCPLease) release() error { func (l *DHCPLease) release() error {
log.Printf("%v: releasing lease", l.clientID) log.Printf("%v: releasing lease", l.clientID)
c, err := newDHCPClient(l.link, l.timeout) c, err := newDHCPClient(l.link, l.clientID, l.timeout, l.broadcast)
if err != nil { if err != nil {
return err return err
} }
defer c.Close() defer c.Close()
if err = c.Release(l.latestLease, withClientID(l.clientID)); err != nil { opts := l.getOptionsWithClientId()
if err = DhcpRelease(c, *l.ack, opts); err != nil {
return fmt.Errorf("failed to send DHCPRELEASE") return fmt.Errorf("failed to send DHCPRELEASE")
} }
@ -392,47 +395,33 @@ func (l *DHCPLease) release() error {
} }
func (l *DHCPLease) IPNet() (*net.IPNet, error) { func (l *DHCPLease) IPNet() (*net.IPNet, error) {
ack := l.latestLease.ACK mask := parseSubnetMask(l.opts)
mask := ack.SubnetMask()
if mask == nil { if mask == nil {
return nil, fmt.Errorf("DHCP option Subnet Mask not found in DHCPACK") return nil, fmt.Errorf("DHCP option Subnet Mask not found in DHCPACK")
} }
return &net.IPNet{ return &net.IPNet{
IP: ack.YourIPAddr, IP: l.ack.YIAddr(),
Mask: mask, Mask: mask,
}, nil }, nil
} }
func (l *DHCPLease) Gateway() net.IP { func (l *DHCPLease) Gateway() net.IP {
ack := l.latestLease.ACK return parseRouter(l.opts)
gws := ack.Router()
if len(gws) > 0 {
return gws[0]
}
return nil
} }
func (l *DHCPLease) Routes() []*types.Route { func (l *DHCPLease) Routes() []*types.Route {
routes := []*types.Route{} routes := []*types.Route{}
ack := l.latestLease.ACK
// RFC 3442 states that if Classless Static Routes (option 121) // RFC 3442 states that if Classless Static Routes (option 121)
// exist, we ignore Static Routes (option 33) and the Router/Gateway. // exist, we ignore Static Routes (option 33) and the Router/Gateway.
opt121Routes := ack.ClasslessStaticRoute() opt121_routes := parseCIDRRoutes(l.opts)
if len(opt121Routes) > 0 { if len(opt121_routes) > 0 {
for _, r := range opt121Routes { return append(routes, opt121_routes...)
routes = append(routes, &types.Route{Dst: *r.Dest, GW: r.Router})
}
return routes
} }
// Append Static Routes // Append Static Routes
if ack.Options.Has(dhcp4.OptionStaticRoutingTable) { routes = append(routes, parseRoutes(l.opts)...)
routes = append(routes, parseRoutes(ack.Options.Get(dhcp4.OptionStaticRoutingTable))...)
}
// The CNI spec says even if there is a gateway specified, we must // The CNI spec says even if there is a gateway specified, we must
// add a default route in the routes section. // add a default route in the routes section.
@ -449,10 +438,10 @@ func jitter(span time.Duration) time.Duration {
return time.Duration(float64(span) * (2.0*rand.Float64() - 1.0)) return time.Duration(float64(span) * (2.0*rand.Float64() - 1.0))
} }
func backoffRetry(ctx context.Context, resendMax time.Duration, f func() (*nclient4.Lease, error)) (*nclient4.Lease, error) { func backoffRetry(resendMax time.Duration, f func() (*dhcp4.Packet, error)) (*dhcp4.Packet, error) {
baseDelay := resendDelay0 var baseDelay time.Duration = resendDelay0
var sleepTime time.Duration var sleepTime time.Duration
fastRetryLimit := resendFastMax var fastRetryLimit = resendFastMax
for { for {
pkt, err := f() pkt, err := f()
if err == nil { if err == nil {
@ -470,23 +459,33 @@ func backoffRetry(ctx context.Context, resendMax time.Duration, f func() (*nclie
log.Printf("retrying in %f seconds", sleepTime.Seconds()) log.Printf("retrying in %f seconds", sleepTime.Seconds())
select { time.Sleep(sleepTime)
case <-ctx.Done():
return nil, context.Cause(ctx) // only adjust delay time if we are in normal backoff stage
case <-time.After(sleepTime): if baseDelay < resendMax && fastRetryLimit == 0 {
// only adjust delay time if we are in normal backoff stage baseDelay *= 2
if baseDelay < resendMax && fastRetryLimit == 0 { } else if fastRetryLimit == 0 { // only break if we are at normal delay
baseDelay *= 2 break
}
} }
} }
return nil, errNoMoreTries
} }
func newDHCPClient( func newDHCPClient(
link netlink.Link, link netlink.Link, clientID string,
timeout time.Duration, timeout time.Duration,
clientOpts ...nclient4.ClientOpt, broadcast bool,
) (*nclient4.Client, error) { ) (*dhcp4client.Client, error) {
clientOpts = append(clientOpts, nclient4.WithTimeout(timeout)) pktsock, err := dhcp4client.NewPacketSock(link.Attrs().Index)
return nclient4.New(link.Attrs().Name, clientOpts...) if err != nil {
return nil, err
}
return dhcp4client.New(
dhcp4client.HardwareAddr(link.Attrs().HardwareAddr),
dhcp4client.Timeout(timeout),
dhcp4client.Broadcast(broadcast),
dhcp4client.Connection(pktsock),
)
} }

View File

@ -51,8 +51,6 @@ type IPAMConfig struct {
// To override default requesting fields, set `skipDefault` to `false`. // To override default requesting fields, set `skipDefault` to `false`.
// If an field is not optional, but the server failed to provide it, error will be raised. // If an field is not optional, but the server failed to provide it, error will be raised.
RequestOptions []RequestOption `json:"request"` RequestOptions []RequestOption `json:"request"`
// The metric of routes
Priority int `json:"priority,omitempty"`
} }
// DHCPOption represents a DHCP option. It can be a number, or a string defined in manual dhcp-options(5). // DHCPOption represents a DHCP option. It can be a number, or a string defined in manual dhcp-options(5).
@ -80,33 +78,25 @@ func main() {
var broadcast bool var broadcast bool
var timeout time.Duration var timeout time.Duration
var resendMax time.Duration var resendMax time.Duration
var resendTimeout time.Duration
daemonFlags := flag.NewFlagSet("daemon", flag.ExitOnError) daemonFlags := flag.NewFlagSet("daemon", flag.ExitOnError)
daemonFlags.StringVar(&pidfilePath, "pidfile", "", "optional path to write daemon PID to") daemonFlags.StringVar(&pidfilePath, "pidfile", "", "optional path to write daemon PID to")
daemonFlags.StringVar(&hostPrefix, "hostprefix", "", "optional prefix to host root") daemonFlags.StringVar(&hostPrefix, "hostprefix", "", "optional prefix to host root")
daemonFlags.StringVar(&socketPath, "socketpath", "", "optional dhcp server socketpath") daemonFlags.StringVar(&socketPath, "socketpath", "", "optional dhcp server socketpath")
daemonFlags.BoolVar(&broadcast, "broadcast", false, "broadcast DHCP leases") daemonFlags.BoolVar(&broadcast, "broadcast", false, "broadcast DHCP leases")
daemonFlags.DurationVar(&timeout, "timeout", 10*time.Second, "optional dhcp client timeout duration for each request") daemonFlags.DurationVar(&timeout, "timeout", 10*time.Second, "optional dhcp client timeout duration")
daemonFlags.DurationVar(&resendMax, "resendmax", resendDelayMax, "optional dhcp client max resend delay between requests") daemonFlags.DurationVar(&resendMax, "resendmax", resendDelayMax, "optional dhcp client resend max duration")
daemonFlags.DurationVar(&resendTimeout, "resendtimeout", defaultResendTimeout, "optional dhcp client resend timeout, no more retries after this timeout")
daemonFlags.Parse(os.Args[2:]) daemonFlags.Parse(os.Args[2:])
if socketPath == "" { if socketPath == "" {
socketPath = defaultSocketPath socketPath = defaultSocketPath
} }
if err := runDaemon(pidfilePath, hostPrefix, socketPath, timeout, resendMax, resendTimeout, broadcast); err != nil { if err := runDaemon(pidfilePath, hostPrefix, socketPath, timeout, resendMax, broadcast); err != nil {
log.Print(err.Error()) log.Print(err.Error())
os.Exit(1) os.Exit(1)
} }
} else { } else {
skel.PluginMainFuncs(skel.CNIFuncs{ skel.PluginMain(cmdAdd, cmdCheck, cmdDel, version.All, bv.BuildString("dhcp"))
Add: cmdAdd,
Check: cmdCheck,
Del: cmdDel,
/* FIXME GC */
/* FIXME Status */
}, version.All, bv.BuildString("dhcp"))
} }
} }
@ -128,20 +118,27 @@ func cmdAdd(args *skel.CmdArgs) error {
func cmdDel(args *skel.CmdArgs) error { func cmdDel(args *skel.CmdArgs) error {
result := struct{}{} result := struct{}{}
return rpcCall("DHCP.Release", args, &result) if err := rpcCall("DHCP.Release", args, &result); err != nil {
return err
}
return nil
} }
func cmdCheck(args *skel.CmdArgs) error { func cmdCheck(args *skel.CmdArgs) error {
// Plugin must return result in same version as specified in netconf // Plugin must return result in same version as specified in netconf
versionDecoder := &version.ConfigDecoder{} versionDecoder := &version.ConfigDecoder{}
// confVersion, err := versionDecoder.Decode(args.StdinData) //confVersion, err := versionDecoder.Decode(args.StdinData)
_, err := versionDecoder.Decode(args.StdinData) _, err := versionDecoder.Decode(args.StdinData)
if err != nil { if err != nil {
return err return err
} }
result := &current.Result{CNIVersion: current.ImplementedSpecVersion} result := &current.Result{CNIVersion: current.ImplementedSpecVersion}
return rpcCall("DHCP.Allocate", args, result) if err := rpcCall("DHCP.Allocate", args, result); err != nil {
return err
}
return nil
} }
func getSocketPath(stdinData []byte) (string, error) { func getSocketPath(stdinData []byte) (string, error) {

View File

@ -15,13 +15,14 @@
package main package main
import ( import (
"encoding/binary"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
"time"
dhcp4 "github.com/insomniacslk/dhcp/dhcpv4"
"github.com/containernetworking/cni/pkg/types" "github.com/containernetworking/cni/pkg/types"
"github.com/d2g/dhcp4"
) )
var optionNameToID = map[string]dhcp4.OptionCode{ var optionNameToID = map[string]dhcp4.OptionCode{
@ -29,8 +30,8 @@ var optionNameToID = map[string]dhcp4.OptionCode{
"subnet-mask": dhcp4.OptionSubnetMask, "subnet-mask": dhcp4.OptionSubnetMask,
"routers": dhcp4.OptionRouter, "routers": dhcp4.OptionRouter,
"host-name": dhcp4.OptionHostName, "host-name": dhcp4.OptionHostName,
"user-class": dhcp4.OptionUserClassInformation, "user-class": dhcp4.OptionUserClass,
"vendor-class-identifier": dhcp4.OptionClassIdentifier, "vendor-class-identifier": dhcp4.OptionVendorClassIdentifier,
} }
func parseOptionName(option string) (dhcp4.OptionCode, error) { func parseOptionName(option string) (dhcp4.OptionCode, error) {
@ -39,9 +40,18 @@ func parseOptionName(option string) (dhcp4.OptionCode, error) {
} }
i, err := strconv.ParseUint(option, 10, 8) i, err := strconv.ParseUint(option, 10, 8)
if err != nil { if err != nil {
return dhcp4.OptionPad, fmt.Errorf("Can not parse option: %w", err) return 0, fmt.Errorf("Can not parse option: %w", err)
} }
return dhcp4.GenericOptionCode(i), nil return dhcp4.OptionCode(i), nil
}
func parseRouter(opts dhcp4.Options) net.IP {
if opts, ok := opts[dhcp4.OptionRouter]; ok {
if len(opts) == 4 {
return net.IP(opts)
}
}
return nil
} }
func classfulSubnet(sn net.IP) net.IPNet { func classfulSubnet(sn net.IP) net.IPNet {
@ -51,22 +61,100 @@ func classfulSubnet(sn net.IP) net.IPNet {
} }
} }
func parseRoutes(opt []byte) []*types.Route { func parseRoutes(opts dhcp4.Options) []*types.Route {
// StaticRoutes format: pairs of: // StaticRoutes format: pairs of:
// Dest = 4 bytes; Classful IP subnet // Dest = 4 bytes; Classful IP subnet
// Router = 4 bytes; IP address of router // Router = 4 bytes; IP address of router
routes := []*types.Route{} routes := []*types.Route{}
for len(opt) >= 8 { if opt, ok := opts[dhcp4.OptionStaticRoute]; ok {
sn := opt[0:4] for len(opt) >= 8 {
r := opt[4:8] sn := opt[0:4]
rt := &types.Route{ r := opt[4:8]
Dst: classfulSubnet(sn), rt := &types.Route{
GW: r, Dst: classfulSubnet(sn),
GW: r,
}
routes = append(routes, rt)
opt = opt[8:]
} }
routes = append(routes, rt)
opt = opt[8:]
} }
return routes return routes
} }
func parseCIDRRoutes(opts dhcp4.Options) []*types.Route {
// See RFC4332 for format (http://tools.ietf.org/html/rfc3442)
routes := []*types.Route{}
if opt, ok := opts[dhcp4.OptionClasslessRouteFormat]; ok {
for len(opt) >= 5 {
width := int(opt[0])
if width > 32 {
// error: can't have more than /32
return nil
}
// network bits are compacted to avoid zeros
octets := 0
if width > 0 {
octets = (width-1)/8 + 1
}
if len(opt) < 1+octets+4 {
// error: too short
return nil
}
sn := make([]byte, 4)
copy(sn, opt[1:octets+1])
gw := net.IP(opt[octets+1 : octets+5])
rt := &types.Route{
Dst: net.IPNet{
IP: net.IP(sn),
Mask: net.CIDRMask(width, 32),
},
GW: gw,
}
routes = append(routes, rt)
opt = opt[octets+5:]
}
}
return routes
}
func parseSubnetMask(opts dhcp4.Options) net.IPMask {
mask, ok := opts[dhcp4.OptionSubnetMask]
if !ok {
return nil
}
return net.IPMask(mask)
}
func parseDuration(opts dhcp4.Options, code dhcp4.OptionCode, optName string) (time.Duration, error) {
val, ok := opts[code]
if !ok {
return 0, fmt.Errorf("option %v not found", optName)
}
if len(val) != 4 {
return 0, fmt.Errorf("option %v is not 4 bytes", optName)
}
secs := binary.BigEndian.Uint32(val)
return time.Duration(secs) * time.Second, nil
}
func parseLeaseTime(opts dhcp4.Options) (time.Duration, error) {
return parseDuration(opts, dhcp4.OptionIPAddressLeaseTime, "LeaseTime")
}
func parseRenewalTime(opts dhcp4.Options) (time.Duration, error) {
return parseDuration(opts, dhcp4.OptionRenewalTimeValue, "RenewalTime")
}
func parseRebindingTime(opts dhcp4.Options) (time.Duration, error) {
return parseDuration(opts, dhcp4.OptionRebindingTimeValue, "RebindingTime")
}

View File

@ -19,9 +19,8 @@ import (
"reflect" "reflect"
"testing" "testing"
dhcp4 "github.com/insomniacslk/dhcp/dhcpv4"
"github.com/containernetworking/cni/pkg/types" "github.com/containernetworking/cni/pkg/types"
"github.com/d2g/dhcp4"
) )
func validateRoutes(t *testing.T, routes []*types.Route) { func validateRoutes(t *testing.T, routes []*types.Route) {
@ -61,8 +60,17 @@ func validateRoutes(t *testing.T, routes []*types.Route) {
} }
func TestParseRoutes(t *testing.T) { func TestParseRoutes(t *testing.T) {
data := []byte{10, 0, 0, 0, 10, 1, 2, 3, 192, 168, 1, 0, 192, 168, 2, 3} opts := make(dhcp4.Options)
routes := parseRoutes(data) opts[dhcp4.OptionStaticRoute] = []byte{10, 0, 0, 0, 10, 1, 2, 3, 192, 168, 1, 0, 192, 168, 2, 3}
routes := parseRoutes(opts)
validateRoutes(t, routes)
}
func TestParseCIDRRoutes(t *testing.T) {
opts := make(dhcp4.Options)
opts[dhcp4.OptionClasslessRouteFormat] = []byte{8, 10, 10, 1, 2, 3, 24, 192, 168, 1, 192, 168, 2, 3}
routes := parseCIDRRoutes(opts)
validateRoutes(t, routes) validateRoutes(t, routes)
} }
@ -78,10 +86,10 @@ func TestParseOptionName(t *testing.T) {
"hostname", "host-name", dhcp4.OptionHostName, false, "hostname", "host-name", dhcp4.OptionHostName, false,
}, },
{ {
"hostname in number", "12", dhcp4.GenericOptionCode(12), false, "hostname in number", "12", dhcp4.OptionHostName, false,
}, },
{ {
"random string", "doNotparseMe", dhcp4.OptionPad, true, "random string", "doNotparseMe", 0, true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -22,6 +22,7 @@ import (
"strconv" "strconv"
current "github.com/containernetworking/cni/pkg/types/100" current "github.com/containernetworking/cni/pkg/types/100"
"github.com/containernetworking/plugins/pkg/ip" "github.com/containernetworking/plugins/pkg/ip"
"github.com/containernetworking/plugins/plugins/ipam/host-local/backend" "github.com/containernetworking/plugins/plugins/ipam/host-local/backend"
) )
@ -196,7 +197,7 @@ func (i *RangeIter) Next() (*net.IPNet, net.IP) {
// If we've reached the end of this range, we need to advance the range // If we've reached the end of this range, we need to advance the range
// RangeEnd is inclusive as well // RangeEnd is inclusive as well
if i.cur.Equal(r.RangeEnd) { if i.cur.Equal(r.RangeEnd) {
i.rangeIdx++ i.rangeIdx += 1
i.rangeIdx %= len(*i.rangeset) i.rangeIdx %= len(*i.rangeset)
r = (*i.rangeset)[i.rangeIdx] r = (*i.rangeset)[i.rangeIdx]

View File

@ -15,10 +15,10 @@
package allocator_test package allocator_test
import ( import (
"testing" . "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"testing"
) )
func TestAllocator(t *testing.T) { func TestAllocator(t *testing.T) {

View File

@ -18,12 +18,12 @@ import (
"fmt" "fmt"
"net" "net"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/containernetworking/cni/pkg/types" "github.com/containernetworking/cni/pkg/types"
current "github.com/containernetworking/cni/pkg/types/100" current "github.com/containernetworking/cni/pkg/types/100"
fakestore "github.com/containernetworking/plugins/plugins/ipam/host-local/backend/testing" fakestore "github.com/containernetworking/plugins/plugins/ipam/host-local/backend/testing"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
) )
type AllocatorTestCase struct { type AllocatorTestCase struct {
@ -77,7 +77,7 @@ func (t AllocatorTestCase) run(idx int) (*current.IPConfig, error) {
p = append(p, Range{Subnet: types.IPNet(*subnet)}) p = append(p, Range{Subnet: types.IPNet(*subnet)})
} }
Expect(p.Canonicalize()).To(Succeed()) Expect(p.Canonicalize()).To(BeNil())
store := fakestore.NewFakeStore(t.ipmap, map[string]net.IP{"rangeid": net.ParseIP(t.lastIP)}) store := fakestore.NewFakeStore(t.ipmap, map[string]net.IP{"rangeid": net.ParseIP(t.lastIP)})
@ -262,6 +262,7 @@ var _ = Describe("host-local ip allocator", func() {
res, err = alloc.Get("ID", "eth0", nil) res, err = alloc.Get("ID", "eth0", nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(res.Address.String()).To(Equal("192.168.1.3/29")) Expect(res.Address.String()).To(Equal("192.168.1.3/29"))
}) })
Context("when requesting a specific IP", func() { Context("when requesting a specific IP", func() {
@ -300,6 +301,7 @@ var _ = Describe("host-local ip allocator", func() {
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
}) })
}) })
}) })
Context("when out of ips", func() { Context("when out of ips", func() {
It("returns a meaningful error", func() { It("returns a meaningful error", func() {
@ -330,7 +332,7 @@ var _ = Describe("host-local ip allocator", func() {
} }
for idx, tc := range testCases { for idx, tc := range testCases {
_, err := tc.run(idx) _, err := tc.run(idx)
Expect(err).To(HaveOccurred()) Expect(err).NotTo(BeNil())
Expect(err.Error()).To(HavePrefix("no IP addresses available in range set")) Expect(err.Error()).To(HavePrefix("no IP addresses available in range set"))
} }
}) })

View File

@ -21,6 +21,7 @@ import (
"github.com/containernetworking/cni/pkg/types" "github.com/containernetworking/cni/pkg/types"
"github.com/containernetworking/cni/pkg/version" "github.com/containernetworking/cni/pkg/version"
"github.com/containernetworking/plugins/pkg/ip" "github.com/containernetworking/plugins/pkg/ip"
) )
@ -42,7 +43,7 @@ type Net struct {
// IPAMConfig represents the IP related network configuration. // IPAMConfig represents the IP related network configuration.
// This nests Range because we initially only supported a single // This nests Range because we initially only supported a single
// range directly, and wish to preserve backwards compatibility // range directly, and wish to preserve backwards compatability
type IPAMConfig struct { type IPAMConfig struct {
*Range *Range
Name string Name string

View File

@ -17,10 +17,9 @@ package allocator
import ( import (
"net" "net"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/containernetworking/cni/pkg/types" "github.com/containernetworking/cni/pkg/types"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
) )
var _ = Describe("IPAM config", func() { var _ = Describe("IPAM config", func() {
@ -416,6 +415,7 @@ var _ = Describe("IPAM config", func() {
}` }`
_, _, err := LoadIPAMConfig([]byte(input), "") _, _, err := LoadIPAMConfig([]byte(input), "")
Expect(err).To(MatchError("invalid range set 0: mixed address families")) Expect(err).To(MatchError("invalid range set 0: mixed address families"))
}) })
It("Should should error on too many ranges", func() { It("Should should error on too many ranges", func() {

View File

@ -125,7 +125,7 @@ func (r *Range) Contains(addr net.IP) bool {
// Overlaps returns true if there is any overlap between ranges // Overlaps returns true if there is any overlap between ranges
func (r *Range) Overlaps(r1 *Range) bool { func (r *Range) Overlaps(r1 *Range) bool {
// different families // different familes
if len(r.RangeStart) != len(r1.RangeStart) { if len(r.RangeStart) != len(r1.RangeStart) {
return false return false
} }

View File

@ -67,8 +67,10 @@ func (s *RangeSet) Canonicalize() error {
} }
if i == 0 { if i == 0 {
fam = len((*s)[i].RangeStart) fam = len((*s)[i].RangeStart)
} else if fam != len((*s)[i].RangeStart) { } else {
return fmt.Errorf("mixed address families") if fam != len((*s)[i].RangeStart) {
return fmt.Errorf("mixed address families")
}
} }
} }

View File

@ -17,7 +17,7 @@ package allocator
import ( import (
"net" "net"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -40,6 +40,7 @@ var _ = Describe("range sets", func() {
r, err = p.RangeFor(net.IP{192, 168, 99, 99}) r, err = p.RangeFor(net.IP{192, 168, 99, 99})
Expect(r).To(BeNil()) Expect(r).To(BeNil())
Expect(err).To(MatchError("192.168.99.99 not in range set 192.168.0.1-192.168.0.254,172.16.1.1-172.16.1.254")) Expect(err).To(MatchError("192.168.99.99 not in range set 192.168.0.1-192.168.0.254,172.16.1.1-172.16.1.254"))
}) })
It("should discover overlaps within a set", func() { It("should discover overlaps within a set", func() {

View File

@ -17,10 +17,11 @@ package allocator
import ( import (
"net" "net"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/containernetworking/cni/pkg/types" "github.com/containernetworking/cni/pkg/types"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
) )
var _ = Describe("IP ranges", func() { var _ = Describe("IP ranges", func() {

View File

@ -15,6 +15,7 @@
package disk package disk
import ( import (
"io/ioutil"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
@ -24,10 +25,8 @@ import (
"github.com/containernetworking/plugins/plugins/ipam/host-local/backend" "github.com/containernetworking/plugins/plugins/ipam/host-local/backend"
) )
const ( const lastIPFilePrefix = "last_reserved_ip."
lastIPFilePrefix = "last_reserved_ip." const LineBreak = "\r\n"
LineBreak = "\r\n"
)
var defaultDataDir = "/var/lib/cni/networks" var defaultDataDir = "/var/lib/cni/networks"
@ -46,7 +45,7 @@ func New(network, dataDir string) (*Store, error) {
dataDir = defaultDataDir dataDir = defaultDataDir
} }
dir := filepath.Join(dataDir, network) dir := filepath.Join(dataDir, network)
if err := os.MkdirAll(dir, 0o755); err != nil { if err := os.MkdirAll(dir, 0755); err != nil {
return nil, err return nil, err
} }
@ -60,7 +59,7 @@ func New(network, dataDir string) (*Store, error) {
func (s *Store) Reserve(id string, ifname string, ip net.IP, rangeID string) (bool, error) { func (s *Store) Reserve(id string, ifname string, ip net.IP, rangeID string) (bool, error) {
fname := GetEscapedPath(s.dataDir, ip.String()) fname := GetEscapedPath(s.dataDir, ip.String())
f, err := os.OpenFile(fname, os.O_RDWR|os.O_EXCL|os.O_CREATE, 0o600) f, err := os.OpenFile(fname, os.O_RDWR|os.O_EXCL|os.O_CREATE, 0644)
if os.IsExist(err) { if os.IsExist(err) {
return false, nil return false, nil
} }
@ -78,7 +77,7 @@ func (s *Store) Reserve(id string, ifname string, ip net.IP, rangeID string) (bo
} }
// store the reserved ip in lastIPFile // store the reserved ip in lastIPFile
ipfile := GetEscapedPath(s.dataDir, lastIPFilePrefix+rangeID) ipfile := GetEscapedPath(s.dataDir, lastIPFilePrefix+rangeID)
err = os.WriteFile(ipfile, []byte(ip.String()), 0o600) err = ioutil.WriteFile(ipfile, []byte(ip.String()), 0644)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -88,21 +87,25 @@ func (s *Store) Reserve(id string, ifname string, ip net.IP, rangeID string) (bo
// LastReservedIP returns the last reserved IP if exists // LastReservedIP returns the last reserved IP if exists
func (s *Store) LastReservedIP(rangeID string) (net.IP, error) { func (s *Store) LastReservedIP(rangeID string) (net.IP, error) {
ipfile := GetEscapedPath(s.dataDir, lastIPFilePrefix+rangeID) ipfile := GetEscapedPath(s.dataDir, lastIPFilePrefix+rangeID)
data, err := os.ReadFile(ipfile) data, err := ioutil.ReadFile(ipfile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return net.ParseIP(string(data)), nil return net.ParseIP(string(data)), nil
} }
func (s *Store) FindByKey(match string) (bool, error) { func (s *Store) Release(ip net.IP) error {
return os.Remove(GetEscapedPath(s.dataDir, ip.String()))
}
func (s *Store) FindByKey(id string, ifname string, match string) (bool, error) {
found := false found := false
err := filepath.Walk(s.dataDir, func(path string, info os.FileInfo, err error) error { err := filepath.Walk(s.dataDir, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() { if err != nil || info.IsDir() {
return nil return nil
} }
data, err := os.ReadFile(path) data, err := ioutil.ReadFile(path)
if err != nil { if err != nil {
return nil return nil
} }
@ -112,31 +115,33 @@ func (s *Store) FindByKey(match string) (bool, error) {
return nil return nil
}) })
return found, err return found, err
} }
func (s *Store) FindByID(id string, ifname string) bool { func (s *Store) FindByID(id string, ifname string) bool {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
found := false
match := strings.TrimSpace(id) + LineBreak + ifname match := strings.TrimSpace(id) + LineBreak + ifname
found, err := s.FindByKey(match) found, err := s.FindByKey(id, ifname, match)
// Match anything created by this id // Match anything created by this id
if !found && err == nil { if !found && err == nil {
match := strings.TrimSpace(id) match := strings.TrimSpace(id)
found, _ = s.FindByKey(match) found, err = s.FindByKey(id, ifname, match)
} }
return found return found
} }
func (s *Store) ReleaseByKey(match string) (bool, error) { func (s *Store) ReleaseByKey(id string, ifname string, match string) (bool, error) {
found := false found := false
err := filepath.Walk(s.dataDir, func(path string, info os.FileInfo, err error) error { err := filepath.Walk(s.dataDir, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() { if err != nil || info.IsDir() {
return nil return nil
} }
data, err := os.ReadFile(path) data, err := ioutil.ReadFile(path)
if err != nil { if err != nil {
return nil return nil
} }
@ -149,18 +154,20 @@ func (s *Store) ReleaseByKey(match string) (bool, error) {
return nil return nil
}) })
return found, err return found, err
} }
// N.B. This function eats errors to be tolerant and // N.B. This function eats errors to be tolerant and
// release as much as possible // release as much as possible
func (s *Store) ReleaseByID(id string, ifname string) error { func (s *Store) ReleaseByID(id string, ifname string) error {
found := false
match := strings.TrimSpace(id) + LineBreak + ifname match := strings.TrimSpace(id) + LineBreak + ifname
found, err := s.ReleaseByKey(match) found, err := s.ReleaseByKey(id, ifname, match)
// For backwards compatibility, look for files written by a previous version // For backwards compatibility, look for files written by a previous version
if !found && err == nil { if !found && err == nil {
match := strings.TrimSpace(id) match := strings.TrimSpace(id)
_, err = s.ReleaseByKey(match) found, err = s.ReleaseByKey(id, ifname, match)
} }
return err return err
} }
@ -178,7 +185,7 @@ func (s *Store) GetByID(id string, ifname string) []net.IP {
if err != nil || info.IsDir() { if err != nil || info.IsDir() {
return nil return nil
} }
data, err := os.ReadFile(path) data, err := ioutil.ReadFile(path)
if err != nil { if err != nil {
return nil return nil
} }
@ -196,7 +203,7 @@ func (s *Store) GetByID(id string, ifname string) []net.IP {
func GetEscapedPath(dataDir string, fname string) string { func GetEscapedPath(dataDir string, fname string) string {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
fname = strings.ReplaceAll(fname, ":", "_") fname = strings.Replace(fname, ":", "_", -1)
} }
return filepath.Join(dataDir, fname) return filepath.Join(dataDir, fname)
} }

View File

@ -15,10 +15,10 @@
package disk package disk
import ( import (
"testing" . "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"testing"
) )
func TestLock(t *testing.T) { func TestLock(t *testing.T) {

View File

@ -15,10 +15,9 @@
package disk package disk
import ( import (
"github.com/alexflint/go-filemutex"
"os" "os"
"path" "path"
"github.com/alexflint/go-filemutex"
) )
// FileLock wraps os.File to be used as a lock using flock // FileLock wraps os.File to be used as a lock using flock

View File

@ -15,22 +15,23 @@
package disk package disk
import ( import (
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
var _ = Describe("Lock Operations", func() { var _ = Describe("Lock Operations", func() {
It("locks a file path", func() { It("locks a file path", func() {
dir, err := os.MkdirTemp("", "") dir, err := ioutil.TempDir("", "")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
// create a dummy file to lock // create a dummy file to lock
path := filepath.Join(dir, "x") path := filepath.Join(dir, "x")
f, err := os.OpenFile(path, os.O_RDONLY|os.O_CREATE, 0o666) f, err := os.OpenFile(path, os.O_RDONLY|os.O_CREATE, 0666)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = f.Close() err = f.Close()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -46,7 +47,7 @@ var _ = Describe("Lock Operations", func() {
}) })
It("locks a folder path", func() { It("locks a folder path", func() {
dir, err := os.MkdirTemp("", "") dir, err := ioutil.TempDir("", "")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer os.RemoveAll(dir) defer os.RemoveAll(dir)

View File

@ -22,6 +22,7 @@ type Store interface {
Close() error Close() error
Reserve(id string, ifname string, ip net.IP, rangeID string) (bool, error) Reserve(id string, ifname string, ip net.IP, rangeID string) (bool, error)
LastReservedIP(rangeID string) (net.IP, error) LastReservedIP(rangeID string) (net.IP, error)
Release(ip net.IP) error
ReleaseByID(id string, ifname string) error ReleaseByID(id string, ifname string) error
GetByID(id string, ifname string) []net.IP GetByID(id string, ifname string) []net.IP
} }

View File

@ -45,7 +45,7 @@ func (s *FakeStore) Close() error {
return nil return nil
} }
func (s *FakeStore) Reserve(id string, _ string, ip net.IP, rangeID string) (bool, error) { func (s *FakeStore) Reserve(id string, ifname string, ip net.IP, rangeID string) (bool, error) {
key := ip.String() key := ip.String()
if _, ok := s.ipMap[key]; !ok { if _, ok := s.ipMap[key]; !ok {
s.ipMap[key] = id s.ipMap[key] = id
@ -63,7 +63,12 @@ func (s *FakeStore) LastReservedIP(rangeID string) (net.IP, error) {
return ip, nil return ip, nil
} }
func (s *FakeStore) ReleaseByID(id string, _ string) error { func (s *FakeStore) Release(ip net.IP) error {
delete(s.ipMap, ip.String())
return nil
}
func (s *FakeStore) ReleaseByID(id string, ifname string) error {
toDelete := []string{} toDelete := []string{}
for k, v := range s.ipMap { for k, v := range s.ipMap {
if v == id { if v == id {
@ -76,7 +81,7 @@ func (s *FakeStore) ReleaseByID(id string, _ string) error {
return nil return nil
} }
func (s *FakeStore) GetByID(id string, _ string) []net.IP { func (s *FakeStore) GetByID(id string, ifname string) []net.IP {
var ips []net.IP var ips []net.IP
for k, v := range s.ipMap { for k, v := range s.ipMap {
if v == id { if v == id {

View File

@ -28,7 +28,6 @@ func parseResolvConf(filename string) (*types.DNS, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer fp.Close()
dns := types.DNS{} dns := types.DNS{}
scanner := bufio.NewScanner(fp) scanner := bufio.NewScanner(fp)

View File

@ -15,12 +15,12 @@
package main package main
import ( import (
"io/ioutil"
"os" "os"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/containernetworking/cni/pkg/types" "github.com/containernetworking/cni/pkg/types"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
) )
var _ = Describe("parsing resolv.conf", func() { var _ = Describe("parsing resolv.conf", func() {
@ -64,7 +64,7 @@ options four
}) })
func parse(contents string) (*types.DNS, error) { func parse(contents string) (*types.DNS, error) {
f, err := os.CreateTemp("", "host_local_resolv") f, err := ioutil.TempFile("", "host_local_resolv")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -15,10 +15,10 @@
package main package main
import ( import (
"testing" . "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"testing"
) )
func TestHostLocal(t *testing.T) { func TestHostLocal(t *testing.T) {

View File

@ -16,19 +16,20 @@ package main
import ( import (
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/containernetworking/cni/pkg/skel" "github.com/containernetworking/cni/pkg/skel"
"github.com/containernetworking/cni/pkg/types" "github.com/containernetworking/cni/pkg/types"
types100 "github.com/containernetworking/cni/pkg/types/100" "github.com/containernetworking/cni/pkg/types/100"
"github.com/containernetworking/plugins/pkg/testutils" "github.com/containernetworking/plugins/pkg/testutils"
"github.com/containernetworking/plugins/plugins/ipam/host-local/backend/disk" "github.com/containernetworking/plugins/plugins/ipam/host-local/backend/disk"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
) )
const LineBreak = "\r\n" const LineBreak = "\r\n"
@ -42,7 +43,7 @@ var _ = Describe("host-local Operations", func() {
BeforeEach(func() { BeforeEach(func() {
var err error var err error
tmpDir, err = os.MkdirTemp("", "host-local_test") tmpDir, err = ioutil.TempDir("", "host-local_test")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
tmpDir = filepath.ToSlash(tmpDir) tmpDir = filepath.ToSlash(tmpDir)
}) })
@ -57,7 +58,7 @@ var _ = Describe("host-local Operations", func() {
ver := ver ver := ver
It(fmt.Sprintf("[%s] allocates and releases addresses with ADD/DEL", ver), func() { It(fmt.Sprintf("[%s] allocates and releases addresses with ADD/DEL", ver), func() {
err := os.WriteFile(filepath.Join(tmpDir, "resolv.conf"), []byte("nameserver 192.0.2.3"), 0o644) err := ioutil.WriteFile(filepath.Join(tmpDir, "resolv.conf"), []byte("nameserver 192.0.2.3"), 0644)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
conf := fmt.Sprintf(`{ conf := fmt.Sprintf(`{
@ -114,7 +115,7 @@ var _ = Describe("host-local Operations", func() {
Gateway: net.ParseIP("2001:db8:1::1"), Gateway: net.ParseIP("2001:db8:1::1"),
}, },
)) ))
Expect(result.IPs).To(HaveLen(2)) Expect(len(result.IPs)).To(Equal(2))
for _, expectedRoute := range []*types.Route{ for _, expectedRoute := range []*types.Route{
{Dst: mustCIDR("0.0.0.0/0"), GW: nil}, {Dst: mustCIDR("0.0.0.0/0"), GW: nil},
@ -133,22 +134,22 @@ var _ = Describe("host-local Operations", func() {
} }
ipFilePath1 := filepath.Join(tmpDir, "mynet", "10.1.2.2") ipFilePath1 := filepath.Join(tmpDir, "mynet", "10.1.2.2")
contents, err := os.ReadFile(ipFilePath1) contents, err := ioutil.ReadFile(ipFilePath1)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(string(contents)).To(Equal(args.ContainerID + LineBreak + ifname)) Expect(string(contents)).To(Equal(args.ContainerID + LineBreak + ifname))
ipFilePath2 := filepath.Join(tmpDir, disk.GetEscapedPath("mynet", "2001:db8:1::2")) ipFilePath2 := filepath.Join(tmpDir, disk.GetEscapedPath("mynet", "2001:db8:1::2"))
contents, err = os.ReadFile(ipFilePath2) contents, err = ioutil.ReadFile(ipFilePath2)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(string(contents)).To(Equal(args.ContainerID + LineBreak + ifname)) Expect(string(contents)).To(Equal(args.ContainerID + LineBreak + ifname))
lastFilePath1 := filepath.Join(tmpDir, "mynet", "last_reserved_ip.0") lastFilePath1 := filepath.Join(tmpDir, "mynet", "last_reserved_ip.0")
contents, err = os.ReadFile(lastFilePath1) contents, err = ioutil.ReadFile(lastFilePath1)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(string(contents)).To(Equal("10.1.2.2")) Expect(string(contents)).To(Equal("10.1.2.2"))
lastFilePath2 := filepath.Join(tmpDir, "mynet", "last_reserved_ip.1") lastFilePath2 := filepath.Join(tmpDir, "mynet", "last_reserved_ip.1")
contents, err = os.ReadFile(lastFilePath2) contents, err = ioutil.ReadFile(lastFilePath2)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(string(contents)).To(Equal("2001:db8:1::2")) Expect(string(contents)).To(Equal("2001:db8:1::2"))
// Release the IP // Release the IP
@ -166,7 +167,7 @@ var _ = Describe("host-local Operations", func() {
It(fmt.Sprintf("[%s] allocates and releases addresses on specific interface with ADD/DEL", ver), func() { It(fmt.Sprintf("[%s] allocates and releases addresses on specific interface with ADD/DEL", ver), func() {
const ifname1 string = "eth1" const ifname1 string = "eth1"
err := os.WriteFile(filepath.Join(tmpDir, "resolv.conf"), []byte("nameserver 192.0.2.3"), 0o644) err := ioutil.WriteFile(filepath.Join(tmpDir, "resolv.conf"), []byte("nameserver 192.0.2.3"), 0644)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
conf0 := fmt.Sprintf(`{ conf0 := fmt.Sprintf(`{
@ -238,12 +239,12 @@ var _ = Describe("host-local Operations", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
ipFilePath0 := filepath.Join(tmpDir, "mynet0", "10.1.2.2") ipFilePath0 := filepath.Join(tmpDir, "mynet0", "10.1.2.2")
contents, err := os.ReadFile(ipFilePath0) contents, err := ioutil.ReadFile(ipFilePath0)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(string(contents)).To(Equal(args0.ContainerID + LineBreak + ifname)) Expect(string(contents)).To(Equal(args0.ContainerID + LineBreak + ifname))
ipFilePath1 := filepath.Join(tmpDir, "mynet1", "10.2.2.2") ipFilePath1 := filepath.Join(tmpDir, "mynet1", "10.2.2.2")
contents, err = os.ReadFile(ipFilePath1) contents, err = ioutil.ReadFile(ipFilePath1)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(string(contents)).To(Equal(args1.ContainerID + LineBreak + ifname1)) Expect(string(contents)).To(Equal(args1.ContainerID + LineBreak + ifname1))
@ -256,7 +257,7 @@ var _ = Describe("host-local Operations", func() {
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
// reread ipFilePath1, ensure that ifname1 didn't get deleted // reread ipFilePath1, ensure that ifname1 didn't get deleted
contents, err = os.ReadFile(ipFilePath1) contents, err = ioutil.ReadFile(ipFilePath1)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(string(contents)).To(Equal(args1.ContainerID + LineBreak + ifname1)) Expect(string(contents)).To(Equal(args1.ContainerID + LineBreak + ifname1))
@ -310,7 +311,7 @@ var _ = Describe("host-local Operations", func() {
result0, err := types100.GetResult(r0) result0, err := types100.GetResult(r0)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(result0.IPs).Should(HaveLen(1)) Expect(len(result0.IPs)).Should(Equal(1))
Expect(result0.IPs[0].Address.String()).Should(Equal("10.1.2.2/24")) Expect(result0.IPs[0].Address.String()).Should(Equal("10.1.2.2/24"))
// Allocate the IP with the same container ID // Allocate the IP with the same container ID
@ -330,7 +331,7 @@ var _ = Describe("host-local Operations", func() {
result1, err := types100.GetResult(r1) result1, err := types100.GetResult(r1)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(result1.IPs).Should(HaveLen(1)) Expect(len(result1.IPs)).Should(Equal(1))
Expect(result1.IPs[0].Address.String()).Should(Equal("10.1.2.3/24")) Expect(result1.IPs[0].Address.String()).Should(Equal("10.1.2.3/24"))
// Allocate the IP with the same container ID again // Allocate the IP with the same container ID again
@ -356,7 +357,7 @@ var _ = Describe("host-local Operations", func() {
}) })
It(fmt.Sprintf("[%s] verify DEL works on backwards compatible allocate", ver), func() { It(fmt.Sprintf("[%s] verify DEL works on backwards compatible allocate", ver), func() {
err := os.WriteFile(filepath.Join(tmpDir, "resolv.conf"), []byte("nameserver 192.0.2.3"), 0o644) err := ioutil.WriteFile(filepath.Join(tmpDir, "resolv.conf"), []byte("nameserver 192.0.2.3"), 0644)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
conf := fmt.Sprintf(`{ conf := fmt.Sprintf(`{
@ -394,10 +395,10 @@ var _ = Describe("host-local Operations", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
ipFilePath := filepath.Join(tmpDir, "mynet", "10.1.2.2") ipFilePath := filepath.Join(tmpDir, "mynet", "10.1.2.2")
contents, err := os.ReadFile(ipFilePath) contents, err := ioutil.ReadFile(ipFilePath)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(string(contents)).To(Equal(args.ContainerID + LineBreak + ifname)) Expect(string(contents)).To(Equal(args.ContainerID + LineBreak + ifname))
err = os.WriteFile(ipFilePath, []byte(strings.TrimSpace(args.ContainerID)), 0o644) err = ioutil.WriteFile(ipFilePath, []byte(strings.TrimSpace(args.ContainerID)), 0644)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = testutils.CmdDelWithArgs(args, func() error { err = testutils.CmdDelWithArgs(args, func() error {
@ -465,7 +466,7 @@ var _ = Describe("host-local Operations", func() {
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
ipFilePath := filepath.Join(tmpDir, "mynet", result.IPs[0].Address.IP.String()) ipFilePath := filepath.Join(tmpDir, "mynet", result.IPs[0].Address.IP.String())
contents, err := os.ReadFile(ipFilePath) contents, err := ioutil.ReadFile(ipFilePath)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(string(contents)).To(Equal("dummy" + LineBreak + ifname)) Expect(string(contents)).To(Equal("dummy" + LineBreak + ifname))
@ -504,7 +505,7 @@ var _ = Describe("host-local Operations", func() {
return cmdAdd(args) return cmdAdd(args)
}) })
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(strings.Index(string(out), "Error retrieving last reserved ip")).To(Equal(-1)) Expect(strings.Index(string(out), "Error retriving last reserved ip")).To(Equal(-1))
}) })
It(fmt.Sprintf("[%s] allocates a custom IP when requested by config args", ver), func() { It(fmt.Sprintf("[%s] allocates a custom IP when requested by config args", ver), func() {
@ -546,7 +547,7 @@ var _ = Describe("host-local Operations", func() {
}) })
It(fmt.Sprintf("[%s] allocates custom IPs from multiple ranges", ver), func() { It(fmt.Sprintf("[%s] allocates custom IPs from multiple ranges", ver), func() {
err := os.WriteFile(filepath.Join(tmpDir, "resolv.conf"), []byte("nameserver 192.0.2.3"), 0o644) err := ioutil.WriteFile(filepath.Join(tmpDir, "resolv.conf"), []byte("nameserver 192.0.2.3"), 0644)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
conf := fmt.Sprintf(`{ conf := fmt.Sprintf(`{
@ -594,7 +595,7 @@ var _ = Describe("host-local Operations", func() {
}) })
It(fmt.Sprintf("[%s] allocates custom IPs from multiple protocols", ver), func() { It(fmt.Sprintf("[%s] allocates custom IPs from multiple protocols", ver), func() {
err := os.WriteFile(filepath.Join(tmpDir, "resolv.conf"), []byte("nameserver 192.0.2.3"), 0o644) err := ioutil.WriteFile(filepath.Join(tmpDir, "resolv.conf"), []byte("nameserver 192.0.2.3"), 0644)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
conf := fmt.Sprintf(`{ conf := fmt.Sprintf(`{

View File

@ -15,31 +15,26 @@
package main package main
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"strings" "strings"
bv "github.com/containernetworking/plugins/pkg/utils/buildversion"
"github.com/containernetworking/plugins/plugins/ipam/host-local/backend/allocator"
"github.com/containernetworking/plugins/plugins/ipam/host-local/backend/disk"
"github.com/containernetworking/cni/pkg/skel" "github.com/containernetworking/cni/pkg/skel"
"github.com/containernetworking/cni/pkg/types" "github.com/containernetworking/cni/pkg/types"
current "github.com/containernetworking/cni/pkg/types/100" current "github.com/containernetworking/cni/pkg/types/100"
"github.com/containernetworking/cni/pkg/version" "github.com/containernetworking/cni/pkg/version"
bv "github.com/containernetworking/plugins/pkg/utils/buildversion"
"github.com/containernetworking/plugins/plugins/ipam/host-local/backend/allocator"
"github.com/containernetworking/plugins/plugins/ipam/host-local/backend/disk"
) )
func main() { func main() {
skel.PluginMainFuncs(skel.CNIFuncs{ skel.PluginMain(cmdAdd, cmdCheck, cmdDel, version.All, bv.BuildString("host-local"))
Add: cmdAdd,
Check: cmdCheck,
Del: cmdDel,
/* FIXME GC */
/* FIXME Status */
}, version.All, bv.BuildString("host-local"))
} }
func cmdCheck(args *skel.CmdArgs) error { func cmdCheck(args *skel.CmdArgs) error {
ipamConf, _, err := allocator.LoadIPAMConfig(args.StdinData, args.Args) ipamConf, _, err := allocator.LoadIPAMConfig(args.StdinData, args.Args)
if err != nil { if err != nil {
return err return err
@ -53,8 +48,8 @@ func cmdCheck(args *skel.CmdArgs) error {
} }
defer store.Close() defer store.Close()
containerIPFound := store.FindByID(args.ContainerID, args.IfName) containerIpFound := store.FindByID(args.ContainerID, args.IfName)
if !containerIPFound { if containerIpFound == false {
return fmt.Errorf("host-local: Failed to find address added by container %v", args.ContainerID) return fmt.Errorf("host-local: Failed to find address added by container %v", args.ContainerID)
} }
@ -89,7 +84,7 @@ func cmdAdd(args *skel.CmdArgs) error {
// Store all requested IPs in a map, so we can easily remove ones we use // Store all requested IPs in a map, so we can easily remove ones we use
// and error if some remain // and error if some remain
requestedIPs := map[string]net.IP{} // net.IP cannot be a key requestedIPs := map[string]net.IP{} //net.IP cannot be a key
for _, ip := range ipamConf.IPArgs { for _, ip := range ipamConf.IPArgs {
requestedIPs[ip.String()] = ip requestedIPs[ip.String()] = ip
@ -131,7 +126,7 @@ func cmdAdd(args *skel.CmdArgs) error {
for _, ip := range requestedIPs { for _, ip := range requestedIPs {
errstr = errstr + " " + ip.String() errstr = errstr + " " + ip.String()
} }
return errors.New(errstr) return fmt.Errorf(errstr)
} }
result.Routes = ipamConf.Routes result.Routes = ipamConf.Routes
@ -152,18 +147,18 @@ func cmdDel(args *skel.CmdArgs) error {
defer store.Close() defer store.Close()
// Loop through all ranges, releasing all IPs, even if an error occurs // Loop through all ranges, releasing all IPs, even if an error occurs
var errs []string var errors []string
for idx, rangeset := range ipamConf.Ranges { for idx, rangeset := range ipamConf.Ranges {
ipAllocator := allocator.NewIPAllocator(&rangeset, store, idx) ipAllocator := allocator.NewIPAllocator(&rangeset, store, idx)
err := ipAllocator.Release(args.ContainerID, args.IfName) err := ipAllocator.Release(args.ContainerID, args.IfName)
if err != nil { if err != nil {
errs = append(errs, err.Error()) errors = append(errors, err.Error())
} }
} }
if errs != nil { if errors != nil {
return errors.New(strings.Join(errs, ";")) return fmt.Errorf(strings.Join(errors, ";"))
} }
return nil return nil
} }

View File

@ -68,13 +68,7 @@ type Address struct {
} }
func main() { func main() {
skel.PluginMainFuncs(skel.CNIFuncs{ skel.PluginMain(cmdAdd, cmdCheck, cmdDel, version.All, bv.BuildString("static"))
Add: cmdAdd,
Check: cmdCheck,
Del: cmdDel,
/* FIXME GC */
/* FIXME Status */
}, version.All, bv.BuildString("static"))
} }
func loadNetConf(bytes []byte) (*types.NetConf, string, error) { func loadNetConf(bytes []byte) (*types.NetConf, string, error) {
@ -282,7 +276,7 @@ func cmdAdd(args *skel.CmdArgs) error {
return types.PrintResult(result, confVersion) return types.PrintResult(result, confVersion)
} }
func cmdDel(_ *skel.CmdArgs) error { func cmdDel(args *skel.CmdArgs) error {
// Nothing required because of no resource allocation in static plugin. // Nothing required because of no resource allocation in static plugin.
return nil return nil
} }

View File

@ -17,7 +17,7 @@ package main_test
import ( import (
"testing" "testing"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )

View File

@ -19,13 +19,13 @@ import (
"net" "net"
"strings" "strings"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/containernetworking/cni/pkg/skel" "github.com/containernetworking/cni/pkg/skel"
"github.com/containernetworking/cni/pkg/types" "github.com/containernetworking/cni/pkg/types"
types100 "github.com/containernetworking/cni/pkg/types/100" types100 "github.com/containernetworking/cni/pkg/types/100"
"github.com/containernetworking/plugins/pkg/testutils" "github.com/containernetworking/plugins/pkg/testutils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
) )
var _ = Describe("static Operations", func() { var _ = Describe("static Operations", func() {
@ -97,7 +97,7 @@ var _ = Describe("static Operations", func() {
Gateway: net.ParseIP("3ffe:ffff:0::1"), Gateway: net.ParseIP("3ffe:ffff:0::1"),
}, },
)) ))
Expect(result.IPs).To(HaveLen(2)) Expect(len(result.IPs)).To(Equal(2))
Expect(result.Routes).To(Equal([]*types.Route{ Expect(result.Routes).To(Equal([]*types.Route{
{Dst: mustCIDR("0.0.0.0/0")}, {Dst: mustCIDR("0.0.0.0/0")},
@ -206,7 +206,7 @@ var _ = Describe("static Operations", func() {
Gateway: net.ParseIP("10.10.0.254"), Gateway: net.ParseIP("10.10.0.254"),
})) }))
Expect(result.IPs).To(HaveLen(1)) Expect(len(result.IPs)).To(Equal(1))
Expect(result.Routes).To(Equal([]*types.Route{ Expect(result.Routes).To(Equal([]*types.Route{
{Dst: mustCIDR("0.0.0.0/0")}, {Dst: mustCIDR("0.0.0.0/0")},
@ -272,7 +272,7 @@ var _ = Describe("static Operations", func() {
Gateway: nil, Gateway: nil,
})) }))
Expect(result.IPs).To(HaveLen(2)) Expect(len(result.IPs)).To(Equal(2))
// Release the IP // Release the IP
err = testutils.CmdDelWithArgs(args, func() error { err = testutils.CmdDelWithArgs(args, func() error {
@ -337,7 +337,7 @@ var _ = Describe("static Operations", func() {
Address: mustCIDR("3ffe:ffff:0:01ff::1/64"), Address: mustCIDR("3ffe:ffff:0:01ff::1/64"),
}, },
)) ))
Expect(result.IPs).To(HaveLen(2)) Expect(len(result.IPs)).To(Equal(2))
Expect(result.Routes).To(Equal([]*types.Route{ Expect(result.Routes).To(Equal([]*types.Route{
{Dst: mustCIDR("0.0.0.0/0"), GW: net.ParseIP("10.10.0.254")}, {Dst: mustCIDR("0.0.0.0/0"), GW: net.ParseIP("10.10.0.254")},
{Dst: mustCIDR("3ffe:ffff:0:01ff::1/64"), GW: net.ParseIP("3ffe:ffff:0::1")}, {Dst: mustCIDR("3ffe:ffff:0:01ff::1/64"), GW: net.ParseIP("3ffe:ffff:0::1")},
@ -407,7 +407,7 @@ var _ = Describe("static Operations", func() {
Address: mustCIDR("3ffe:ffff:0:01ff::1/64"), Address: mustCIDR("3ffe:ffff:0:01ff::1/64"),
}, },
)) ))
Expect(result.IPs).To(HaveLen(2)) Expect(len(result.IPs)).To(Equal(2))
Expect(result.Routes).To(Equal([]*types.Route{ Expect(result.Routes).To(Equal([]*types.Route{
{Dst: mustCIDR("0.0.0.0/0"), GW: net.ParseIP("10.10.0.254")}, {Dst: mustCIDR("0.0.0.0/0"), GW: net.ParseIP("10.10.0.254")},
{Dst: mustCIDR("3ffe:ffff:0:01ff::1/64"), GW: net.ParseIP("3ffe:ffff:0::1")}, {Dst: mustCIDR("3ffe:ffff:0:01ff::1/64"), GW: net.ParseIP("3ffe:ffff:0::1")},
@ -482,7 +482,7 @@ var _ = Describe("static Operations", func() {
Address: mustCIDR("3ffe:ffff:0:01ff::1/64"), Address: mustCIDR("3ffe:ffff:0:01ff::1/64"),
}, },
)) ))
Expect(result.IPs).To(HaveLen(2)) Expect(len(result.IPs)).To(Equal(2))
Expect(result.Routes).To(Equal([]*types.Route{ Expect(result.Routes).To(Equal([]*types.Route{
{Dst: mustCIDR("0.0.0.0/0"), GW: net.ParseIP("10.10.0.254")}, {Dst: mustCIDR("0.0.0.0/0"), GW: net.ParseIP("10.10.0.254")},
{Dst: mustCIDR("3ffe:ffff:0:01ff::1/64"), GW: net.ParseIP("3ffe:ffff:0::1")}, {Dst: mustCIDR("3ffe:ffff:0:01ff::1/64"), GW: net.ParseIP("3ffe:ffff:0::1")},

Some files were not shown because too many files have changed in this diff Show More