#!/usr/bin/env python2 # vim: expandtab:tabstop=4:shiftwidth=4 from time import time import argparse import yaml import os import sys import pdb import subprocess import json import pprint CONFIG_FILE_NAME = 'multi_ec2.yaml' class MultiEc2(object): def __init__(self): self.config = None self.all_ec2_results = {} self.result = {} self.cache_path = os.path.expanduser('~/.ansible/tmp/multi_ec2_inventory.cache') self.file_path = os.path.join(os.path.dirname(os.path.realpath(__file__))) same_dir_config_file = os.path.join(self.file_path, CONFIG_FILE_NAME) etc_dir_config_file = os.path.join(os.path.sep, 'etc','ansible', CONFIG_FILE_NAME) # Prefer a file in the same directory, fall back to a file in etc if os.path.isfile(same_dir_config_file): self.config_file = same_dir_config_file elif os.path.isfile(etc_dir_config_file): self.config_file = etc_dir_config_file else: self.config_file = None # expect env vars self.parse_cli_args() # load yaml if self.config_file and os.path.isfile(self.config_file): self.config = self.load_yaml_config() elif os.environ.has_key("AWS_ACCESS_KEY_ID") and os.environ.has_key("AWS_SECRET_ACCESS_KEY"): self.config = {} self.config['accounts'] = [ { 'name': 'default', 'provider': 'aws/ec2.py', 'env_vars': { 'AWS_ACCESS_KEY_ID': os.environ["AWS_ACCESS_KEY_ID"], 'AWS_SECRET_ACCESS_KEY': os.environ["AWS_SECRET_ACCESS_KEY"], } }, ] self.config['cache_max_age'] = 0 else: raise RuntimeError("Could not find valid ec2 credentials in the environment.") if self.args.cache_only: # get data from disk result = self.get_inventory_from_cache() if not result: self.get_inventory() self.write_to_cache() # if its a host query, fetch and do not cache elif self.args.host: self.get_inventory() elif not self.is_cache_valid(): # go fetch the inventories and cache them if cache is expired self.get_inventory() self.write_to_cache() else: # get data from disk self.get_inventory_from_cache() def load_yaml_config(self,conf_file=None): """Load a yaml config file with credentials to query the respective cloud for inventory. """ config = None if not conf_file: conf_file = self.config_file with open(conf_file) as conf: config = yaml.safe_load(conf) return config def get_provider_tags(self,provider, env={}): """Call and query all of the tags that are usuable by ansible. If environment is empty use the default env. """ if not env: env = os.environ # Allow relatively path'd providers in config file if os.path.isfile(os.path.join(self.file_path, provider)): provider = os.path.join(self.file_path, provider) # check to see if provider exists if not os.path.isfile(provider) or not os.access(provider, os.X_OK): raise RuntimeError("Problem with the provider. Please check path " \ "and that it is executable. (%s)" % provider) cmds = [provider] if self.args.host: cmds.append("--host") cmds.append(self.args.host) else: cmds.append('--list') cmds.append('--refresh-cache') return subprocess.Popen(cmds, stderr=subprocess.PIPE, \ stdout=subprocess.PIPE, env=env) def get_inventory(self): """Create the subprocess to fetch tags from a provider. Host query: Query to return a specific host. If > 1 queries have results then fail. List query: Query all of the different accounts for their tags. Once completed store all of their results into one merged updated hash. """ processes = {} for account in self.config['accounts']: env = account['env_vars'] name = account['name'] provider = account['provider'] processes[name] = self.get_provider_tags(provider, env) # for each process collect stdout when its available all_results = [] for name, process in processes.items(): out, err = process.communicate() all_results.append({ "name": name, "out": out.strip(), "err": err.strip(), "code": process.returncode }) # process --host results if not self.args.host: # For any non-zero, raise an error on it for result in all_results: if result['code'] != 0: raise RuntimeError(result['err']) else: self.all_ec2_results[result['name']] = json.loads(result['out']) values = self.all_ec2_results.values() values.insert(0, self.result) [MultiEc2.merge_destructively(self.result, x) for x in values] else: # For any 0 result, return it count = 0 for results in all_results: if results['code'] == 0 and results['err'] == '' and results['out'] != '{}': self.result = json.loads(out) count += 1 if count > 1: raise RuntimeError("Found > 1 results for --host %s. \ This is an invalid state." % self.args.host) @staticmethod def merge_destructively(a, b): "merges b into a" for key in b: if key in a: if isinstance(a[key], dict) and isinstance(b[key], dict): MultiEc2.merge_destructively(a[key], b[key]) elif a[key] == b[key]: pass # same leaf value # both lists so add each element in b to a if it does ! exist elif isinstance(a[key], list) and isinstance(b[key],list): for x in b[key]: if x not in a[key]: a[key].append(x) # a is a list and not b elif isinstance(a[key], list): if b[key] not in a[key]: a[key].append(b[key]) elif isinstance(b[key], list): a[key] = [a[key]] + [k for k in b[key] if k != a[key]] else: a[key] = [a[key],b[key]] else: a[key] = b[key] return a def is_cache_valid(self): ''' Determines if the cache files have expired, or if it is still valid ''' if os.path.isfile(self.cache_path): mod_time = os.path.getmtime(self.cache_path) current_time = time() if (mod_time + self.config['cache_max_age']) > current_time: return True return False def parse_cli_args(self): ''' Command line argument processing ''' parser = argparse.ArgumentParser(description='Produce an Ansible Inventory file based on a provider') parser.add_argument('--cache-only', action='store_true', default=False, help='Fetch cached only instances (default: False)') parser.add_argument('--list', action='store_true', default=True, help='List instances (default: True)') parser.add_argument('--host', action='store', default=False, help='Get all the variables about a specific instance') self.args = parser.parse_args() def write_to_cache(self): ''' Writes data in JSON format to a file ''' json_data = self.json_format_dict(self.result, True) with open(self.cache_path, 'w') as cache: cache.write(json_data) def get_inventory_from_cache(self): ''' Reads the inventory from the cache file and returns it as a JSON object ''' if not os.path.isfile(self.cache_path): return None with open(self.cache_path, 'r') as cache: self.result = json.loads(cache.read()) return True def json_format_dict(self, data, pretty=False): ''' Converts a dict to a JSON object and dumps it as a formatted string ''' if pretty: return json.dumps(data, sort_keys=True, indent=2) else: return json.dumps(data) def result_str(self): return self.json_format_dict(self.result, True) if __name__ == "__main__": mi = MultiEc2() print mi.result_str()