#!/usr/bin/python3
import argparse
import os
from subprocess import check_call, check_output
import subprocess
import tempfile
import pathlib
import platform
import glob
import shutil
import re
import sys
import hashlib
from collections import namedtuple
from enum import Enum, auto
import json


class ModTable:

    def __init__(self):
        self._table = {}

    def get(self, key, default=None):
        if not isinstance(key, str):
            raise TypeError
        return self._table.get(key.replace('_', '-'), default)

    def __contains__(self, key):
        return key.replace('_', '-') in self._table

    def __getitem__(self, key):
        if not isinstance(key, str):
            raise TypeError
        return self._table[key.replace('_', '-')]

    def __setitem__(self, key, value):
        if not isinstance(key, str):
            raise TypeError
        self._table[key.replace('_', '-')] = value


class ModuleDb:

    ModuleInfo = namedtuple('ModuleInfo', ('builtin', 'aliases', 'symbols'))
    alias_line = re.compile(r"^alias +(?P<alias>[^ ]+) +(?P<target>[^ ]+)$")
    modinfo_line = re.compile("^(?P<name>[^.]+)[.](?P<arg>[^.]+)=(?P<value>.+)$")

    class Installed(Enum):
        EXPLICIT = auto()
        IMPLICIT = auto()

    @staticmethod
    def __each_alias(path):
        with open(path, "r") as f:
            data = f.read()
        for line in data.splitlines():
            m = ModuleDb.alias_line.fullmatch(line)
            if m:
                yield m['alias'], m['target']

    def __init__(self, module_path):
        self._modules = ModTable()
        self._alias = ModTable()
        self._deps = ModTable()
        self._installed = ModTable()
        not_found = set()

        with open(os.path.join(module_path, "modules.builtin"), "r") as f:
            for line in f.read().splitlines():
                name = os.path.splitext(os.path.basename(line))[0]
                self._modules[name] = ModuleDb.ModuleInfo(True, [], [])

        with open(os.path.join(module_path, "modules.builtin.modinfo"), "r") as f:
            for line in f.read().split('\0'):
                m = ModuleDb.modinfo_line.fullmatch(line)
                if m:
                    if m['arg'] == 'alias':
                        original = self._modules.get(m['name'], ModuleDb.ModuleInfo(True, [], []))
                        self._modules[m['name']] = ModuleDb.ModuleInfo(original.builtin,
                                                                       original.aliases + [m['value']],
                                                                       original.symbols)
                    self._alias[m['value']] = m['name']

        for root, dirs, files in os.walk(module_path):
            for f in files:
                if f.endswith('.ko'):
                    name, _ = os.path.splitext(f)
                elif '.ko.' in f:
                    name, _ = f.split('.ko.', maxsplit=1)
                else:
                    continue
                self._modules[name] = ModuleDb.ModuleInfo(False, [], [])

        for alias, target in ModuleDb.__each_alias(os.path.join(module_path, "modules.alias")):
            try:
                original = self._modules[target]
            except KeyError:
                if target not in not_found:
                    print(f"module.alias: cannot find module {target}", file=sys.stderr)
                    not_found.add(target)
            else:
                self._modules[target] = ModuleDb.ModuleInfo(original.builtin,
                                                            original.aliases + [alias],
                                                            original.symbols)
                self._alias[alias] = target

        for alias, target in ModuleDb.__each_alias(os.path.join(module_path, "modules.symbols")):
            try:
                original = self._modules[target]
            except KeyError:
                if target not in not_found:
                    print(f"modules.symbols: cannot find module {target}", file=sys.stderr)
                    not_found.add(target)
            else:
                self._modules[target] = ModuleDb.ModuleInfo(original.builtin,
                                                            original.aliases,
                                                            original.symbols + [alias])

        with open(os.path.join(module_path, "modules.dep"), "r") as f:
            for line in f.read().splitlines():
                mod, deps = line.split(':', 1)
                deps = deps.strip()
                mod = os.path.basename(mod)
                mod = os.path.splitext(mod)[0]
                if len(deps) != 0:
                    for dep in deps.split():
                        dep = os.path.basename(dep)
                        dep = os.path.splitext(dep)[0]
                        if mod not in self._deps:
                            self._deps[mod] = []
                        self._deps[mod].append(dep)

    def mark_installed(self, mod, source):
        real_mod = self.solve_alias(mod)
        if real_mod is None:
            print(f"WARNING: Module {mod} is not known in {source}", file=sys.stderr)
            return
        if real_mod in self._installed:
            old_mode, old_origin, old_source = self._installed[real_mod]
            if old_mode == ModuleDb.Installed.IMPLICIT:
                print(f"WARNING: Module {mod} installed by {source}, but is dependency of {old_origin} installed by {old_source}", file=sys.stderr)
            elif old_mode == ModuleDb.Installed.EXPLICIT:
                print(f"WARNING: Module {mod} installed by {source}, but it was already installed by {old_source} - you could remove it from one of the sources", file=sys.stderr)
                return
            else:
                print(f"INTERNAL ERROR: unexpected value in ModuleDb: {ModuleDb}",
                      file=sys.stderr)
                exit(1)
        self._installed[mod] = (ModuleDb.Installed.EXPLICIT, mod, source)
        for dep in self.all_deps(real_mod):
            if dep in self._installed:
                dep_mode, dep_origin, dep_source = self._installed[dep]
                if dep_mode == ModuleDb.Installed.EXPLICIT:
                    print(f"WARNING: Module {dep} installed by {dep_source}, but is dependency of {mod} installed by {source}", file=sys.stderr)
            else:
                self._installed[dep] = (ModuleDb.Installed.IMPLICIT, mod, source)

    def all_deps(self, mod):
        real_mod = self.solve_alias(mod)
        if real_mod is None:
            return
        queue = self._deps.get(real_mod, [])
        done = set()
        while queue:
            m = queue.pop().replace('_', '-')
            if m in done:
                continue
            yield m
            done.add(m)
            if m in self._deps:
                queue.extend(self._deps[m])

    def solve_alias(self, name):
        if name in self._modules:
            return name
        if name in self._alias:
            return self.solve_alias(self._alias[name])
        return None

    def __contains__(self, name):
        if name in self._modules:
            return True
        if name not in self._alias:
            return False
        return self._alias[name] in self

    def __getitem__(self, name):
        if not isinstance(name, str):
            raise TypeError
        if name in self._modules:
            return self._modules[name]
        if name in self._alias:
            return self[self._alias[name]]
        raise KeyError


