summaryrefslogtreecommitdiffstats
path: root/roles/openshift_health_checker/library/rpm_version.py
blob: 8ea223055e4d3fcfd6d5415328c4b3e36324649c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/python
"""
Ansible module for rpm-based systems determining existing package version information in a host.
"""

from ansible.module_utils.basic import AnsibleModule

IMPORT_EXCEPTION = None
try:
    import rpm  # pylint: disable=import-error
except ImportError as err:
    IMPORT_EXCEPTION = err  # in tox test env, rpm import fails


class RpmVersionException(Exception):
    """Base exception class for package version problems"""
    def __init__(self, message, problem_pkgs=None):
        Exception.__init__(self, message)
        self.problem_pkgs = problem_pkgs


def main():
    """Entrypoint for this Ansible module"""
    module = AnsibleModule(
        argument_spec=dict(
            package_list=dict(type="list", required=True),
        ),
        supports_check_mode=True
    )

    if IMPORT_EXCEPTION:
        module.fail_json(msg="rpm_version module could not import rpm: %s" % IMPORT_EXCEPTION)

    # determine the packages we will look for
    pkg_list = module.params['package_list']
    if not pkg_list:
        module.fail_json(msg="package_list must not be empty")

    # get list of packages available and complain if any
    # of them are missing or if any errors occur
    try:
        pkg_versions = _retrieve_expected_pkg_versions(_to_dict(pkg_list))
        _check_pkg_versions(pkg_versions, _to_dict(pkg_list))
    except RpmVersionException as excinfo:
        module.fail_json(msg=str(excinfo))
    module.exit_json(changed=False)


def _to_dict(pkg_list):
    return {pkg["name"]: pkg for pkg in pkg_list}


def _retrieve_expected_pkg_versions(expected_pkgs_dict):
    """Search for installed packages matching given pkg names
    and versions. Returns a dictionary: {pkg_name: [versions]}"""

    transaction = rpm.TransactionSet()
    pkgs = {}

    for pkg_name in expected_pkgs_dict:
        matched_pkgs = transaction.dbMatch("name", pkg_name)
        if not matched_pkgs:
            continue

        for header in matched_pkgs:
            if header['name'] == pkg_name:
                if pkg_name not in pkgs:
                    pkgs[pkg_name] = []

                pkgs[pkg_name].append(header['version'])

    return pkgs


def _check_pkg_versions(found_pkgs_dict, expected_pkgs_dict):
    invalid_pkg_versions = {}
    not_found_pkgs = []

    for pkg_name, pkg in expected_pkgs_dict.items():
        if not found_pkgs_dict.get(pkg_name):
            not_found_pkgs.append(pkg_name)
            continue

        found_versions = [_parse_version(version) for version in found_pkgs_dict[pkg_name]]
        expected_version = _parse_version(pkg["version"])
        if expected_version not in found_versions:
            invalid_pkg_versions[pkg_name] = {
                "found_versions": found_versions,
                "required_version": expected_version,
            }

    if not_found_pkgs:
        raise RpmVersionException(
            '\n'.join([
                "The following packages were not found to be installed: {}".format('\n    '.join([
                    "{}".format(pkg)
                    for pkg in not_found_pkgs
                ]))
            ]),
            not_found_pkgs,
        )

    if invalid_pkg_versions:
        raise RpmVersionException(
            '\n    '.join([
                "The following packages were found to be installed with an incorrect version: {}".format('\n'.join([
                    "    \n{}\n    Required version: {}\n    Found versions: {}".format(
                        pkg_name,
                        pkg["required_version"],
                        ', '.join([version for version in pkg["found_versions"]]))
                    for pkg_name, pkg in invalid_pkg_versions.items()
                ]))
            ]),
            invalid_pkg_versions,
        )


def _parse_version(version_str):
    segs = version_str.split('.')
    if not segs or len(segs) <= 2:
        return version_str

    return '.'.join(segs[0:2])


if __name__ == '__main__':
    main()