fix bug on getting NextIP of addresses with first byte 0

1. get the right next IP of addresses of first byte 0
2. refactor some methods to handle illegal IPs or IPNets
3. add some unit tests

Signed-off-by: Bruce Ma <brucema19901024@gmail.com>
This commit is contained in:
Bruce Ma
2022-11-17 17:54:16 +08:00
parent 7e9ada51e7
commit 63a6dbcfd6
2 changed files with 307 additions and 16 deletions

View File

@ -19,43 +19,87 @@ import (
"net"
)
// NextIP returns IP incremented by 1
// NextIP returns IP incremented by 1, if IP is invalid, return nil
func NextIP(ip net.IP) net.IP {
i := ipToInt(ip)
return intToIP(i.Add(i, big.NewInt(1)))
normalizedIP := normalizeIP(ip)
if normalizedIP == nil {
return nil
}
i := ipToInt(normalizedIP)
return intToIP(i.Add(i, big.NewInt(1)), len(normalizedIP) == net.IPv6len)
}
// PrevIP returns IP decremented by 1
// PrevIP returns IP decremented by 1, if IP is invalid, return nil
func PrevIP(ip net.IP) net.IP {
i := ipToInt(ip)
return intToIP(i.Sub(i, big.NewInt(1)))
normalizedIP := normalizeIP(ip)
if normalizedIP == nil {
return nil
}
i := ipToInt(normalizedIP)
return intToIP(i.Sub(i, big.NewInt(1)), len(normalizedIP) == net.IPv6len)
}
// Cmp compares two IPs, returning the usual ordering:
// a < b : -1
// a == b : 0
// a > b : 1
// incomparable : -2
func Cmp(a, b net.IP) int {
aa := ipToInt(a)
bb := ipToInt(b)
return aa.Cmp(bb)
normalizedA := normalizeIP(a)
normalizedB := normalizeIP(b)
if len(normalizedA) == len(normalizedB) && len(normalizedA) != 0 {
return ipToInt(normalizedA).Cmp(ipToInt(normalizedB))
}
return -2
}
func ipToInt(ip net.IP) *big.Int {
if v := ip.To4(); v != nil {
return big.NewInt(0).SetBytes(v)
return big.NewInt(0).SetBytes(ip)
}
func intToIP(i *big.Int, isIPv6 bool) net.IP {
intBytes := i.Bytes()
if len(intBytes) == net.IPv4len || len(intBytes) == net.IPv6len {
return intBytes
}
return big.NewInt(0).SetBytes(ip.To16())
if isIPv6 {
return append(make([]byte, net.IPv6len-len(intBytes)), intBytes...)
}
return append(make([]byte, net.IPv4len-len(intBytes)), intBytes...)
}
func intToIP(i *big.Int) net.IP {
return net.IP(i.Bytes())
// normalizeIP will normalize IP by family,
// IPv4 : 4-byte form
// IPv6 : 16-byte form
// others : nil
func normalizeIP(ip net.IP) net.IP {
if ipTo4 := ip.To4(); ipTo4 != nil {
return ipTo4
}
return ip.To16()
}
// Network masks off the host portion of the IP
// Network masks off the host portion of the IP, if IPNet is invalid,
// return nil
func Network(ipn *net.IPNet) *net.IPNet {
if ipn == nil {
return nil
}
maskedIP := ipn.IP.Mask(ipn.Mask)
if maskedIP == nil {
return nil
}
return &net.IPNet{
IP: ipn.IP.Mask(ipn.Mask),
IP: maskedIP,
Mask: ipn.Mask,
}
}