# os.path.join() throws away previous components when it finds one
# that is an absolute path - we do not want that when using --root as
# many defaults are absolute, so we use this function to make second
# and later paths relative.
def path_join_make_rel_paths(path, *paths):
    relative_paths = []
    for p in paths:
        relative_paths.append(os.path.relpath(os.path.join('/', p), '/'))
    return os.path.join(path, *relative_paths)


known_automatic_modules = list(map(re.compile, [
    r'crypto-.*',
    r'algif-.*',
    r'fs-.*',
    r'nls-.*',
    r'dm-.*',
    r'md-.*',
]))


# Include kernel modules specified in conf_file
def add_modules_from_file(dest_d, kernel_root, modules_d, fw_d, conf_file, db,
                          warn_discoverable=False):
    """Include kernel modules specified in `conf_file`

    Args:
        dest_d (str): Path to sysroot
        kernel_root (str): Path to kernel directory
        modules_d (str): Path to module directory
        fw_d (str): Path to firmware directory
        conf_file (str): Path to configuration file
        db (ModuleDb): A ModuleDb instance which contains information about modules from `modules_d`
        warn_discoverable (bool): Whether to warn of modules that can be discovered through aliases

    """

    with open(conf_file) as f:
        for module in f.readlines():
            module = module.strip()
            if module:
                if module[0] in ["#", ";"]:
                    continue

                try:
                    mod = db[module]
                except KeyError:
                    print(f"NOTE: {conf_file}: Module {module} not found", file=sys.stderr)
                    if warn_discoverable and any(r.fullmatch(module.replace('_', '-')) for r in known_automatic_modules):
                        print(f"WARNING: {conf_file}: Module {module} is probably automatically loaded", file=sys.stderr)
                else:
                    if mod.builtin:
                        print(f"NOTE: {conf_file}: Module {module} is builtin", file=sys.stderr)
                    if warn_discoverable:
                        if len(mod.aliases) > 0:
                            print(f"WARNING: {conf_file}: Module {module} has aliases:", file=sys.stderr)
                            for alias in mod.aliases:
                                print(f" * {alias}", file=sys.stderr)
                        elif warn_discoverable and any(r.fullmatch(module.replace('_', '-')) for r in known_automatic_modules):
                            print(f"WARNING: {conf_file}: Module {module} is probably automatically loaded", file=sys.stderr)

                    # A module that has symbols but no alias is likely to be loaded through aliases
                    if len(mod.aliases) == 0 and len(mod.symbols) > 0:
                        print(f"WARNING: {conf_file}: Module {module} exports symbols:", file=sys.stderr)
                        for symbol in mod.symbols:
                            print(f" * {symbol}", file=sys.stderr)
                    if not mod.builtin:
                        db.mark_installed(module, conf_file)

                check_call(
                    [
                        "/usr/lib/dracut/dracut-install",
                        "-D",
                        dest_d,
                        "-r", kernel_root,
                        "--kerneldir",
                        modules_d,
                        "--firmwaredirs",
                        fw_d,
                        "--module",
                        "--optional",
                        module,
                    ]
                )


def install_files(files, dest_dir, sysroot):
    proc_env = os.environ.copy()
    proc_env["LD_PRELOAD"] = ""
    cmd = [
        "/usr/lib/dracut/dracut-install",
        "-D", dest_dir,
        "--ldd", "--all",
    ]
    # dracut-install has some bugs if you set "/" as sysroot, avoid that
    if sysroot != "/":
        cmd += ["-r", sysroot]
    check_call(cmd + files, env=proc_env)


