From 752fc9b8ba3b487c6e8c777c5260de2db5efb1b4 Mon Sep 17 00:00:00 2001 From: Nadeem Douba Date: Mon, 3 Jul 2023 09:58:18 -0400 Subject: [PATCH] Added support for v1 and v2 cgroups No longer using a plugin to deploy solution --- Dockerfile | 19 +-- build.sh | 21 +-- config.json | 33 +++-- ctypes.h | 33 +++++ docker-compose.yml | 26 ++-- go.mod | 10 +- go.sum | 21 ++- internal/cgroup/api.go | 91 ++++++++++++ internal/cgroup/ebpf.go | 299 ++++++++++++++++++++++++++++++++++++++++ internal/cgroup/v1.go | 150 ++++++++++++++++++++ internal/cgroup/v2.go | 206 +++++++++++++++++++++++++++ main.go | 282 ++++++++++++++++++------------------- 12 files changed, 987 insertions(+), 204 deletions(-) create mode 100644 ctypes.h create mode 100644 internal/cgroup/api.go create mode 100644 internal/cgroup/ebpf.go create mode 100644 internal/cgroup/v1.go create mode 100644 internal/cgroup/v2.go diff --git a/Dockerfile b/Dockerfile index 7a2e736..ddb09be 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,17 +1,20 @@ -FROM debian +# syntax=docker/dockerfile:1 + +FROM golang:1.19 ENV DEBIAN_FRONTEND noninteractive WORKDIR /go/src/github.com/allfro/device-volume-driver + COPY . . -RUN apt update && \ - apt install -y musl-dev musl-tools git curl && \ - curl -L -o go.tgz https://go.dev/dl/go1.19.3.linux-amd64.tar.gz && \ - tar -zxvf go.tgz && \ - export PATH=$PATH:go/bin && \ - go get && \ - CC=/usr/bin/musl-gcc go build -ldflags "-linkmode external -extldflags -static" -o /dvd + +RUN CGO_ENABLED=1 GOOS=linux go build -ldflags "-linkmode external -extldflags -static" -o /dvd FROM alpine + +WORKDIR / + COPY --from=0 /dvd /dvd + ENTRYPOINT ["/dvd"] + diff --git a/build.sh b/build.sh index 2ea9713..22a3b9e 100755 --- a/build.sh +++ b/build.sh @@ -1,21 +1,4 @@ #!/bin/sh -set -eux - -ROOTFS=plugin/rootfs -CONFIG=plugin/config.json - -tag=redcanari/dvd -docker build -t "$tag" -f Dockerfile . -id=$(docker create "$tag" true) -rm -Rf $ROOTFS -mkdir -p $ROOTFS -docker export "$id" | tar -x -C $ROOTFS -docker rm -vf "$id" -docker rmi "$tag" -cp config.json $CONFIG - -docker plugin rm -f $tag || echo -docker plugin create $tag ./plugin -docker plugin push $tag -#docker plugin enable $tag \ No newline at end of file +docker build . -t ndouba/device-mapping-manager +docker push ndouba/device-mapping-manager \ No newline at end of file diff --git a/config.json b/config.json index 9e64066..ad8d469 100644 --- a/config.json +++ b/config.json @@ -28,26 +28,37 @@ }, "Linux": { "Capabilities": [ + "CAP_BPF", + "CAP_SYS_PTRACE", "CAP_SYS_ADMIN" ], "AllowAllDevices": true, "Devices": null }, "Mounts": [ - { - "source": "/sys/fs/cgroup/devices", - "destination": "/sys/fs/cgroup/devices", - "options": [ - "rw", - "rbind" - ], - "type": "rbind" - }, { "destination": "/dev", "source": "/dev", "options": [ - "bind", + "rbind", + "rw" + ], + "type": "bind" + }, + { + "destination": "/host/sys", + "source": "/sys", + "options": [ + "rbind", + "rw" + ], + "type": "bind" + }, + { + "destination": "/host/proc", + "source": "/proc", + "options": [ + "rbind", "rw" ], "type": "bind" @@ -66,7 +77,7 @@ "Network": { "Type": "" }, - "PropagatedMount": "/dev", + "PropagatedMount": null, "User": {}, "Workdir": "/" } \ No newline at end of file diff --git a/ctypes.h b/ctypes.h new file mode 100644 index 0000000..4134ed3 --- /dev/null +++ b/ctypes.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef HEADER_NVCGO_CTYPES_H +#define HEADER_NVCGO_CTYPES_H + +#include + +#include +#include + +struct device_rule { + bool allow; + const char *type; + const char *access; + dev_t major; + dev_t minor; +}; + +#endif /* HEADER_NVCGO_CTYPES_H */ \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index b8de8fb..e918022 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,15 +1,21 @@ version: "3.8" -volumes: - fuse: - driver: redcanari/device-volume-driver - driver_opts: - device: /dev/fuse services: - rdesktop: - image: lscr.io/linuxserver/rdesktop:ubuntu-xfce + dmm: + image: docker + entrypoint: docker + command: | + run + -i + --name device-manager + --restart always + --privileged + --cgroupns=host + --pid=host + --userns=host + -v /sys:/host/sys + -v /var/run/docker.sock:/var/run/docker.sock + ndouba/device-mapping-manager volumes: - - fuse:/dev/fuse - ports: - - 3390:3389 \ No newline at end of file + - /var/run/docker.sock:/var/run/docker.sock diff --git a/go.mod b/go.mod index c84ed78..d2f2dea 100644 --- a/go.mod +++ b/go.mod @@ -3,33 +3,31 @@ module device-volume-driver go 1.19 require ( - github.com/containerd/cgroups/v3 v3.0.0-20221112182753-e8802a182774 + github.com/cilium/ebpf v0.9.1 github.com/docker/docker v20.10.21+incompatible github.com/docker/go-plugins-helpers v0.0.0-20211224144127-6eecb7beb651 + github.com/google/uuid v1.3.0 github.com/opencontainers/runtime-spec v1.0.2 + github.com/sirupsen/logrus v1.8.1 golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f ) require ( github.com/Microsoft/go-winio v0.6.0 // indirect - github.com/cilium/ebpf v0.9.1 // indirect github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf // indirect - github.com/coreos/go-systemd/v22 v22.3.2 // indirect github.com/docker/distribution v2.8.1+incompatible // indirect github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.5.0 // indirect - github.com/godbus/dbus/v5 v5.0.4 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/moby/term v0.0.0-20221105221325-4eb28fa6025c // indirect github.com/morikuni/aec v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.0.2 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/sirupsen/logrus v1.8.1 // indirect + github.com/stretchr/testify v1.8.0 // indirect golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect golang.org/x/net v0.0.0-20220722155237-a158d28d115b // indirect golang.org/x/time v0.2.0 // indirect golang.org/x/tools v0.1.12 // indirect - google.golang.org/protobuf v1.27.1 // indirect gotest.tools/v3 v3.4.0 // indirect ) diff --git a/go.sum b/go.sum index 9e992f8..88b7fa7 100644 --- a/go.sum +++ b/go.sum @@ -3,12 +3,9 @@ github.com/Microsoft/go-winio v0.6.0 h1:slsWYD/zyx7lCXoZVlvQrj0hPTM1HI4+v1sIda2y github.com/Microsoft/go-winio v0.6.0/go.mod h1:cTAf44im0RAYeL23bpB+fzCyDH2MJiz2BO69KH/soAE= github.com/cilium/ebpf v0.9.1 h1:64sn2K3UKw8NbP/blsixRpF3nXuyhz/VjRlRzvlBRu4= github.com/cilium/ebpf v0.9.1/go.mod h1:+OhNOIXx/Fnu1IE8bJz2dzOA+VSfyTfdNUVdlQnxUFY= -github.com/containerd/cgroups/v3 v3.0.0-20221112182753-e8802a182774 h1:Tej/o6wjJ3icV9qkPopNXJxk2oeVAmRc7JL0q5JeUq8= -github.com/containerd/cgroups/v3 v3.0.0-20221112182753-e8802a182774/go.mod h1:/vtwk1VXrtoa5AaZLkypuOJgA/6DyPMZHJPGQNtlHnw= github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU= github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/go-systemd/v22 v22.3.2 h1:D9/bQk5vlXQFZ6Kwuu6zaiXJ9oTPe68++AzAJc1DzSI= -github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/docker/distribution v2.8.1+incompatible h1:Q50tZOPR6T/hjNsyc9g8/syEs6bk8XXApsHjKukMl68= @@ -22,13 +19,12 @@ github.com/docker/go-plugins-helpers v0.0.0-20211224144127-6eecb7beb651/go.mod h github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/frankban/quicktest v1.14.0 h1:+cqqvzZV87b4adx/5ayVOaYZ2CrvM4ejQvUdBzPPUss= -github.com/godbus/dbus/v5 v5.0.4 h1:9349emZab16e7zQvpmsbtjc18ykshndd8y2PG3sgJbA= -github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= @@ -50,11 +46,14 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -94,9 +93,9 @@ golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= -google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.4.0 h1:ZazjZUfuVeZGLAmlKKuyv3IKP5orXcwtOwDQH6YVr6o= gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g= diff --git a/internal/cgroup/api.go b/internal/cgroup/api.go new file mode 100644 index 0000000..306e5ce --- /dev/null +++ b/internal/cgroup/api.go @@ -0,0 +1,91 @@ +//go:build linux + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cgroup + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/opencontainers/runtime-spec/specs-go" +) + +type DeviceRule = specs.LinuxDeviceCgroup + +type Interface interface { + GetDeviceCGroupMountPath(procRootPath string, pid int) (string, string, error) + GetDeviceCGroupRootPath(procRootPath string, prefix string, pid int) (string, error) + AddDeviceRules(cgroupPath string, devices []DeviceRule) error +} + +func New(version int) (Interface, error) { + switch version { + case 1: + return &cgroupv1{}, nil + case 2: + return &cgroupv2{}, nil + default: + return nil, fmt.Errorf("invalid version") + } +} + +type cgroupv1 struct{} +type cgroupv2 struct{} + +var _ Interface = (*cgroupv1)(nil) +var _ Interface = (*cgroupv2)(nil) + +// GetDeviceCGroupVersion returns the version of linux cgroups in use +func GetDeviceCGroupVersion(rootPath string, pid int) (int, error) { + // Open the pid's cgroup file in /proc. + path := fmt.Sprintf(filepath.Join(rootPath, "proc", "%v", "cgroup"), pid) + file, err := os.Open(path) + if err != nil { + return -1, fmt.Errorf("failed to open cgroup path for pid '%d': %v", pid, err) + } + defer file.Close() + + // Create a scanner to loop through the file's contents. + scanner := bufio.NewScanner(file) + scanner.Split(bufio.ScanLines) + + // Loop through the file looking for either a 'devices' or a '' (i.e. unified) entry + found := make(map[string]bool) + for scanner.Scan() { + parts := strings.SplitN(scanner.Text(), ":", 3) + if len(parts) != 3 { + return -1, fmt.Errorf("malformed cgroup entry: %v", scanner.Text()) + } + found[parts[1]] = true + } + + // If a 'devices' entry was found, return version 1. + if found["devices"] { + return 1, nil + } + + // If a '', (i.e. 'unified') entry was found, return version 2. + if found[""] { + return 2, nil + } + + return -1, fmt.Errorf("no devices or unified cgroup entries found") +} diff --git a/internal/cgroup/ebpf.go b/internal/cgroup/ebpf.go new file mode 100644 index 0000000..6674cb8 --- /dev/null +++ b/internal/cgroup/ebpf.go @@ -0,0 +1,299 @@ +//go:build linux + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The implementation of the device filter eBPF program in this file is based on: +// https://github.com/containers/crun/blob/0.10.2/src/libcrun/ebpf.c +// +// Although ebpf.c is originally licensed under LGPL-3.0-or-later, the author (Giuseppe Scrivano) +// agreed to relicense the file in Apache License 2.0: https://github.com/opencontainers/runc/issues/2144#issuecomment-543116397 +// +// Much of the go code in this file is borrowed heavily from the (Apache licensed) file found here: +// https://github.com/opencontainers/runc/blob/8e7ab26104352f4214ab1daec5c1d4bf75eddb54/libcontainer/cgroups/ebpf/devicefilter/devicefilter.go + +package cgroup + +import ( + "errors" + "fmt" + "math" + "os" + "runtime" + "unsafe" + + "github.com/cilium/ebpf" + "github.com/cilium/ebpf/asm" + "github.com/cilium/ebpf/link" + "github.com/google/uuid" + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" +) + +func nilCloser() error { + return nil +} + +type program struct { + insts asm.Instructions + hasWildCard bool + blockID int +} + +func (p *program) init() { + // struct bpf_cgroup_dev_ctx: https://elixir.bootlin.com/linux/v5.3.6/source/include/uapi/linux/bpf.h#L3423 + /* + u32 access_type + u32 major + u32 minor + */ + // R2 <- type (lower 16 bit of u32 access_type at R1[0]) + p.insts = append(p.insts, + asm.LoadMem(asm.R2, asm.R1, 0, asm.Half)) + + // R3 <- access (upper 16 bit of u32 access_type at R1[0]) + p.insts = append(p.insts, + asm.LoadMem(asm.R3, asm.R1, 0, asm.Word), + // RSh: bitwise shift right + asm.RSh.Imm32(asm.R3, 16)) + + // R4 <- major (u32 major at R1[4]) + p.insts = append(p.insts, + asm.LoadMem(asm.R4, asm.R1, 4, asm.Word)) + + // R5 <- minor (u32 minor at R1[8]) + p.insts = append(p.insts, + asm.LoadMem(asm.R5, asm.R1, 8, asm.Word)) +} + +// appendDevice needs to be called from the last element of OCI linux.resources.devices to the head element. +func (p *program) appendDevice(dev specs.LinuxDeviceCgroup, labelPrefix string) error { + if p.blockID < 0 { + return errors.New("the program is finalized") + } + if p.hasWildCard { + // All entries after wildcard entry are ignored + return nil + } + + bpfType := int32(-1) + hasType := true + switch dev.Type { + case string('c'): + bpfType = int32(unix.BPF_DEVCG_DEV_CHAR) + case string('b'): + bpfType = int32(unix.BPF_DEVCG_DEV_BLOCK) + case string('a'): + hasType = false + default: + // if not specified in OCI json, typ is set to DeviceTypeAll + return fmt.Errorf("invalid DeviceType %q", dev.Type) + } + if *dev.Major > math.MaxUint32 { + return fmt.Errorf("invalid major %d", *dev.Major) + } + if *dev.Minor > math.MaxUint32 { + return fmt.Errorf("invalid minor %d", *dev.Major) + } + hasMajor := *dev.Major >= 0 // if not specified in OCI json, major is set to -1 + hasMinor := *dev.Minor >= 0 + bpfAccess := int32(0) + for _, r := range dev.Access { + switch r { + case 'r': + bpfAccess |= unix.BPF_DEVCG_ACC_READ + case 'w': + bpfAccess |= unix.BPF_DEVCG_ACC_WRITE + case 'm': + bpfAccess |= unix.BPF_DEVCG_ACC_MKNOD + default: + return fmt.Errorf("unknown device access %v", r) + } + } + // If the access is rwm, skip the check. + hasAccess := bpfAccess != (unix.BPF_DEVCG_ACC_READ | unix.BPF_DEVCG_ACC_WRITE | unix.BPF_DEVCG_ACC_MKNOD) + + blockSym := fmt.Sprintf("%s-block-%d", labelPrefix, p.blockID) + nextBlockSym := fmt.Sprintf("%s-block-%d", labelPrefix, p.blockID+1) + prevBlockLastIdx := len(p.insts) - 1 + if hasType { + p.insts = append(p.insts, + // if (R2 != bpfType) goto next + asm.JNE.Imm(asm.R2, bpfType, nextBlockSym), + ) + } + if hasAccess { + p.insts = append(p.insts, + // if (R3 & bpfAccess == 0 /* use R2 as a temp var */) goto next + asm.Mov.Reg32(asm.R2, asm.R3), + asm.And.Imm32(asm.R2, bpfAccess), + asm.JEq.Imm(asm.R2, 0, nextBlockSym), + ) + } + if hasMajor { + p.insts = append(p.insts, + // if (R4 != major) goto next + asm.JNE.Imm(asm.R4, int32(*dev.Major), nextBlockSym), + ) + } + if hasMinor { + p.insts = append(p.insts, + // if (R5 != minor) goto next + asm.JNE.Imm(asm.R5, int32(*dev.Minor), nextBlockSym), + ) + } + if !hasType && !hasAccess && !hasMajor && !hasMinor { + p.hasWildCard = true + } + p.insts = append(p.insts, p.acceptBlock(dev.Allow)...) + // set blockSym to the first instruction we added in this iteration + p.insts[prevBlockLastIdx+1] = p.insts[prevBlockLastIdx+1].Sym(blockSym) + p.blockID++ + return nil +} + +func (p *program) acceptBlock(accept bool) asm.Instructions { + v := int32(0) + if accept { + v = 1 + } + return []asm.Instruction{ + // R0 <- v + asm.Mov.Imm32(asm.R0, v), + asm.Return(), + } +} + +func (p *program) finalize(origInsts asm.Instructions, labelPrefix string) (asm.Instructions, error) { + lenInsts := len(p.insts) + // set blockSym to the first instruction of origInsts so we are able to jump to it properly + blockSym := fmt.Sprintf("%s-block-%d", labelPrefix, p.blockID) + p.insts = append(p.insts, origInsts...) + p.insts[lenInsts] = p.insts[lenInsts].Sym(blockSym) + p.blockID = -1 + return p.insts, nil +} + +// FindAttachedCgroupDeviceFilters finds all ebpf prgrams associated with 'dirFd' that control device access +func FindAttachedCgroupDeviceFilters(dirFd int) ([]*ebpf.Program, error) { + type bpfAttrQuery struct { + TargetFd uint32 + AttachType uint32 + QueryType uint32 + AttachFlags uint32 + ProgIds uint64 // __aligned_u64 + ProgCnt uint32 + } + + // Currently you can only have 64 eBPF programs attached to a cgroup. + size := 64 + retries := 0 + for retries < 10 { + progIds := make([]uint32, size) + query := bpfAttrQuery{ + TargetFd: uint32(dirFd), + AttachType: uint32(unix.BPF_CGROUP_DEVICE), + ProgIds: uint64(uintptr(unsafe.Pointer(&progIds[0]))), + ProgCnt: uint32(len(progIds)), + } + + // Fetch the list of program ids. + _, _, errno := unix.Syscall(unix.SYS_BPF, + uintptr(unix.BPF_PROG_QUERY), + uintptr(unsafe.Pointer(&query)), + unsafe.Sizeof(query)) + size = int(query.ProgCnt) + runtime.KeepAlive(query) + if errno != 0 { + // On ENOSPC we get the correct number of programs. + if errno == unix.ENOSPC { + retries++ + continue + } + return nil, fmt.Errorf("bpf_prog_query(BPF_CGROUP_DEVICE) failed: %w", errno) + } + + // Convert the ids to program handles. + progIds = progIds[:size] + programs := make([]*ebpf.Program, 0, len(progIds)) + for _, progId := range progIds { + program, err := ebpf.NewProgramFromID(ebpf.ProgramID(progId)) + if err != nil { + // We skip over programs that give us -EACCES or -EPERM. This + // is necessary because there may be BPF programs that have + // been attached (such as with --systemd-cgroup) which have an + // LSM label that blocks us from interacting with the program. + // + // Because additional BPF_CGROUP_DEVICE programs only can add + // restrictions, there's no real issue with just ignoring these + // programs (and stops runc from breaking on distributions with + // very strict SELinux policies). + if errors.Is(err, os.ErrPermission) { + logrus.Debugf("ignoring existing CGROUP_DEVICE program (prog_id=%v) which cannot be accessed by runc -- likely due to LSM policy: %v", progId, err) + continue + } + return nil, fmt.Errorf("cannot fetch program from id: %w", err) + } + programs = append(programs, program) + } + runtime.KeepAlive(progIds) + return programs, nil + } + + return nil, errors.New("could not get complete list of CGROUP_DEVICE programs") +} + +// PrependDeviceFilter prepends a set of instructions for further device filtering to an existing device filtering ebpf program +func PrependDeviceFilter(devices []specs.LinuxDeviceCgroup, origInsts asm.Instructions) (asm.Instructions, error) { + labelPrefix := uuid.New().String() + p := &program{} + p.init() + for i := len(devices) - 1; i >= 0; i-- { + if err := p.appendDevice(devices[i], labelPrefix); err != nil { + return nil, err + } + } + insts, err := p.finalize(origInsts, labelPrefix) + return insts, err +} + +// DetachCgroupDeviceFilter detaches an existing device filter ebpf program from a cgroup. +func DetachCgroupDeviceFilter(prog *ebpf.Program, dirFd int) error { + err := link.RawDetachProgram(link.RawDetachProgramOptions{ + Target: dirFd, + Program: prog, + Attach: ebpf.AttachCGroupDevice, + }) + if err != nil { + return fmt.Errorf("failed to call BPF_PROG_DETACH (BPF_CGROUP_DEVICE): %w", err) + } + return nil +} + +// AttachCgroupDeviceFilter attaches a new device filter ebpf program to a cgroup. +func AttachCgroupDeviceFilter(prog *ebpf.Program, dirFd int) error { + err := link.RawAttachProgram(link.RawAttachProgramOptions{ + Target: dirFd, + Program: prog, + Attach: ebpf.AttachCGroupDevice, + Flags: unix.BPF_F_ALLOW_MULTI, + }) + if err != nil { + return fmt.Errorf("failed to call BPF_PROG_ATTACH (BPF_CGROUP_DEVICE, BPF_F_ALLOW_MULTI): %w", err) + } + return nil +} diff --git a/internal/cgroup/v1.go b/internal/cgroup/v1.go new file mode 100644 index 0000000..992ba99 --- /dev/null +++ b/internal/cgroup/v1.go @@ -0,0 +1,150 @@ +//go:build linux + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cgroup + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "strings" +) + +// GetDeviceCGroupMountPath returns the mount path (and its prefix) for the device cgroup controller associated with pid +func (c *cgroupv1) GetDeviceCGroupMountPath(procRootPath string, pid int) (string, string, error) { + // Open the pid's mountinfo file in /proc. + path := fmt.Sprintf(filepath.Join(procRootPath, "proc", "%v", "mountinfo"), pid) + file, err := os.Open(path) + if err != nil { + return "", "", err + } + defer file.Close() + + // Create a scanner to loop through the file's contents. + scanner := bufio.NewScanner(file) + scanner.Split(bufio.ScanLines) + + // Loop through the file looking for a subsystem of 'devices' entry. + for scanner.Scan() { + // Split each entry by '[space]' + parts := strings.Split(scanner.Text(), " ") + if len(parts) < 5 { + return "", "", fmt.Errorf("malformed mountinfo entry: %v", scanner.Text()) + } + // Look for an entry with cgroup as the mount type. + if parts[len(parts)-3] != "cgroup" { + continue + } + // Look for an entry with 'devices' as the basename of the mountpath + if filepath.Base(parts[4]) != "devices" { + continue + } + // Make sure the mount prefix is not a relative path. + if strings.HasPrefix(parts[3], "/..") { + return "", "", fmt.Errorf("relative path in mount prefix: %v", parts[3]) + } + // Return the 3rd element as the prefix of the mount point for + // the devices cgroup and the 4th element as the mount point of + // the devices cgroup itself. + return parts[3], parts[4], nil + } + + return "", "", fmt.Errorf("no cgroup filesystem mounted for the devices subsytem in mountinfo file") +} + +// GetDeviceCGroupRootPath returns the root path for the device cgroup controller associated with pid +func (c *cgroupv1) GetDeviceCGroupRootPath(procRootPath string, prefix string, pid int) (string, error) { + // Open the pid's cgroup file in /proc. + path := fmt.Sprintf(filepath.Join(procRootPath, "proc", "%v", "cgroup"), pid) + file, err := os.Open(path) + if err != nil { + return "", err + } + defer file.Close() + + // Create a scanner to loop through the file's contents. + scanner := bufio.NewScanner(file) + scanner.Split(bufio.ScanLines) + + // Loop through the file looking for either a subsystem of 'devices' entry. + for scanner.Scan() { + // Split each entry by ':' + parts := strings.SplitN(scanner.Text(), ":", 3) + if len(parts) != 3 { + return "", fmt.Errorf("malformed cgroup entry: %v", scanner.Text()) + } + // Look for the devices subsystem in the 1st element. + if parts[1] != "devices" { + continue + } + // Return the cgroup root from the 2nd element + // (with the prefix possibly stripped off). + if prefix == "/" { + return parts[2], nil + } + return strings.TrimPrefix(parts[2], prefix), nil + } + + return "", fmt.Errorf("no devices cgroup entries found") +} + +// AddDeviceRules adds a set of device rules for the device cgroup at cgroupPath +func (c *cgroupv1) AddDeviceRules(cgroupPath string, rules []DeviceRule) error { + // Loop through all rules in the set of device rules and add that rule to the device. + for _, rule := range rules { + err := c.addDeviceRule(cgroupPath, &rule) + if err != nil { + return err + } + } + + return nil +} + +func (c *cgroupv1) addDeviceRule(cgroupPath string, rule *DeviceRule) error { + // Check the major/minor numbers of the device in the device rule. + if rule.Major == nil { + return fmt.Errorf("no major set in device rule") + } + + if rule.Minor == nil { + return fmt.Errorf("no minor set in device rule") + } + + // Open the appropriate allow/deny file. + var path string + if rule.Allow { + path = filepath.Join(cgroupPath, "devices.allow") + } else { + path = filepath.Join(cgroupPath, "devices.deny") + } + file, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0600) + if err != nil { + return err + } + defer file.Close() + + // Write the device rule into the file. + _, err = file.WriteString(fmt.Sprintf("%s %d:%d %s", rule.Type, *rule.Major, *rule.Minor, rule.Access)) + if err != nil { + return err + } + + return nil +} diff --git a/internal/cgroup/v2.go b/internal/cgroup/v2.go new file mode 100644 index 0000000..ed92278 --- /dev/null +++ b/internal/cgroup/v2.go @@ -0,0 +1,206 @@ +//go:build linux + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cgroup + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/cilium/ebpf" + "github.com/cilium/ebpf/asm" + "golang.org/x/sys/unix" +) + +const ( + BpfProgramLicense = "Apache" +) + +// GetDeviceCGroupMountPath returns the mount path (and its prefix) for the device cgroup controller associated with pid +func (c *cgroupv2) GetDeviceCGroupMountPath(procRootPath string, pid int) (string, string, error) { + // Open the pid's mountinfo file in /proc. + path := fmt.Sprintf(filepath.Join(procRootPath, "proc", "%v", "mountinfo"), pid) + file, err := os.Open(path) + if err != nil { + return "", "", err + } + defer file.Close() + + // Create a scanner to loop through the file's contents. + scanner := bufio.NewScanner(file) + scanner.Split(bufio.ScanLines) + + // Loop through the file looking for a subsystem of '' (i.e. unified) entry. + for scanner.Scan() { + // Split each entry by '[space]' + parts := strings.Split(scanner.Text(), " ") + if len(parts) < 5 { + return "", "", fmt.Errorf("malformed mountinfo entry: %v", scanner.Text()) + } + // Look for an entry with cgroup2 as the mount type. + if parts[len(parts)-3] != "cgroup2" { + continue + } + // Make sure the mount prefix is not a relative path. + if strings.HasPrefix(parts[3], "/..") { + return "", "", fmt.Errorf("relative path in mount prefix: %v", parts[3]) + } + // Return the 3rd element as the prefix of the mount point for + // the devices cgroup and the 4th element as the mount point of + // the devices cgroup itself. + return parts[3], parts[4], nil + } + + return "", "", fmt.Errorf("no cgroup2 filesystem in mountinfo file") +} + +// GetDeviceCGroupRootPath returns the root path for the device cgroup controller associated with pid +func (c *cgroupv2) GetDeviceCGroupRootPath(procRootPath string, prefix string, pid int) (string, error) { + // Open the pid's cgroup file in /proc. + path := fmt.Sprintf(filepath.Join(procRootPath, "proc", "%v", "cgroup"), pid) + file, err := os.Open(path) + if err != nil { + return "", err + } + defer file.Close() + + // Create a scanner to loop through the file's contents. + scanner := bufio.NewScanner(file) + scanner.Split(bufio.ScanLines) + + // Loop through the file looking for either a '' (i.e. unified) entry. + for scanner.Scan() { + // Split each entry by ':' + parts := strings.SplitN(scanner.Text(), ":", 3) + if len(parts) != 3 { + return "", fmt.Errorf("malformed cgroup entry: %v", scanner.Text()) + } + // Look for the (empty) subsystem in the 1st element. + if parts[1] != "" { + continue + } + // Return the cgroup root from the 2nd element + // (with the prefix possibly stripped off). + if prefix == "/" { + return parts[2], nil + } + return strings.TrimPrefix(parts[2], prefix), nil + } + + return "", fmt.Errorf("no cgroupv2 entries in file") +} + +// AddDeviceRules adds a set of device rules for the device cgroup at cgroupPath +func (c *cgroupv2) AddDeviceRules(cgroupPath string, rules []DeviceRule) error { + // Open the cgroup path. + dirFD, err := unix.Open(cgroupPath, unix.O_DIRECTORY|unix.O_RDONLY, 0600) + if err != nil { + return fmt.Errorf("unable to open the cgroup path: %v", err) + } + defer unix.Close(dirFD) + + // Find any existing eBPF device filter programs attached to this cgroup. + oldProgs, err := FindAttachedCgroupDeviceFilters(dirFD) + if err != nil { + return fmt.Errorf("unable to find any existing device filters attached to the cgroup: %v", err) + } + + // Generate a new set of eBPF programs by prepending instructions for the + // new devices to the instructions of each existing program. + // If no existing programs found, create a new program with just our device filter. + var newProgs []*ebpf.Program + if len(oldProgs) == 0 { + oldInsts := asm.Instructions{asm.Return()} + + newProg, err := generateNewProgram(rules, oldInsts) + if err != nil { + return fmt.Errorf("unable to generate new device filter program with no existing programs: %v", err) + } + + newProgs = append(newProgs, newProg) + } + for _, oldProg := range oldProgs { + oldInfo, err := oldProg.Info() + if err != nil { + return fmt.Errorf("unable to get Info() of the original device filters program: %v", err) + } + + oldInsts, err := oldInfo.Instructions() + if err != nil { + return fmt.Errorf("unable to get the instructions of the original device filters program: %v", err) + } + + newProg, err := generateNewProgram(rules, oldInsts) + if err != nil { + return fmt.Errorf("unable to generate new device filter program from existing programs: %v", err) + } + + newProgs = append(newProgs, newProg) + } + + // Increase `ulimit -l` limit to avoid BPF_PROG_LOAD error below. + // This limit is not inherited into the container. + memlockLimit := &unix.Rlimit{ + Cur: unix.RLIM_INFINITY, + Max: unix.RLIM_INFINITY, + } + _ = unix.Setrlimit(unix.RLIMIT_MEMLOCK, memlockLimit) + + // Replace the set of existing eBPF programs with the new ones. + // We don't have to worry about atomically replacing each program (i.e. by + // using BPF_F_REPLACE) because we know that the code here is always run + // strictly *before* a container begins executing. + for _, oldProg := range oldProgs { + err = DetachCgroupDeviceFilter(oldProg, dirFD) + if err != nil { + return fmt.Errorf("unable to detach original device filters program: %v", err) + } + } + for _, newProg := range newProgs { + err = AttachCgroupDeviceFilter(newProg, dirFD) + if err != nil { + return fmt.Errorf("unable to attach new device filters program: %v", err) + } + } + + return nil +} + +func generateNewProgram(rules []DeviceRule, oldInsts asm.Instructions) (*ebpf.Program, error) { + // Prepend instructions for the new devices to the original set of instructions. + newInsts, err := PrependDeviceFilter(rules, oldInsts) + if err != nil { + return nil, fmt.Errorf("unable to prepend new device filters to the original device filters program: %v", err) + } + + // Generate new eBPF program for the merged device filter instructions. + spec := &ebpf.ProgramSpec{ + Type: ebpf.CGroupDevice, + Instructions: newInsts, + License: BpfProgramLicense, + } + newProg, err := ebpf.NewProgram(spec) + if err != nil { + return nil, fmt.Errorf("unable to create new device filters program: %v", err) + } + + return newProg, nil +} diff --git a/main.go b/main.go index 7c062ef..c6e0a45 100644 --- a/main.go +++ b/main.go @@ -1,176 +1,180 @@ +//go:build linux + package main +// #include "ctypes.h" +import "C" import ( "context" - "errors" + "device-volume-driver/internal/cgroup" "fmt" - "github.com/containerd/cgroups/v3" "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/filters" "github.com/docker/docker/client" - "github.com/docker/go-plugins-helpers/volume" _ "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/sys/unix" "log" "os" "path" + "path/filepath" "strings" - "time" ) const pluginId = "dvd" +const rootPath = "/host" + +func Ptr[T any](v T) *T { + return &v +} func main() { - driver := DeviceVolumeDriver() - handler := volume.NewHandler(driver) - log.Println(handler.ServeUnix(pluginId, 0)) + listenForMounts() } -type deviceVolumeDriver struct { - *client.Client -} +func getDeviceInfo(devicePath string) (string, int64, int64, error) { + var stat unix.Stat_t -type mountPoint struct { - name string - device string -} - -var mountPoints = make(map[string]mountPoint) - -func (d deviceVolumeDriver) Create(request *volume.CreateRequest) error { - device, ok := request.Options["device"] - - if !ok { - return errors.New("you must specify the `device` option") + if err := unix.Lstat(devicePath, &stat); err != nil { + log.Println(err) + return "", -1, -1, err } - mountPoints[request.Name] = mountPoint{ - name: strings.Clone(request.Name), - device: strings.Clone(device), + var deviceType string + + switch stat.Mode & unix.S_IFMT { + case unix.S_IFBLK: + deviceType = "b" + case unix.S_IFCHR: + deviceType = "c" + default: + log.Println("aborting: device is neither a character or block device") + return "", -1, -1, fmt.Errorf("unsupported device type... aborting") } - return nil + major := int64(unix.Major(stat.Rdev)) + minor := int64(unix.Minor(stat.Rdev)) + + log.Printf("Found device: %s %s %d:%d\n", devicePath, deviceType, major, minor) + + return deviceType, major, minor, nil } -func (d deviceVolumeDriver) List() (*volume.ListResponse, error) { - var volumes []*volume.Volume +func listenForMounts() { + ctx := context.Background() - for _, mountPoint := range mountPoints { - volumes = append(volumes, &volume.Volume{Name: mountPoint.name, Mountpoint: mountPoint.device}) - } - - return &volume.ListResponse{ - Volumes: volumes, - }, nil -} - -func (d deviceVolumeDriver) Get(request *volume.GetRequest) (*volume.GetResponse, error) { - mountPoint, ok := mountPoints[request.Name] - - if !ok { - return nil, errors.New("no such volume") - } - - return &volume.GetResponse{Volume: &volume.Volume{Name: mountPoint.name, Mountpoint: mountPoint.device}}, nil -} - -func (d deviceVolumeDriver) Remove(request *volume.RemoveRequest) error { - delete(mountPoints, request.Name) - return nil -} - -func (d deviceVolumeDriver) Path(request *volume.PathRequest) (*volume.PathResponse, error) { - mountPoint, ok := mountPoints[request.Name] - if !ok { - return nil, errors.New("no such volume") - } - return &volume.PathResponse{Mountpoint: mountPoint.device}, nil -} - -func (d deviceVolumeDriver) Mount(request *volume.MountRequest) (*volume.MountResponse, error) { - mountPoint, ok := mountPoints[request.Name] - - if !ok { - return nil, errors.New("mountpoint does not exist") - } - - go func() { - time.Sleep(time.Second * 1) - filter := filters.NewArgs(filters.KeyValuePair{Key: "volume", Value: request.Name}) - - containers, err := d.ContainerList( - context.Background(), - types.ContainerListOptions{Filters: filter}, - ) - - if err != nil { - log.Println(err) - return - } else if len(containers) == 0 { - log.Println("aborting: could not find container that uses volume " + mountPoint.name) - return - } - - devicesAllowPath := path.Join("/sys/fs/cgroup/devices/docker", containers[0].ID, "devices.allow") - - if _, err := os.Stat(devicesAllowPath); os.IsNotExist(err) { - //return nil, errors.New("could not find cgroup `devices.allow` file for specified container: " + devicesAllowPath) - log.Println(errors.New("could not find cgroup `devices.allow` file for specified container: " + devicesAllowPath)) - return - } - - var stat unix.Stat_t - - if err := unix.Lstat(mountPoint.device, &stat); err != nil { - //return nil, err - log.Println(err) - return - } - - var deviceType string - - switch stat.Mode & unix.S_IFMT { - case unix.S_IFBLK: - deviceType = "b" - case unix.S_IFCHR: - deviceType = "c" - default: - log.Println("aborting: device is neither a character or block device") - return - } - - input := fmt.Sprintf("%s %d:%d rwm\n", deviceType, unix.Major(stat.Rdev), unix.Minor(stat.Rdev)) - - log.Println("Whitelisting `" + mountPoint.device + "` in `" + devicesAllowPath + "`") - - if err := os.WriteFile(devicesAllowPath, []byte(input), 0400); err != nil { - //return nil, err - log.Println(err) - return - } - }() - - return &volume.MountResponse{Mountpoint: mountPoint.device}, nil -} - -func (d deviceVolumeDriver) Unmount(request *volume.UnmountRequest) error { - return nil -} - -func (d deviceVolumeDriver) Capabilities() *volume.CapabilitiesResponse { - return &volume.CapabilitiesResponse{Capabilities: volume.Capability{Scope: "local"}} -} - -func DeviceVolumeDriver() *deviceVolumeDriver { - cli, err := client.NewClientWithOpts(client.FromEnv) + cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) if err != nil { log.Fatal(err) } - if cgroups.Mode() == cgroups.Unified { - log.Fatal(errors.New("cgroupv2 is not supported")) + defer cli.Close() + + msgs, errs := cli.Events( + ctx, + types.EventsOptions{Filters: filters.NewArgs(filters.Arg("event", "start"))}, + ) + + for { + select { + case err := <-errs: + log.Fatal(err) + case msg := <-msgs: + info, err := cli.ContainerInspect(ctx, msg.Actor.ID) + + if err != nil { + panic(err) + } else { + pid := info.State.Pid + version, err := cgroup.GetDeviceCGroupVersion("/", pid) + + log.Printf("The cgroup version for process %d is: %v\n", pid, version) + + if err != nil { + log.Println(err) + break + } + + log.Printf("Checking mounts for process %d\n", pid) + + for _, mount := range info.Mounts { + log.Printf( + "%s/%v requested a volume mount for %s at %s\n", + msg.Actor.ID, info.State.Pid, mount.Source, mount.Destination, + ) + + if !strings.HasPrefix(mount.Source, "/dev") { + log.Printf("%s is not a device... skipping\n", mount.Source) + continue + } + + api, err := cgroup.New(version) + cgroupPath, sysfsPath, err := api.GetDeviceCGroupMountPath("/", pid) + + if err != nil { + log.Println(err) + break + } + + cgroupPath = path.Join(rootPath, sysfsPath, cgroupPath) + + log.Printf("The cgroup path for process %d is at %v\n", pid, cgroupPath) + + if fileInfo, err := os.Stat(mount.Source); err != nil { + log.Println(err) + continue + } else { + if fileInfo.IsDir() { + err := filepath.Walk(mount.Source, + func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } else if info.IsDir() { + return nil + } else if err = applyDeviceRules(api, path, cgroupPath, pid); err != nil { + log.Println(err) + } + return nil + }) + if err != nil { + log.Println(err) + } + } else { + if err = applyDeviceRules(api, mount.Source, cgroupPath, pid); err != nil { + log.Println(err) + } + } + } + + } + } + } + } +} + +func applyDeviceRules(api cgroup.Interface, mountPath string, cgroupPath string, pid int) error { + deviceType, major, minor, err := getDeviceInfo(mountPath) + + if err != nil { + log.Println(err) + return err + } else { + log.Printf("Adding device rule for process %d at %s\n", pid, cgroupPath) + err = api.AddDeviceRules(cgroupPath, []cgroup.DeviceRule{ + { + Access: "rwm", + Major: Ptr[int64](major), + Minor: Ptr[int64](minor), + Type: deviceType, + Allow: true, + }, + }) + + if err != nil { + log.Println(err) + return err + } } - return &deviceVolumeDriver{cli} + return nil }