diff --git a/pkg/ip/ip.go b/pkg/ip/ip.go new file mode 100644 index 00000000..4469e1b5 --- /dev/null +++ b/pkg/ip/ip.go @@ -0,0 +1,105 @@ +// Copyright 2021 CNI authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip + +import ( + "fmt" + "net" + "strings" +) + +// IP is a CNI maintained type inherited from net.IPNet which can +// represent a single IP address with or without prefix. +type IP struct { + net.IPNet +} + +// newIP will create an IP with net.IP and net.IPMask +func newIP(ip net.IP, mask net.IPMask) *IP { + return &IP{ + IPNet: net.IPNet{ + IP: ip, + Mask: mask, + }, + } +} + +// ParseIP will parse string s as an IP, and return it. +// The string s must be formed like [/]. +// If s is not a valid textual representation of an IP, +// will return nil. +func ParseIP(s string) *IP { + if strings.ContainsAny(s, "/") { + ip, ipNet, err := net.ParseCIDR(s) + if err != nil { + return nil + } + return newIP(ip, ipNet.Mask) + } else { + ip := net.ParseIP(s) + if ip == nil { + return nil + } + return newIP(ip, nil) + } +} + +// ToIP will return a net.IP in standard form from this IP. +// If this IP can not be converted to a valid net.IP, will return nil. +func (i *IP) ToIP() net.IP { + switch { + case i.IP.To4() != nil: + return i.IP.To4() + case i.IP.To16() != nil: + return i.IP.To16() + default: + return nil + } +} + +// String returns the string form of this IP. +func (i *IP) String() string { + if len(i.Mask) > 0 { + return i.IPNet.String() + } + return i.IP.String() +} + +// MarshalText implements the encoding.TextMarshaler interface. +// The encoding is the same as returned by String, +// But when len(ip) is zero, will return an empty slice. +func (i *IP) MarshalText() ([]byte, error) { + if len(i.IP) == 0 { + return []byte{}, nil + } + return []byte(i.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// The textual bytes are expected in a form accepted by Parse, +// But when len(b) is zero, will return an empty IP. +func (i *IP) UnmarshalText(b []byte) error { + if len(b) == 0 { + *i = IP{} + return nil + } + + ip := ParseIP(string(b)) + if ip == nil { + return fmt.Errorf("invalid IP address %s", string(b)) + } + *i = *ip + return nil +} diff --git a/pkg/ip/ip_test.go b/pkg/ip/ip_test.go new file mode 100644 index 00000000..41ddee26 --- /dev/null +++ b/pkg/ip/ip_test.go @@ -0,0 +1,272 @@ +// Copyright 2021 CNI authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip + +import ( + "encoding/json" + "fmt" + "net" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("IP Operations", func() { + It("Parse", func() { + testCases := []struct { + ipStr string + expected *IP + }{ + { + "192.168.0.10", + newIP(net.IPv4(192, 168, 0, 10), nil), + }, + { + "2001:db8::1", + newIP(net.ParseIP("2001:db8::1"), nil), + }, + { + "192.168.0.10/24", + newIP(net.IPv4(192, 168, 0, 10), net.IPv4Mask(255, 255, 255, 0)), + }, + { + "2001:db8::1/64", + newIP(net.ParseIP("2001:db8::1"), net.CIDRMask(64, 128)), + }, + { + "invalid", + nil, + }, + } + + for _, test := range testCases { + ip := ParseIP(test.ipStr) + + Expect(ip).To(Equal(test.expected)) + } + }) + + It("String", func() { + testCases := []struct { + ip *IP + expected string + }{ + { + newIP(net.IPv4(192, 168, 0, 1), net.IPv4Mask(255, 255, 255, 0)), + "192.168.0.1/24", + }, + { + newIP(net.IPv4(192, 168, 0, 2), nil), + "192.168.0.2", + }, + { + newIP(net.ParseIP("2001:db8::1"), nil), + "2001:db8::1", + }, + { + newIP(net.ParseIP("2001:db8::1"), net.CIDRMask(64, 128)), + "2001:db8::1/64", + }, + { + newIP(nil, nil), + "", + }, + } + + for _, test := range testCases { + Expect(test.ip.String()).To(Equal(test.expected)) + } + }) + + It("ToIP", func() { + testCases := []struct { + ip *IP + expectedLen int + expectedIP net.IP + }{ + { + newIP(net.IPv4(192, 168, 0, 1), net.IPv4Mask(255, 255, 255, 0)), + net.IPv4len, + net.IP{192, 168, 0, 1}, + }, + { + newIP(net.IPv4(192, 168, 0, 2), nil), + net.IPv4len, + net.IP{192, 168, 0, 2}, + }, + { + newIP(net.ParseIP("2001:db8::1"), nil), + net.IPv6len, + net.IP{32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, + }, + { + newIP(net.ParseIP("2001:db8::1"), net.CIDRMask(64, 128)), + net.IPv6len, + net.IP{32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}, + }, + { + newIP(nil, nil), + 0, + nil, + }, + } + + for _, test := range testCases { + Expect(len(test.ip.ToIP())).To(Equal(test.expectedLen)) + Expect(test.ip.ToIP()).To(Equal(test.expectedIP)) + } + }) + + It("Encode", func() { + testCases := []struct { + object interface{} + expected string + }{ + { + newIP(net.IPv4(192, 168, 0, 1), net.IPv4Mask(255, 255, 255, 0)), + `"192.168.0.1/24"`, + }, + { + newIP(net.IPv4(192, 168, 0, 2), nil), + `"192.168.0.2"`, + }, + { + newIP(net.ParseIP("2001:db8::1"), nil), + `"2001:db8::1"`, + }, + { + newIP(net.ParseIP("2001:db8::1"), net.CIDRMask(64, 128)), + `"2001:db8::1/64"`, + }, + { + newIP(nil, nil), + `""`, + }, + { + []*IP{ + newIP(net.IPv4(192, 168, 0, 1), net.IPv4Mask(255, 255, 255, 0)), + newIP(net.IPv4(192, 168, 0, 2), nil), + newIP(net.ParseIP("2001:db8::1"), nil), + newIP(net.ParseIP("2001:db8::1"), net.CIDRMask(64, 128)), + newIP(nil, nil), + }, + `["192.168.0.1/24","192.168.0.2","2001:db8::1","2001:db8::1/64",""]`, + }, + } + + for _, test := range testCases { + bytes, err := json.Marshal(test.object) + + Expect(err).NotTo(HaveOccurred()) + Expect(string(bytes)).To(Equal(test.expected)) + } + }) + + It("Decode", func() { + Context("valid IP", func() { + testCases := []struct { + text string + expected *IP + }{ + { + `"192.168.0.1"`, + newIP(net.IPv4(192, 168, 0, 1), nil), + }, + { + `"192.168.0.1/24"`, + newIP(net.IPv4(192, 168, 0, 1), net.IPv4Mask(255, 255, 255, 0)), + }, + { + `"2001:db8::1"`, + newIP(net.ParseIP("2001:db8::1"), nil), + }, + { + `"2001:db8::1/64"`, + newIP(net.ParseIP("2001:db8::1"), net.CIDRMask(64, 128)), + }, + } + + for _, test := range testCases { + ip := &IP{} + err := json.Unmarshal([]byte(test.text), ip) + + Expect(err).NotTo(HaveOccurred()) + Expect(ip).To(Equal(test.expected)) + } + + }) + + Context("empty text", func() { + ip := &IP{} + err := json.Unmarshal([]byte(`""`), ip) + + Expect(err).NotTo(HaveOccurred()) + Expect(ip).To(Equal(newIP(nil, nil))) + }) + + Context("invalid IP", func() { + testCases := []struct { + text string + expectedErr error + }{ + { + `"192.168.0.1000"`, + fmt.Errorf("invalid IP address 192.168.0.1000"), + }, + { + `"2001:db8::1/256"`, + fmt.Errorf("invalid IP address 2001:db8::1/256"), + }, + { + `"test"`, + fmt.Errorf("invalid IP address test"), + }, + } + + for _, test := range testCases { + err := json.Unmarshal([]byte(test.text), &IP{}) + + Expect(err).To(HaveOccurred()) + Expect(err).To(Equal(test.expectedErr)) + } + }) + + Context("IP slice", func() { + testCases := []struct { + text string + expected []*IP + }{ + { + `["192.168.0.1/24","192.168.0.2","2001:db8::1","2001:db8::1/64",""]`, + []*IP{ + newIP(net.IPv4(192, 168, 0, 1), net.IPv4Mask(255, 255, 255, 0)), + newIP(net.IPv4(192, 168, 0, 2), nil), + newIP(net.ParseIP("2001:db8::1"), nil), + newIP(net.ParseIP("2001:db8::1"), net.CIDRMask(64, 128)), + newIP(nil, nil), + }, + }, + } + + for _, test := range testCases { + ips := make([]*IP, 0) + err := json.Unmarshal([]byte(test.text), &ips) + + Expect(err).NotTo(HaveOccurred()) + Expect(ips).To(Equal(test.expected)) + } + }) + }) +})