# This is useful if the destination path inside dest_dir is different
# from the current file path.
def install_file_to_path(file, dest_dir, dest_path, sysroot):
    proc_env = os.environ.copy()
    proc_env["LD_PRELOAD"] = ""
    cmd = [
        "/usr/lib/dracut/dracut-install",
        "-D", dest_dir,
        "--ldd",
    ]
    # dracut-install has some bugs if you set "/" as sysroot, avoid that
    if sysroot != "/":
        cmd += ["-r", sysroot]
    cmd += [file, dest_path]
    check_call(cmd, env=proc_env)


# Returns as a list the files contained in a list of deb packages.
def package_files(pkgs, sysroot):
    out = check_output(["dpkg", "--admindir=" + sysroot + "/var/lib/dpkg/", "-L"] + pkgs)
    return out.decode("utf-8").splitlines()


def dpkg_architecture(sysroot, variable):
    proc_env = os.environ.copy()
    proc_env["DPKG_DATADIR"] = sysroot + "/usr/share/dpkg"
    out = check_output(["dpkg-architecture", "-q", variable], env=proc_env).decode("utf-8")
    return out.splitlines()[0]


def install_systemd_files(dest_dir, sysroot, ubuntu_release):
    # Build list of files and directories

    syd_pkgs = ["systemd", "systemd-sysv"]
    if release_higher_or_equal(ubuntu_release, (24, 10)):
        syd_pkgs.append("systemd-cryptsetup")
    lines = package_files(syd_pkgs, sysroot)
    # From systemd, we pull
    # * units configuration
    # * Executables
    # * Module load options
    # * Configuration of kernel parameters
    # * udev rules
    # * Configuration for systemd-tmpfiles
    # TODO: some of this can be cleaned up
    to_include = re.compile(r"^(/usr)?/lib/systemd/system|"
                            r"/usr/share/dbus-1/system-services/org.freedesktop.systemd1.service|"
                            r"/usr/share/dbus-1/system.d/org.freedesktop.systemd1.conf|"
                            r".*/modprobe\.d/|"
                            r".*/sysctl\.d/|"
                            r".*/rules\.d/|"
                            r".*/tmpfiles\.d/|"
                            r".*bin/|"
                            r".*sbin/"
                            )
    files = [i for i in lines if to_include.match(i)]

    lines = package_files(["udev"], sysroot)
    # From udev, we pull
    # * Executables
    # * systemd configuration (units, tmpfiles)
    # * udev rules
    # * hwdb
    # TODO: some of this can be cleaned up
    to_include = re.compile(r".*bin/|"
                            r".*lib/|"
                            r".*rules\.d/"
                            )
    files += [i for i in lines if to_include.match(i)]

    files.append("/var/lib/systemd/")

    # Filter out some units we don't want
    to_remove = re.compile(".*systemd-gpt-auto-generator|"
                           ".*proc-sys-fs-binfmt_misc.automount|"
                           ".*systemd-pcrphase.*|"
                           ".*systemd-tpm2-setup.*")
    filtered = [i for i in files if not to_remove.match(i)]
    # Install
    install_files(filtered, dest_dir, sysroot)

    # Local modifications
    # This hack should be removed with PR#113
    check_call([r"sed -i '/^After=/"
                r"{;s, *plymouth-start[.]service *, ,;/"
                r"^After= *$/d;}' " + os.path.join(dest_dir,
                                                   "usr/lib/systemd/system/systemd-ask-password-*")],
               shell=True)
    # Generate hw database (/usr/lib/udev/hwdb.bin) for udev and
    # remove redundant definitions after that.
    check_call(["systemd-hwdb", "--root", dest_dir,
                "update", "--usr", "--strict"])
    shutil.rmtree(os.path.join(dest_dir, "usr/lib/udev/hwdb.d"))


def install_busybox(dest_dir, sysroot):
    install_file_to_path("/usr/lib/initramfs-tools/bin/busybox", dest_dir,
                         "usr/bin/busybox", sysroot)
    # Commands we want for busybox (static for more control and to help
    # with cross-building).
    cmds = ["[", "[[", "acpid", "arch", "ascii", "ash", "awk", "base32", "basename",
            "blockdev", "cat", "chmod", "chroot", "chvt", "clear", "cmp", "cp",
            "crc32", "cut", "date", "deallocvt", "deluser", "devmem", "df", "dirname",
            "du", "dumpkmap", "echo", "egrep", "env", "expr", "false", "fbset", "fgrep",
            "find", "fold", "fstrim", "grep", "gunzip", "gzip", "hostname", "hwclock",
            "i2ctransfer", "ifconfig", "ip", "kill", "ln", "loadfont", "loadkmap", "ls",
            "lzop", "mkdir", "mkfifo", "mknod", "mkswap", "mktemp", "more",
            "mv", "nuke", "openvt", "pidof", "printf", "ps", "pwd", "readlink",
            "reset", "rm", "rmdir", "run-init", "sed", "seq", "setkeycodes",
            "sh", "sleep", "sort", "stat", "static-sh", "stty", "switch_root", "sync",
            "tail", "tee", "test", "touch", "tr", "true", "ts", "tty", "uname",
            "uniq", "wc", "wget", "which", "yes"]
    for c in cmds:
        os.symlink("busybox", os.path.join(dest_dir, "usr/bin", c))


