From 63a6dbcfd675f85099a186100351f9f3bf6a2b9f Mon Sep 17 00:00:00 2001 From: Bruce Ma Date: Thu, 17 Nov 2022 17:54:16 +0800 Subject: [PATCH] fix bug on getting NextIP of addresses with first byte 0 1. get the right next IP of addresses of first byte 0 2. refactor some methods to handle illegal IPs or IPNets 3. add some unit tests Signed-off-by: Bruce Ma --- pkg/ip/cidr.go | 76 +++++++++++--- pkg/ip/cidr_test.go | 247 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 307 insertions(+), 16 deletions(-) create mode 100644 pkg/ip/cidr_test.go diff --git a/pkg/ip/cidr.go b/pkg/ip/cidr.go index 7acc2d47..8b380fc7 100644 --- a/pkg/ip/cidr.go +++ b/pkg/ip/cidr.go @@ -19,43 +19,87 @@ import ( "net" ) -// NextIP returns IP incremented by 1 +// NextIP returns IP incremented by 1, if IP is invalid, return nil func NextIP(ip net.IP) net.IP { - i := ipToInt(ip) - return intToIP(i.Add(i, big.NewInt(1))) + normalizedIP := normalizeIP(ip) + if normalizedIP == nil { + return nil + } + + i := ipToInt(normalizedIP) + return intToIP(i.Add(i, big.NewInt(1)), len(normalizedIP) == net.IPv6len) } -// PrevIP returns IP decremented by 1 +// PrevIP returns IP decremented by 1, if IP is invalid, return nil func PrevIP(ip net.IP) net.IP { - i := ipToInt(ip) - return intToIP(i.Sub(i, big.NewInt(1))) + normalizedIP := normalizeIP(ip) + if normalizedIP == nil { + return nil + } + + i := ipToInt(normalizedIP) + return intToIP(i.Sub(i, big.NewInt(1)), len(normalizedIP) == net.IPv6len) } // Cmp compares two IPs, returning the usual ordering: // a < b : -1 // a == b : 0 // a > b : 1 +// incomparable : -2 func Cmp(a, b net.IP) int { - aa := ipToInt(a) - bb := ipToInt(b) - return aa.Cmp(bb) + normalizedA := normalizeIP(a) + normalizedB := normalizeIP(b) + + if len(normalizedA) == len(normalizedB) && len(normalizedA) != 0 { + return ipToInt(normalizedA).Cmp(ipToInt(normalizedB)) + } + + return -2 } func ipToInt(ip net.IP) *big.Int { - if v := ip.To4(); v != nil { - return big.NewInt(0).SetBytes(v) + return big.NewInt(0).SetBytes(ip) +} + +func intToIP(i *big.Int, isIPv6 bool) net.IP { + intBytes := i.Bytes() + + if len(intBytes) == net.IPv4len || len(intBytes) == net.IPv6len { + return intBytes } - return big.NewInt(0).SetBytes(ip.To16()) + + if isIPv6 { + return append(make([]byte, net.IPv6len-len(intBytes)), intBytes...) + } + + return append(make([]byte, net.IPv4len-len(intBytes)), intBytes...) } -func intToIP(i *big.Int) net.IP { - return net.IP(i.Bytes()) +// normalizeIP will normalize IP by family, +// IPv4 : 4-byte form +// IPv6 : 16-byte form +// others : nil +func normalizeIP(ip net.IP) net.IP { + if ipTo4 := ip.To4(); ipTo4 != nil { + return ipTo4 + } + return ip.To16() } -// Network masks off the host portion of the IP +// Network masks off the host portion of the IP, if IPNet is invalid, +// return nil func Network(ipn *net.IPNet) *net.IPNet { + if ipn == nil { + return nil + } + + maskedIP := ipn.IP.Mask(ipn.Mask) + if maskedIP == nil { + return nil + } + return &net.IPNet{ - IP: ipn.IP.Mask(ipn.Mask), + IP: maskedIP, Mask: ipn.Mask, } } diff --git a/pkg/ip/cidr_test.go b/pkg/ip/cidr_test.go new file mode 100644 index 00000000..1f617ce7 --- /dev/null +++ b/pkg/ip/cidr_test.go @@ -0,0 +1,247 @@ +// Copyright 2022 CNI authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip + +import ( + "net" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("CIDR functions", func() { + It("NextIP", func() { + testCases := []struct { + ip net.IP + nextIP net.IP + }{ + { + []byte{192, 0, 2}, + nil, + }, + { + net.ParseIP("192.168.0.1"), + net.IPv4(192, 168, 0, 2).To4(), + }, + { + net.ParseIP("192.168.0.255"), + net.IPv4(192, 168, 1, 0).To4(), + }, + { + net.ParseIP("0.1.0.5"), + net.IPv4(0, 1, 0, 6).To4(), + }, + { + net.ParseIP("AB12::123"), + net.ParseIP("AB12::124"), + }, + { + net.ParseIP("AB12::FFFF"), + net.ParseIP("AB12::1:0"), + }, + { + net.ParseIP("0::123"), + net.ParseIP("0::124"), + }, + } + + for _, test := range testCases { + ip := NextIP(test.ip) + + Expect(ip).To(Equal(test.nextIP)) + } + }) + + It("PrevIP", func() { + testCases := []struct { + ip net.IP + prevIP net.IP + }{ + { + []byte{192, 0, 2}, + nil, + }, + { + net.ParseIP("192.168.0.2"), + net.IPv4(192, 168, 0, 1).To4(), + }, + { + net.ParseIP("192.168.1.0"), + net.IPv4(192, 168, 0, 255).To4(), + }, + { + net.ParseIP("0.1.0.5"), + net.IPv4(0, 1, 0, 4).To4(), + }, + { + net.ParseIP("AB12::123"), + net.ParseIP("AB12::122"), + }, + { + net.ParseIP("AB12::1:0"), + net.ParseIP("AB12::FFFF"), + }, + { + net.ParseIP("0::124"), + net.ParseIP("0::123"), + }, + } + + for _, test := range testCases { + ip := PrevIP(test.ip) + + Expect(ip).To(Equal(test.prevIP)) + } + }) + + It("Cmp", func() { + testCases := []struct { + a net.IP + b net.IP + result int + }{ + { + net.ParseIP("192.168.0.2"), + nil, + -2, + }, + { + net.ParseIP("192.168.0.2"), + []byte{192, 168, 5}, + -2, + }, + { + net.ParseIP("192.168.0.2"), + net.ParseIP("AB12::123"), + -2, + }, + { + net.ParseIP("192.168.0.2"), + net.ParseIP("192.168.0.5"), + -1, + }, + { + net.ParseIP("192.168.0.2"), + net.ParseIP("192.168.0.5").To4(), + -1, + }, + { + net.ParseIP("192.168.0.10"), + net.ParseIP("192.168.0.5"), + 1, + }, + { + net.ParseIP("192.168.0.10"), + net.ParseIP("192.168.0.10"), + 0, + }, + { + net.ParseIP("192.168.0.10"), + net.ParseIP("192.168.0.10").To4(), + 0, + }, + { + net.ParseIP("AB12::122"), + net.ParseIP("AB12::123"), + -1, + }, + { + net.ParseIP("AB12::210"), + net.ParseIP("AB12::123"), + 1, + }, + { + net.ParseIP("AB12::210"), + net.ParseIP("AB12::210"), + 0, + }, + } + + for _, test := range testCases { + result := Cmp(test.a, test.b) + + Expect(result).To(Equal(test.result)) + } + }) + + It("Network", func() { + testCases := []struct { + ipNet *net.IPNet + result *net.IPNet + }{ + { + nil, + nil, + }, + { + &net.IPNet{ + IP: nil, + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + nil, + }, + { + &net.IPNet{ + IP: net.IPv4(192, 168, 0, 1), + Mask: nil, + }, + nil, + }, + { + &net.IPNet{ + IP: net.ParseIP("AB12::123"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + nil, + }, + { + &net.IPNet{ + IP: net.IPv4(192, 168, 0, 100).To4(), + Mask: net.CIDRMask(120, 128), + }, + &net.IPNet{ + IP: net.IPv4(192, 168, 0, 0).To4(), + Mask: net.CIDRMask(120, 128), + }, + }, + { + &net.IPNet{ + IP: net.IPv4(192, 168, 0, 100), + Mask: net.CIDRMask(24, 32), + }, + &net.IPNet{ + IP: net.IPv4(192, 168, 0, 0).To4(), + Mask: net.CIDRMask(24, 32), + }, + }, + { + &net.IPNet{ + IP: net.ParseIP("AB12::123"), + Mask: net.CIDRMask(120, 128), + }, + &net.IPNet{ + IP: net.ParseIP("AB12::100"), + Mask: net.CIDRMask(120, 128), + }, + }, + } + + for _, test := range testCases { + result := Network(test.ipNet) + + Expect(result).To(Equal(test.result)) + } + }) +})