Cleanups to the common.sys_info API

* Move get_all_subclasses out of sys_info as it is unrelated to system
  information.
* get_all_subclasses now returns a set() instead of a list.
* Don't port get_platform to sys_info as it is deprecated.  Code using
  the common API should just use platform.system() directly.
* Rename load_platform_subclass() to get_platform_subclass and do not
  instantiate the rturned class.
* Test the compat shims in module_utils/basic.py separately from the new
  API in module_utils/common/sys_info.py and module_utils/common/_utils.py
This commit is contained in:
Toshio Kuratomi 2018-12-31 15:10:51 -08:00
parent 79dc9a75c3
commit 5844c8c7f0
7 changed files with 284 additions and 47 deletions

View file

@ -165,11 +165,9 @@ from ansible.module_utils.common.file import (
get_flags_from_attributes, get_flags_from_attributes,
) )
from ansible.module_utils.common.sys_info import ( from ansible.module_utils.common.sys_info import (
get_platform,
get_distribution, get_distribution,
get_distribution_version, get_distribution_version,
load_platform_subclass, get_platform_subclass,
get_all_subclasses,
) )
from ansible.module_utils.pycompat24 import get_exception, literal_eval from ansible.module_utils.pycompat24 import get_exception, literal_eval
from ansible.module_utils.six import ( from ansible.module_utils.six import (
@ -184,6 +182,7 @@ from ansible.module_utils.six import (
) )
from ansible.module_utils.six.moves import map, reduce, shlex_quote from ansible.module_utils.six.moves import map, reduce, shlex_quote
from ansible.module_utils._text import to_native, to_bytes, to_text from ansible.module_utils._text import to_native, to_bytes, to_text
from ansible.module_utils.common._utils import get_all_subclasses as _get_all_subclasses
from ansible.module_utils.parsing.convert_bool import BOOLEANS, BOOLEANS_FALSE, BOOLEANS_TRUE, boolean from ansible.module_utils.parsing.convert_bool import BOOLEANS, BOOLEANS_FALSE, BOOLEANS_TRUE, boolean
@ -276,6 +275,42 @@ if not _PY_MIN:
sys.exit(1) sys.exit(1)
#
# Deprecated functions
#
def get_platform():
'''
**Deprecated** Use :py:func:`platform.system` directly.
:returns: Name of the platform the module is running on in a native string
Returns a native string that labels the platform ("Linux", "Solaris", etc). Currently, this is
the result of calling :py:func:`platform.system`.
'''
return platform.system()
# End deprecated functions
#
# Compat shims
#
def load_platform_subclass(cls, *args, **kwargs):
"""**Deprecated**: Use ansible.module_utils.common.sys_info.get_platform_subclass instead"""
platform_cls = get_platform_subclass(cls)
return super(cls, platform_cls).__new__(platform_cls)
def get_all_subclasses(cls):
"""**Deprecated**: Use ansible.module_utils.common._utils.get_all_subclasses instead"""
return list(_get_all_subclasses(cls))
# End compat shims
def json_dict_unicode_to_bytes(d, encoding='utf-8', errors='surrogate_or_strict'): def json_dict_unicode_to_bytes(d, encoding='utf-8', errors='surrogate_or_strict'):
''' Recursively convert dict keys and values to byte str ''' Recursively convert dict keys and values to byte str

View file

@ -0,0 +1,40 @@
# Copyright (c) 2018, Ansible Project
# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
"""
Modules in _utils are waiting to find a better home. If you need to use them, be prepared for them
to move to a different location in the future.
"""
def get_all_subclasses(cls):
'''
Recursively search and find all subclasses of a given class
:arg cls: A python class
:rtype: set
:returns: The set of python classes which are the subclasses of `cls`.
In python, you can use a class's :py:meth:`__subclasses__` method to determine what subclasses
of a class exist. However, `__subclasses__` only goes one level deep. This function searches
each child class's `__subclasses__` method to find all of the descendent classes. It then
returns an iterable of the descendent classes.
'''
# Retrieve direct subclasses
subclasses = set(cls.__subclasses__())
to_visit = list(subclasses)
# Then visit all subclasses
while to_visit:
for sc in to_visit:
# The current class is now visited, so remove it from list
to_visit.remove(sc)
# Appending all subclasses to visit and keep a reference of available class
for ssc in sc.__subclasses__():
if ssc not in subclasses:
to_visit.append(ssc)
subclasses.add(ssc)
return subclasses

View file

@ -9,21 +9,22 @@ import os
import platform import platform
from ansible.module_utils import distro from ansible.module_utils import distro
from ansible.module_utils.common._utils import get_all_subclasses
# Backwards compat. New code should just use platform.system() __all__ = ('get_distribution', 'get_distribution_version', 'get_platform_subclass')
def get_platform():
'''
:rtype: NativeString
:returns: Name of the platform the module is running on
'''
return platform.system()
def get_distribution(): def get_distribution():
''' '''
Return the name of the distribution the module is running on
:rtype: NativeString or None :rtype: NativeString or None
:returns: Name of the distribution the module is running on :returns: Name of the distribution the module is running on
This function attempts to determine what Linux distribution the code is running on and return
a string representing that value. If the distribution cannot be determined, it returns
``OtherLinux``. If not run on Linux it returns None.
''' '''
distribution = None distribution = None
@ -42,9 +43,11 @@ def get_distribution():
def get_distribution_version(): def get_distribution_version():
''' '''
Get the version of the Linux distribution the code is running on
:rtype: NativeString or None :rtype: NativeString or None
:returns: A string representation of the version of the distribution. None if this is not run :returns: A string representation of the version of the distribution. If it cannot determine
on a Linux machine the version, it returns empty string. If this is not run on a Linux machine it returns None
''' '''
version = None version = None
if platform.system() == 'Linux': if platform.system() == 'Linux':
@ -82,33 +85,36 @@ def get_distribution_codename():
return codename return codename
def get_all_subclasses(cls): def get_platform_subclass(cls):
''' '''
used by modules like Hardware or Network fact classes to recursively retrieve all Finds a subclass implementing desired functionality on the platform the code is running on
subclasses of a given class not only the direct sub classes.
'''
# Retrieve direct subclasses
subclasses = cls.__subclasses__()
to_visit = list(subclasses)
# Then visit all subclasses
while to_visit:
for sc in to_visit:
# The current class is now visited, so remove it from list
to_visit.remove(sc)
# Appending all subclasses to visit and keep a reference of available class
for ssc in sc.__subclasses__():
subclasses.append(ssc)
to_visit.append(ssc)
return subclasses
:arg cls: Class to find an appropriate subclass for
:returns: A class that implements the functionality on this platform
def load_platform_subclass(cls, *args, **kwargs): Some Ansible modules have different implementations depending on the platform they run on. This
''' function is used to select between the various implementations and choose one. You can look at
used by modules like User to have different implementations based on detected platform. See User the implementation of the Ansible :ref:`User module<user_module>` module for an example of how to use this.
module for an example.
This function replaces ``basic.load_platform_subclass()``. When you port code, you need to
change the callers to be explicit about instantiating the class. For instance, code in the
Ansible User module changed from::
.. code-block:: python
# Old
class User:
def __new__(cls, args, kwargs):
return load_platform_subclass(User, args, kwargs)
# New
class User:
def __new__(cls, args, kwargs):
new_cls = get_platform_subclass(User)
return super(cls, new_cls).__new__(new_cls, args, kwargs)
''' '''
this_platform = get_platform() this_platform = platform.system()
distribution = get_distribution() distribution = get_distribution()
subclass = None subclass = None
@ -124,4 +130,4 @@ def load_platform_subclass(cls, *args, **kwargs):
if subclass is None: if subclass is None:
subclass = cls subclass = cls
return super(cls, subclass).__new__(subclass) return subclass

View file

@ -47,6 +47,7 @@ MODULE_UTILS_BASIC_IMPORTS = frozenset((('_text',),
('common', 'file'), ('common', 'file'),
('common', 'process'), ('common', 'process'),
('common', 'sys_info'), ('common', 'sys_info'),
('common', '_utils'),
('distro', '__init__'), ('distro', '__init__'),
('distro', '_distro'), ('distro', '_distro'),
('parsing', '__init__'), ('parsing', '__init__'),
@ -55,19 +56,20 @@ MODULE_UTILS_BASIC_IMPORTS = frozenset((('_text',),
('six', '__init__'), ('six', '__init__'),
)) ))
MODULE_UTILS_BASIC_FILES = frozenset(('ansible/module_utils/parsing/__init__.py', MODULE_UTILS_BASIC_FILES = frozenset(('ansible/module_utils/_text.py',
'ansible/module_utils/common/process.py',
'ansible/module_utils/basic.py', 'ansible/module_utils/basic.py',
'ansible/module_utils/six/__init__.py',
'ansible/module_utils/_text.py',
'ansible/module_utils/common/_collections_compat.py',
'ansible/module_utils/parsing/convert_bool.py',
'ansible/module_utils/common/__init__.py', 'ansible/module_utils/common/__init__.py',
'ansible/module_utils/common/_collections_compat.py',
'ansible/module_utils/common/file.py', 'ansible/module_utils/common/file.py',
'ansible/module_utils/common/process.py',
'ansible/module_utils/common/sys_info.py', 'ansible/module_utils/common/sys_info.py',
'ansible/module_utils/common/_utils.py',
'ansible/module_utils/distro/__init__.py', 'ansible/module_utils/distro/__init__.py',
'ansible/module_utils/distro/_distro.py', 'ansible/module_utils/distro/_distro.py',
'ansible/module_utils/parsing/__init__.py',
'ansible/module_utils/parsing/convert_bool.py',
'ansible/module_utils/pycompat24.py', 'ansible/module_utils/pycompat24.py',
'ansible/module_utils/six/__init__.py',
)) ))
ONLY_BASIC_IMPORT = frozenset((('basic',),)) ONLY_BASIC_IMPORT = frozenset((('basic',),))

View file

@ -16,11 +16,11 @@ from units.compat.mock import patch
from ansible.module_utils.six.moves import builtins from ansible.module_utils.six.moves import builtins
# Functions being tested # Functions being tested
from ansible.module_utils.common.sys_info import get_all_subclasses from ansible.module_utils.basic import get_platform
from ansible.module_utils.common.sys_info import get_distribution from ansible.module_utils.basic import get_all_subclasses
from ansible.module_utils.common.sys_info import get_distribution_version from ansible.module_utils.basic import get_distribution
from ansible.module_utils.common.sys_info import get_platform from ansible.module_utils.basic import get_distribution_version
from ansible.module_utils.common.sys_info import load_platform_subclass from ansible.module_utils.basic import load_platform_subclass
realimport = builtins.__import__ realimport = builtins.__import__
@ -104,7 +104,7 @@ class TestLoadPlatformSubclass:
def test_not_linux(self): def test_not_linux(self):
# if neither match, the fallback should be the top-level class # if neither match, the fallback should be the top-level class
with patch('ansible.module_utils.common.sys_info.get_platform', return_value="Foo"): with patch('platform.system', return_value="Foo"):
with patch('ansible.module_utils.common.sys_info.get_distribution', return_value=None): with patch('ansible.module_utils.common.sys_info.get_distribution', return_value=None):
assert isinstance(load_platform_subclass(self.LinuxTest), self.LinuxTest) assert isinstance(load_platform_subclass(self.LinuxTest), self.LinuxTest)

View file

@ -0,0 +1,108 @@
# -*- coding: utf-8 -*-
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
# (c) 2016 Toshio Kuratomi <tkuratomi@ansible.com>
# (c) 2017-2018 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
from units.compat.mock import patch
from ansible.module_utils.six.moves import builtins
# Functions being tested
from ansible.module_utils.common.sys_info import get_distribution
from ansible.module_utils.common.sys_info import get_distribution_version
from ansible.module_utils.common.sys_info import get_platform_subclass
realimport = builtins.__import__
@pytest.fixture
def platform_linux(mocker):
mocker.patch('platform.system', return_value='Linux')
#
# get_distribution tests
#
def test_get_distribution_not_linux():
"""If it's not Linux, then it has no distribution"""
with patch('platform.system', return_value='Foo'):
assert get_distribution() is None
@pytest.mark.usefixtures("platform_linux")
class TestGetDistribution:
""" Tests for get_distribution that have to find somethine"""
def test_distro_known(self):
with patch('ansible.module_utils.distro.name', return_value="foo"):
assert get_distribution() == "Foo"
def test_distro_unknown(self):
with patch('ansible.module_utils.distro.name', return_value=""):
assert get_distribution() == "OtherLinux"
def test_distro_amazon_part_of_another_name(self):
with patch('ansible.module_utils.distro.name', return_value="AmazonFooBar"):
assert get_distribution() == "Amazonfoobar"
def test_distro_amazon_linux(self):
with patch('ansible.module_utils.distro.name', return_value="Amazon Linux AMI"):
assert get_distribution() == "Amazon"
#
# get_distribution_version tests
#
def test_get_distribution_version_not_linux():
"""If it's not Linux, then it has no distribution"""
with patch('platform.system', return_value='Foo'):
assert get_distribution_version() is None
@pytest.mark.usefixtures("platform_linux")
def test_distro_found():
with patch('ansible.module_utils.distro.version', return_value="1"):
assert get_distribution_version() == "1"
#
# Tests for get_platform_subclass
#
class TestGetPlatformSubclass:
class LinuxTest:
pass
class Foo(LinuxTest):
platform = "Linux"
distribution = None
class Bar(LinuxTest):
platform = "Linux"
distribution = "Bar"
def test_not_linux(self):
# if neither match, the fallback should be the top-level class
with patch('platform.system', return_value="Foo"):
with patch('ansible.module_utils.common.sys_info.get_distribution', return_value=None):
assert get_platform_subclass(self.LinuxTest) is self.LinuxTest
@pytest.mark.usefixtures("platform_linux")
def test_get_distribution_none(self):
# match just the platform class, not a specific distribution
with patch('ansible.module_utils.common.sys_info.get_distribution', return_value=None):
assert get_platform_subclass(self.LinuxTest) is self.Foo
@pytest.mark.usefixtures("platform_linux")
def test_get_distribution_found(self):
# match both the distribution and platform class
with patch('ansible.module_utils.common.sys_info.get_distribution', return_value="Bar"):
assert get_platform_subclass(self.LinuxTest) is self.Bar

View file

@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# (c) 2018 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from ansible.module_utils.common.sys_info import get_all_subclasses
#
# Tests for get_all_subclasses
#
class TestGetAllSubclasses:
class Base:
pass
class BranchI(Base):
pass
class BranchII(Base):
pass
class BranchIA(BranchI):
pass
class BranchIB(BranchI):
pass
class BranchIIA(BranchII):
pass
class BranchIIB(BranchII):
pass
def test_bottom_level(self):
assert get_all_subclasses(self.BranchIIB) == set()
def test_one_inheritance(self):
assert set(get_all_subclasses(self.BranchII)) == set([self.BranchIIA, self.BranchIIB])
def test_toplevel(self):
assert set(get_all_subclasses(self.Base)) == set([self.BranchI, self.BranchII,
self.BranchIA, self.BranchIB,
self.BranchIIA, self.BranchIIB])