def install_misc(dest_dir, sysroot, deb_arch):
    # dmsetup rules
    rules = package_files(["dmsetup"], sysroot)
    to_include = re.compile(r".*rules.d/")
    rules = [i for i in rules if to_include.match(i)]
    install_files(rules, dest_dir, sysroot)

    files = [
        "/usr/bin/kmod",
        "/usr/bin/mount",
        "/usr/sbin/sulogin",
        "/usr/bin/tar",
        "/usr/lib/" + deb_arch + "/libgcc_s.so.1",
        "/usr/lib/" + deb_arch + "/libkmod.so.2",
        "/usr/lib/systemd/system/dbus.service",
        "/usr/lib/systemd/system/dbus.socket",
        "/usr/lib/systemd/system/plymouth-start.service",
        "/usr/lib/systemd/system/plymouth-switch-root.service",
        "/usr/lib/systemd/systemd-bootchart",
        "/usr/sbin/cryptsetup",
        "/usr/sbin/dmsetup",
        "/usr/sbin/e2fsck",
        "/usr/sbin/fsck",
        "/usr/bin/umount",
        "/usr/sbin/fsck.vfat",
        "/usr/sbin/fsck.vfat",
        "/usr/sbin/mkfs.ext4",
        "/usr/sbin/mkfs.vfat",
        "/usr/sbin/sfdisk",
        "/usr/bin/dbus-daemon",
        "/usr/bin/mountpoint",
        "/usr/bin/partx",
        "/usr/bin/plymouth",
        "/usr/bin/unsquashfs",
        "/usr/lib/" + deb_arch + "/plymouth/label-freetype.so",
        "/usr/lib/" + deb_arch + "/plymouth/script.so",
        "/usr/lib/" + deb_arch + "/plymouth/two-step.so",
        "/usr/sbin/depmod",
        "/usr/sbin/insmod",
        "/usr/sbin/lsmod",
        "/usr/sbin/modinfo",
        "/usr/sbin/modprobe",
        "/usr/sbin/plymouthd",
        "/usr/sbin/rmmod",
        "/usr/share/dbus-1/system.conf",
        "/usr/share/libdrm/amdgpu.ids",
    ]
    files += glob.glob("/lib/" + deb_arch + "/libnss_compat.so.*")
    files += glob.glob("/lib/" + deb_arch + "/libnss_files.so.*")
    files += glob.glob("/usr/lib/" + deb_arch + "/plymouth/renderers/*.so")
    files += glob.glob("/usr/share/plymouth/themes/bgrt/*")
    files += glob.glob("/usr/share/plymouth/themes/spinner/*")

    # the arm64 build of snap-bootstrap requires libteec
    if deb_arch == "aarch64-linux-gnu":
        files += glob.glob("/usr/lib/" + deb_arch + "/libteec.so*")

    install_files(files, dest_dir, sysroot)
    # Links for fsck
    os.symlink("e2fsck", os.path.join(dest_dir, "usr/sbin", "fsck.ext4"))
    # Get deps for shared objects that have not the exec bit set and that
    # are loaded with dlopen (which means that they are not pull by --ldd
    # option, instead use --resolvelazy).
    proc_env = os.environ.copy()
    proc_env["LD_PRELOAD"] = ""
    to_resolve = [
        "/usr/lib/" + deb_arch + "/plymouth/label-freetype.so",
        "/usr/lib/" + deb_arch + "/plymouth/script.so",
        "/usr/lib/" + deb_arch + "/plymouth/two-step.so",
    ]
    to_resolve += glob.glob("/usr/lib/" + deb_arch + "/plymouth/renderers/*.so")
    cmd = [
        "/usr/lib/dracut/dracut-install",
        "-D", dest_dir,
        "--resolvelazy",
    ]
    # dracut-install has some bugs if you set "/" as sysroot, avoid that
    if sysroot != "/":
        cmd += ["-r", sysroot]
    check_call(cmd + to_resolve, env=proc_env)
    # Build ld cache (make sure /etc is there before)
    etc_d = os.path.join(dest_dir, "etc")
    os.makedirs(etc_d, exist_ok=True)
    check_call(["ldconfig", "-r", dest_dir])

    # /sbin/modprobe is a static path in the kernel
    os.makedirs(os.path.join(dest_dir, "sbin"))
    os.symlink("../usr/sbin/modprobe", os.path.join(dest_dir, "sbin", "modprobe"))
    # FIXME: systemd is configured with the wrong path to dmsetup
    os.symlink("../usr/sbin/dmsetup", os.path.join(dest_dir, "sbin", "dmsetup"))
    # FIXME: snap-bootstrap uses /sbin/reboot
    os.symlink("../usr/bin/systemctl", os.path.join(dest_dir, "sbin", "reboot"))


