Add sysctl allowlist

Signed-off-by: mmirecki <mmirecki@redhat.com>
This commit is contained in:
mmirecki
2022-01-20 13:53:11 +01:00
parent 27e830b73e
commit 96c3af81e2
2 changed files with 198 additions and 0 deletions

View File

@ -19,12 +19,14 @@ package main
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"os"
"path"
"path/filepath"
"regexp"
"strings"
"github.com/vishvananda/netlink"
@ -40,6 +42,8 @@ import (
)
const defaultDataDir = "/run/cni/tuning"
const defaultAllowlistDir = "/etc/cni/tuning/"
const defaultAllowlistFile = "allowlist.conf"
// TuningConf represents the network tuning configuration.
type TuningConf struct {
@ -305,6 +309,10 @@ func cmdAdd(args *skel.CmdArgs) error {
return err
}
if err = validateSysctlConf(tuningConf); err != nil {
return err
}
// Parse previous result.
if tuningConf.RawPrevResult == nil {
return fmt.Errorf("Required prevResult missing")
@ -477,3 +485,60 @@ func cmdCheck(args *skel.CmdArgs) error {
return nil
}
// Validate the sysctls in the tuning config are on the sysctl allowlist file.
// Note that if the allowlist file is missing no validation takes place.
func validateSysctlConf(tuningConf *TuningConf) error {
isPresent, allowlist, err := readAllowlist()
if err != nil {
return err
}
if !isPresent {
return nil
}
for sysctl, _ := range tuningConf.SysCtl {
match, err := contains(sysctl, allowlist)
if err != nil {
return err
}
if !match {
return errors.New(fmt.Sprintf("Sysctl %s is not allowed. Only the following sysctls are allowed: %+v", sysctl, allowlist))
}
}
return nil
}
// Validate the allowList contains the given sysctl
func contains(sysctl string, allowList []string) (bool, error) {
for _, allowListElement := range allowList {
match, err := regexp.MatchString(allowListElement, sysctl)
if err != nil {
return false, err
}
if match {
return true, nil
}
}
return false, nil
}
// Read the systctl allowlist from file. Return info if the file is present and the read allowList if it is
func readAllowlist() (bool, []string, error) {
if _, err := os.Stat(filepath.Join(defaultAllowlistDir, defaultAllowlistFile)); os.IsNotExist(err) {
return false, nil, nil
}
dat, err := os.ReadFile(filepath.Join(defaultAllowlistDir, defaultAllowlistFile))
if err != nil {
return false, nil, err
}
lines := strings.Split(string(dat), "\n")
allowList := []string{}
for _, line := range lines {
line = strings.TrimSpace(line)
if len(line) > 0 {
allowList = append(allowList, line)
}
}
return true, allowList, nil
}