From 7e05d7478f3f7ab636db5ba52acca967313f175c Mon Sep 17 00:00:00 2001 From: Joshua Tauberer Date: Sat, 31 Jan 2015 20:40:20 +0000 Subject: [PATCH] run status checks asynchronously so that they finish faster, since many checks are waiting on network replies and ought not to block the whole thing --- management/status_checks.py | 155 ++++++++++++++++++++++++------------ 1 file changed, 105 insertions(+), 50 deletions(-) diff --git a/management/status_checks.py b/management/status_checks.py index 4f040372..e2e1ad00 100755 --- a/management/status_checks.py +++ b/management/status_checks.py @@ -6,7 +6,7 @@ __ALL__ = ['check_certificate'] -import os, os.path, re, subprocess, datetime +import os, os.path, re, subprocess, datetime, multiprocessing.pool import dns.reversename, dns.resolver import dateutil.parser, dateutil.tz @@ -35,15 +35,17 @@ def run_checks(env, output): run_system_checks(env, output) - # perform other checks - run_network_checks(env, output) - run_domain_checks(env, output) + # perform other checks asynchronously + + pool = multiprocessing.pool.Pool(processes=1) + r1 = pool.apply_async(run_network_checks, [env]) + r2 = run_domain_checks(env) + r1.get().playback(output) + r2.playback(output) def run_services_checks(env, output): # Check that system services are running. - import socket - services = [ { "name": "Local DNS (bind9)", "port": 53, "public": False, }, #{ "name": "NSD Control", "port": 8952, "public": False, }, @@ -66,33 +68,47 @@ def run_services_checks(env, output): { "name": "HTTPS Web (nginx)", "port": 443, "public": True, }, ] - ok = True - - for service in services: - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.settimeout(.1) - try: - s.connect(( - "127.0.0.1" if not service["public"] else env['PUBLIC_IP'], - service["port"])) - except OSError as e: - output.print_error("%s is not running (%s)." % (service['name'], str(e))) + all_running = True + fatal = False + pool = multiprocessing.pool.Pool(processes=10) + ret = pool.starmap(check_service, ((i, service, env) for i, service in enumerate(services)), chunksize=1) + for i, running, fatal2, output2 in sorted(ret): + all_running = all_running and running + fatal = fatal or fatal2 + output2.playback(output) - # Why is nginx not running? - if service["port"] in (80, 443): - output.print_line(shell('check_output', ['nginx', '-t'], capture_stderr=True, trap=True)[1].strip()) - - # Flag if local DNS is not running. - if service["port"] == 53 and service["public"] == False: - ok = False - - finally: - s.close() - - if ok: + if all_running: output.print_ok("All system services are running.") - return ok + return not fatal + +def check_service(i, service, env): + import socket + output = BufferedOutput() + running = False + fatal = False + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.settimeout(1) + try: + s.connect(( + "127.0.0.1" if not service["public"] else env['PUBLIC_IP'], + service["port"])) + running = True + + except OSError as e: + output.print_error("%s is not running (%s)." % (service['name'], str(e))) + + # Why is nginx not running? + if service["port"] in (80, 443): + output.print_line(shell('check_output', ['nginx', '-t'], capture_stderr=True, trap=True)[1].strip()) + + # Flag if local DNS is not running. + if service["port"] == 53 and service["public"] == False: + fatal = True + finally: + s.close() + + return (i, running, fatal, output) def run_system_checks(env, output): check_ssh_password(env, output) @@ -146,9 +162,10 @@ def check_free_disk_space(env, output): else: output.print_error(disk_msg) -def run_network_checks(env, output): +def run_network_checks(env): # Also see setup/network-checks.sh. + output = BufferedOutput() output.add_heading("Network") # Stop if we cannot make an outbound connection on port 25. Many residential @@ -176,7 +193,9 @@ def run_network_checks(env, output): which may prevent recipients from receiving your email. See http://www.spamhaus.org/query/ip/%s.""" % (env['PUBLIC_IP'], zen, env['PUBLIC_IP'])) -def run_domain_checks(env, output): + return output + +def run_domain_checks(env): # Get the list of domains we handle mail for. mail_domains = get_mail_domains(env) @@ -187,24 +206,44 @@ def run_domain_checks(env, output): # Get the list of domains we serve HTTPS for. web_domains = set(get_web_domains(env)) - # Check the domains. - for domain in sort_domains(mail_domains | dns_domains | web_domains, env): - output.add_heading(domain) + domains_to_check = mail_domains | dns_domains | web_domains - if domain == env["PRIMARY_HOSTNAME"]: - check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles) + # Serial version: + #for domain in sort_domains(domains_to_check, env): + # run_domain_checks_on_domain(domain, env, dns_domains, dns_zonefiles, mail_domains, web_domains) + + # Parallelize the checks across a worker pool. + args = ((domain, env, dns_domains, dns_zonefiles, mail_domains, web_domains) + for domain in domains_to_check) + pool = multiprocessing.pool.Pool(processes=10) + ret = pool.starmap(run_domain_checks_on_domain, args, chunksize=1) + ret = dict(ret) # (domain, output) => { domain: output } + output = BufferedOutput() + for domain in sort_domains(ret, env): + ret[domain].playback(output) + return output + +def run_domain_checks_on_domain(domain, env, dns_domains, dns_zonefiles, mail_domains, web_domains): + output = BufferedOutput() + + output.add_heading(domain) + + if domain == env["PRIMARY_HOSTNAME"]: + check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles) - if domain in dns_domains: - check_dns_zone(domain, env, output, dns_zonefiles) + if domain in dns_domains: + check_dns_zone(domain, env, output, dns_zonefiles) - if domain in mail_domains: - check_mail_domain(domain, env, output) + if domain in mail_domains: + check_mail_domain(domain, env, output) - if domain in web_domains: - check_web_domain(domain, env, output) + if domain in web_domains: + check_web_domain(domain, env, output) - if domain in dns_domains: - check_dns_zone_suggestions(domain, env, output, dns_zonefiles) + if domain in dns_domains: + check_dns_zone_suggestions(domain, env, output, dns_zonefiles) + + return (domain, output) def check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles): # If a DS record is set on the zone containing this domain, check DNSSEC now. @@ -655,11 +694,12 @@ def list_apt_updates(apt_update=True): return pkgs -try: - terminal_columns = int(shell('check_output', ['stty', 'size']).split()[1]) -except: - terminal_columns = 76 class ConsoleOutput: + try: + terminal_columns = int(shell('check_output', ['stty', 'size']).split()[1]) + except: + terminal_columns = 76 + def add_heading(self, heading): print() print(heading) @@ -680,7 +720,7 @@ class ConsoleOutput: words = re.split("(\s+)", message) linelen = 0 for w in words: - if linelen + len(w) > terminal_columns-1-len(first_line): + if linelen + len(w) > self.terminal_columns-1-len(first_line): print() print(" ", end="") linelen = 0 @@ -693,6 +733,21 @@ class ConsoleOutput: for line in message.split("\n"): self.print_block(line) +class BufferedOutput: + # Record all of the instance method calls so we can play them back later. + def __init__(self): + self.buf = [] + def __getattr__(self, attr): + if attr not in ("add_heading", "print_ok", "print_error", "print_warning", "print_block", "print_line"): + raise AttributeError + # Return a function that just records the call & arguments to our buffer. + def w(*args, **kwargs): + self.buf.append((attr, args, kwargs)) + return w + def playback(self, output): + for attr, args, kwargs in self.buf: + getattr(output, attr)(*args, **kwargs) + if __name__ == "__main__": import sys from utils import load_environment