class AptRepo:

    def __init__(self, url):
        self.url = url
        self.suites = []
        self.components = []


def find_apt_repos(sysroot):
    repos = {}
    sources = os.path.join(sysroot, "etc/apt/sources.list")
    sources_d = os.path.join(sysroot, "etc/apt/sources.list.d")
    state_d = os.path.join(sysroot, "var/lib/apt/lists")
    repo_lines = check_output(["apt-cache", "policy",
                               "-o", "Dir::Etc::SourceList=" + sources,
                               "-o", "Dir::Etc::SourceParts=" + sources_d,
                               "-o", "Dir::State::Lists=" + state_d]).decode("utf-8").splitlines()
    uri_line = re.compile(r"^ *-?[0-9]+ (http|mirror|file|cdrom|ftp|copy|rsh|ssh)")
    repo_lines = [r for r in repo_lines if uri_line.match(r)]
    uri_start = re.compile(r"^ *-?[0-9]+ ")
    for r in repo_lines:
        r = uri_start.sub("", r)
        fields = r.split(" ")
        if len(fields) < 2:
            continue
        # suite/comp in the line
        suite_comp = fields[1].split("/")
        if len(suite_comp) != 2:
            continue
        # Key will be url + suite (we only have one suite per line)
        key = (fields[0], suite_comp[0])
        if key in repos:
            repos[key].components.append(suite_comp[1])
        else:
            repo = AptRepo(fields[0])
            repo.suites = [suite_comp[0]]
            repo.components = [suite_comp[1]]
            repos[key] = repo
    # Now merge suites that share components
    repo_ls = list(repos.values())
    merged = set()
    final_repos = []
    for i, repo in enumerate(repo_ls):
        if i in merged:
            continue
        for j in range(i + 1, len(repo_ls)):
            if sorted(repo.components) == sorted(repo_ls[j].components):
                repo.suites += repo_ls[j].suites
                merged.add(j)
        final_repos.append(repo)

    return final_repos


def pkgs_used_in_initrd(dest_dir, sysroot):
    MD5_HEX_CHARS = 2 * hashlib.md5().digest_size

    # Get list of files we have in the initramfs and calculate their md5sum
    initrd_digests = []
    for dirpath, dirs, files in os.walk(dest_dir):
        for f in files:
            path = os.path.join(dirpath, f)
            if os.path.islink(path) or not os.path.isfile(path):
                continue
            f_stats = os.stat(path)
            if f_stats.st_size == 0:
                continue
            dig = hashlib.file_digest(open(path, "rb"), "md5").hexdigest()
            initrd_digests.append(dig)

    # Read digests from dpkg database
    root_digests = {}
    dpkg_d = os.path.join(sysroot, "var/lib/dpkg/info")
    for dirpath, dirs, files in os.walk(dpkg_d):
        for f in files:
            if not f.endswith(".md5sums"):
                continue
            pkg = f.removesuffix(".md5sums")
            # Files from libraries usually are of form pkg_name:arch, consider
            # that case
            pkg = pkg.split(":")[0]
            with open(os.path.join(dirpath, f), "r") as digests_f:
                for ln in digests_f.readlines():
                    if len(ln) >= MD5_HEX_CHARS:
                        root_digests[ln[:MD5_HEX_CHARS]] = pkg

    # Now find matches
    pkgs = set()
    for d in initrd_digests:
        pkg = root_digests.get(d)
        if pkg is not None:
            pkgs.add(pkg)

    # Files we take from systemd-sysv are symlinks and have been ignored
    pkgs.add("systemd-sysv")
    # Return sorted list
    return sorted(pkgs)


def create_initrd_pkg_list(dest_dir, sysroot):
    # Find used repos
    repos = find_apt_repos(sysroot)
    # and packages used to build the initramfs
    pkgs = pkgs_used_in_initrd(dest_dir, sysroot)

    # Write package list in yaml format
    docs_dir = os.path.join(dest_dir, "usr", "share", "doc")
    os.makedirs(docs_dir)
    manifest_path = os.path.join(docs_dir, "dpkg.yaml")
    with open(manifest_path, "w") as pkg_list:
        # Write active repositories first so we can have an idea of where
        # things come from
        out = "package-repositories:\n"
        for r in repos:
            out += "- type: apt\n" \
                "  url: " + r.url + "\n" \
                "  suites: [ " + ', '.join(r.suites) + " ]\n" \
                "  components: [ " + ', '.join(r.components) + " ]\n"

        # Now write package information
        out += "packages:\n"
        out += check_output(["dpkg-query", "--admindir=" + sysroot + "/var/lib/dpkg/",
                             "--show", "--showformat=- ${binary:Package}=${Version}\\n"] + pkgs).decode("utf-8")
        pkg_list.write(out)

    return manifest_path


# verify_missing_dlopen looks at the .notes.dlopen section of ELF
# binaries to find libraries that are not in the dynamic section, and
# that will be loaded with dynamically dlopen when needed.
# See https://systemd.io/ELF_DLOPEN_METADATA/
def verify_missing_dlopen(destdir, libdir):
    missing = {}
    for dirpath, dirs, files in os.walk(destdir):
        for f in files:
            path = os.path.join(dirpath, f)
            if os.path.islink(path) or not os.path.isfile(path):
                continue
            with open(path, 'rb') as b:
                if b.read(4) != b'\x7fELF':
                    continue
            out = check_output(["dlopen-notes", path])
            split = out.splitlines()
            json_doc = b'\n'.join([s for s in split if not s[:1] == b'#'])
            doc = json.loads(json_doc)
            for dep in doc:
                sonames = dep["soname"]
                priority = dep["priority"]
                found_sonames = []
                for soname in sonames:
                    dest = os.path.join(destdir, os.path.relpath(libdir, "/"), soname)
                    if os.path.exists(os.path.join(destdir, dest)):
                        found_sonames.append(soname)
                if not found_sonames:
                    # We did not find any library.
                    # In this case we need to mark all sonames as
                    # missing. This is required because some features
                    # may have common subset of sonames and those
                    # features might have different priorities.
                    for soname in sonames:
                        current_priority = missing.get(soname)
                        if current_priority == "required":
                            continue
                        elif current_priority == "recommended" and priority not in ["required"]:
                            continue
                        elif current_priority == "suggested" and priority not in ["required", "recommended"]:
                            continue
                        else:
                            missing[soname] = priority

    fatal = False
    if missing:
        print("WARNING: These sonames are missing:", file=sys.stderr)
        for m, priority in missing.items():
            print(f" * {m} ({priority})", file=sys.stderr)
            if priority in ["required", "recommended"]:
                fatal = True
        if fatal:
            print("WARNING: Some missing sonames are required or recommended. Failing.", file=sys.stderr)

    return not fatal


def mkcpio(path, output, *, compress=True):
    cpio_cmd = subprocess.Popen(['cpio', '--null', '--reproducible', '--create', '--quiet', '--format=newc', '--owner=0:0'],
                                stdin=subprocess.PIPE, stdout=subprocess.PIPE if compress else output,
                                cwd=path)

    if compress:
        zstd_cmd = subprocess.Popen(['zstd', '-1', '-T0'],
                                    stdin=cpio_cmd.stdout, stdout=output)
    else:
        zstd_cmd = None

    file_list = []
    for root, dirs, files in os.walk(path):
        for d in dirs:
            file_list.append(os.path.join(os.path.relpath(root, path), d))
        for f in files:
            file_list.append(os.path.join(os.path.relpath(root, path), f))
    file_list = sorted(file_list)
    for path in file_list:
        cpio_cmd.stdin.write(path.encode('ascii'))
        cpio_cmd.stdin.write(b'\0')

    cpio_cmd.stdin.close()
    if zstd_cmd is not None:
        zstd_cmd.wait()
    cpio_cmd.wait()
    if cpio_cmd.returncode != 0:
        raise RuntimeError("cpio failed")
    if zstd_cmd is not None:
        if zstd_cmd.returncode != 0:
            raise RuntimeError("zstd failed")


def mk_amd_ucode(root, output):
    for path in glob.glob(f'{root}/lib/firmware/amd-ucode/*.bin'):
        with open(path, 'rb') as f:
            shutil.copyfileobj(f, output)


def mk_intel_ucode(root, output_path):
    check_call(['iucode_tool', f'--write-to={output_path}', f'{root}/lib/firmware/intel-ucode/'])


def reset_time(path):
    for root, dirs, files in os.walk(path):
        for d in dirs:
            os.utime(os.path.join(root, d), (0, 0))
        for f in files:
            os.utime(os.path.join(root, f), (0, 0))


def get_ubuntu_release(sysroot):
    with open(os.path.join(sysroot, "etc/os-release"), "r") as f:
        regex = re.compile('^VERSION_ID="([0-9]+)[.]([0-9]+)"$', re.MULTILINE)
        res = regex.search(f.read())
        if res:
            return [int(res.group(1)), int(res.group(2))]
    print("ERROR: cannot find Ubuntu release", file=sys.stderr)
    sys.exit(1)


def release_higher_or_equal(release1, release2):
    if release1[0] > release2[0] or \
       (release1[0] == release2[0] and release1[1] >= release2[1]):
        return True
    return False


def create_initrd(parser, args):
    rootfs = "/"
    if not args.kerneldir:
        args.kerneldir = "/lib/modules/%s" % args.kernelver
    if args.root:
        args.skeleton = path_join_make_rel_paths(args.root, args.skeleton)
        args.kerneldir = path_join_make_rel_paths(args.root, args.kerneldir)
        args.firmwaredir = path_join_make_rel_paths(args.root, args.firmwaredir)
        args.output = path_join_make_rel_paths(args.root, args.output)
        rootfs = args.root
    if args.kernelver:
        args.output = "-".join([args.output, args.kernelver])
    with tempfile.TemporaryDirectory(suffix=".ubuntu-core-initramfs") as d:
        deb_arch = dpkg_architecture(rootfs, "DEB_HOST_MULTIARCH")
        ubuntu_release = get_ubuntu_release(rootfs)

        kernel_root = os.path.join(d, "kernel")
        modules = os.path.join(kernel_root, "usr", "lib", "modules")
        os.makedirs(modules, exist_ok=True)
        modules = os.path.join(modules, args.kernelver)
        check_call(["cp", "-ar", args.kerneldir, modules])

        firmware = os.path.join(kernel_root, "usr", "lib", "firmware")
        check_call(["cp", "-ar", args.firmwaredir, firmware])

        db = ModuleDb(modules)
        main = os.path.join(d, "main")
        os.makedirs(main, exist_ok=True)
        # copy busybox first so we get already the shell interpreter we
        # want (busybox) instead of dracut-install pulling the systemd
        # default (dash) when it pulls a shell script later.
        install_busybox(main, rootfs)
        # Copy systemd bits
        install_systemd_files(main, rootfs, ubuntu_release)
        # Other miscelanea stuff
        install_misc(main, rootfs, deb_arch)
        # Copy features
        for feature in args.features:
            # Add feature files
            feature_path = os.path.join(args.skeleton, feature)
            if os.path.isdir(feature_path):
                check_call(["cp", "-aT", feature_path, main])
            # Add feature kernel modules
            extra_modules = os.path.join(args.skeleton, "modules", feature,
                                         "extra-modules.conf")
            if os.path.exists(extra_modules):
                add_modules_from_file(main, kernel_root, modules, firmware, extra_modules, db)

        # TODO xnox: fips actually needs additional runtime dependencies than
        # this, and currently forked in fips PPA. we need to figure out how to
        # make ubuntu-core-initramfs-fips a reality.
        if "fips" in args.features:
            install_files([
                "/usr/bin/kcapi-hasher",
                "/usr/bin/.kcapi-hasher.hmac",
            ] + glob.glob("/usr/lib/*/.libkcapi.so.*.hmac"), main, rootfs)

        # Update epoch
        pathlib.Path("%s/main/usr/lib/clock-epoch" % d).touch()
        # Should iterate all the .conf drop ins
        for module_load in glob.iglob("%s/usr/lib/modules-load.d/*.conf" % main):
            add_modules_from_file(main, kernel_root, modules, firmware, module_load, db,
                                  warn_discoverable=True)

        for modulesf in ["modules.order", "modules.builtin", "modules.builtin.bin", "modules.builtin.modinfo"]:
            check_call(
                [
                    "/usr/lib/dracut/dracut-install",
                    "-D",
                    main,
                    "-r", kernel_root,
                    os.path.join(modules, modulesf),
                    os.path.join("usr/lib/modules", args.kernelver, modulesf),
                ]
            )
        check_call(["depmod", "-a", "-b", main, args.kernelver])

        if release_higher_or_equal(ubuntu_release, (24, 10)):
            if not verify_missing_dlopen(main, os.path.join("/usr/lib", deb_arch)):
                sys.exit(1)

        # Create manifest with packages with files included in the initramfs
        manifest_path = create_initrd_pkg_list(main, rootfs)
        out_dir = os.path.dirname(args.output)
        manifest_out_path = os.path.join(out_dir, "manifest-initramfs.yaml")
        if args.kernelver:
            manifest_out_path = "-".join([manifest_out_path, args.kernelver])
        shutil.copyfile(manifest_path, manifest_out_path)

        arch = dpkg_architecture(args.root, "DEB_HOST_ARCH")

        # Finally, write it
        with open(args.output, "wb") as output:
            if arch == "amd64":
                with tempfile.TemporaryDirectory(suffix='ubuntu-core-initramfs') as d:
                    microcode_dir = os.path.join(d, 'kernel/x86/microcode')
                    os.makedirs(microcode_dir)
                    with open(os.path.join(microcode_dir, 'AuthenticAMD.bin'), 'wb') as f:
                        mk_amd_ucode(args.root, f)
                    mk_intel_ucode(args.root, os.path.join(microcode_dir, 'GenuineIntel.bin'))
                    # Time is irrelevant. And we should make it reproducible.
                    # It is also what update-microcode-initrd from package does.
                    reset_time(d)
                    mkcpio(d, output, compress=False)

            mkcpio(main, output)


def create_efi(parser, args):
    if args.root:
        if not args.stub:
            parser.error("--stub is mandatory when --root is given")
        args.os_release = path_join_make_rel_paths(args.root, args.os_release)
        args.stub = path_join_make_rel_paths(args.root, args.stub)
        args.kernel = path_join_make_rel_paths(args.root, args.kernel)
        args.initrd = path_join_make_rel_paths(args.root, args.initrd)
        args.key = path_join_make_rel_paths(args.root, args.key)
        args.cert = path_join_make_rel_paths(args.root, args.cert)
        args.output = path_join_make_rel_paths(args.root, args.output)
    if args.kernelver:
        args.kernel = "-".join([args.kernel, args.kernelver])
        args.initrd = "-".join([args.initrd, args.kernelver])
        args.output = "-".join([args.output, args.kernelver])

        if platform.machine() == "aarch64":
            import gzip
            raw_kernel = tempfile.NamedTemporaryFile(mode='wb')
            try:
                with gzip.open(args.kernel, 'rb') as comp_kernel:
                    shutil.copyfileobj(comp_kernel, raw_kernel)
                raw_kernel.flush()
                args.kernel = raw_kernel.name
            except gzip.BadGzipFile:
                pass

    # TODO: add --splash
    ukify_cmd = [
        '/usr/lib/systemd/ukify', 'build',
        '--linux', args.kernel,
        '--initrd', args.initrd,
        '--output', args.output,
        '--os-release', '@{}'.format(args.os_release),
    ]

    if args.kernelver:
        ukify_cmd += [
            '--uname', args.kernelver,
        ]
    if args.stub:
        ukify_cmd += [
            '--stub', args.stub,
        ]
    if args.efi_arch:
        ukify_cmd += [
            '--efi-arch', args.efi_arch,
        ]
    if args.cmdline:
        ukify_cmd += [
            '--cmdline', args.cmdline,
        ]
    check_call(ukify_cmd)
    if not args.unsigned:
        check_call(
            [
                "sbsign",
                "--key",
                args.key,
                "--cert",
                args.cert,
                "--output",
                args.output,
                args.output,
            ]
        )
    os.chmod(args.output, 0o644)


def main():
    kernelver = check_output(
        ["uname", "-r"], universal_newlines=True
    ).strip()
    features = {"x86_64": "main server"}.get(platform.machine(), "main")
    parser = argparse.ArgumentParser()
    subparser = parser.add_subparsers(dest="subcmd", required=True)
    efi_parser = subparser.add_parser(
        "create-efi", formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    efi_parser.add_argument("--root", help="path to root")
    efi_parser.add_argument("--stub", help="path to stub")
    efi_parser.add_argument("--efi-arch", help="efi architecture name")
    efi_parser.add_argument("--os-release", help="path to os-release", default="/etc/os-release")
    efi_parser.add_argument("--kernel", help="path to kernel", default="/boot/vmlinuz")
    efi_parser.add_argument(
        "--kernelver", help="kernel version suffix", default=kernelver
    )
    efi_parser.add_argument(
        "--initrd", help="path to initramfs", default="/boot/ubuntu-core-initramfs.img"
    )
    efi_parser.add_argument(
        "--cmdline", help="commandline to embed (can be overridden with LoadOptions)"
    )
    efi_parser.add_argument(
        "--unsigned", help="do not sign efi app", default=False, action="store_true"
    )
    efi_parser.add_argument(
        "--key",
        help="signing key",
        default="/usr/lib/ubuntu-core-initramfs/snakeoil/PkKek-1-snakeoil.key",
    )
    efi_parser.add_argument(
        "--cert",
        help="certificate",
        default="/usr/lib/ubuntu-core-initramfs/snakeoil/PkKek-1-snakeoil.pem",
    )
    # TODO we should have an output-dir instead and maybe another override for
    # the file name.
    efi_parser.add_argument(
        "--output", help="path to output", default="/boot/kernel.efi"
    )
    initrd_parser = subparser.add_parser(
        "create-initrd", formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    initrd_parser.add_argument("--root", help="path to root", default="")
    initrd_parser.add_argument(
        "--skeleton", help="path to skeleton", default="/usr/lib/ubuntu-core-initramfs"
    )
    initrd_parser.add_argument(
        "--feature",
        dest="features",
        help="list of features to enable, if unspecified defaults to %s" % features,
        nargs="+"
    )
    initrd_parser.add_argument(
        "--kerneldir", help="path to kernel modules dir default /lib/modules/KERNELVER"
    )
    initrd_parser.add_argument("--kernelver", help="kernel version", default=kernelver)
    initrd_parser.add_argument(
        "--firmwaredir", help="path to kernel firmware", default="/lib/firmware"
    )
    # TODO we should have an output-dir instead and maybe another override for
    # the file name.
    initrd_parser.add_argument(
        "--output", help="path to output", default="/boot/ubuntu-core-initramfs.img"
    )
    initrd_parser.set_defaults(
        kernelver=check_output(
            ["uname", "-r"], universal_newlines=True
        ).strip()
    )

    args = parser.parse_args()
    if args.subcmd == "create-initrd" and not args.features:
        # Set arch specific default set of features
        args.features = features.split()
        # For generic kernels, enable more server drivers too
        if args.kernelver.endswith("-generic") and "server" not in args.features:
            args.features.append("server")
        if args.kernelver.endswith("-fips"):
            args.features.append("fips")
    globals()[args.subcmd.replace("-", "_")](parser, args)


if __name__ == "__main__":
    main()
