1
0
mirror of https://github.com/mail-in-a-box/mailinabox.git synced 2024-11-22 02:17:26 +00:00

run status checks asynchronously so that they finish faster, since many checks are waiting on network replies and ought not to block the whole thing

This commit is contained in:
Joshua Tauberer 2015-01-31 20:40:20 +00:00
parent 8fd98d7db3
commit 7e05d7478f

View File

@ -6,7 +6,7 @@
__ALL__ = ['check_certificate'] __ALL__ = ['check_certificate']
import os, os.path, re, subprocess, datetime import os, os.path, re, subprocess, datetime, multiprocessing.pool
import dns.reversename, dns.resolver import dns.reversename, dns.resolver
import dateutil.parser, dateutil.tz import dateutil.parser, dateutil.tz
@ -35,15 +35,17 @@ def run_checks(env, output):
run_system_checks(env, output) run_system_checks(env, output)
# perform other checks # perform other checks asynchronously
run_network_checks(env, output)
run_domain_checks(env, output) pool = multiprocessing.pool.Pool(processes=1)
r1 = pool.apply_async(run_network_checks, [env])
r2 = run_domain_checks(env)
r1.get().playback(output)
r2.playback(output)
def run_services_checks(env, output): def run_services_checks(env, output):
# Check that system services are running. # Check that system services are running.
import socket
services = [ services = [
{ "name": "Local DNS (bind9)", "port": 53, "public": False, }, { "name": "Local DNS (bind9)", "port": 53, "public": False, },
#{ "name": "NSD Control", "port": 8952, "public": False, }, #{ "name": "NSD Control", "port": 8952, "public": False, },
@ -66,33 +68,47 @@ def run_services_checks(env, output):
{ "name": "HTTPS Web (nginx)", "port": 443, "public": True, }, { "name": "HTTPS Web (nginx)", "port": 443, "public": True, },
] ]
ok = True all_running = True
fatal = False
pool = multiprocessing.pool.Pool(processes=10)
ret = pool.starmap(check_service, ((i, service, env) for i, service in enumerate(services)), chunksize=1)
for i, running, fatal2, output2 in sorted(ret):
all_running = all_running and running
fatal = fatal or fatal2
output2.playback(output)
for service in services: if all_running:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(.1)
try:
s.connect((
"127.0.0.1" if not service["public"] else env['PUBLIC_IP'],
service["port"]))
except OSError as e:
output.print_error("%s is not running (%s)." % (service['name'], str(e)))
# Why is nginx not running?
if service["port"] in (80, 443):
output.print_line(shell('check_output', ['nginx', '-t'], capture_stderr=True, trap=True)[1].strip())
# Flag if local DNS is not running.
if service["port"] == 53 and service["public"] == False:
ok = False
finally:
s.close()
if ok:
output.print_ok("All system services are running.") output.print_ok("All system services are running.")
return ok return not fatal
def check_service(i, service, env):
import socket
output = BufferedOutput()
running = False
fatal = False
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(1)
try:
s.connect((
"127.0.0.1" if not service["public"] else env['PUBLIC_IP'],
service["port"]))
running = True
except OSError as e:
output.print_error("%s is not running (%s)." % (service['name'], str(e)))
# Why is nginx not running?
if service["port"] in (80, 443):
output.print_line(shell('check_output', ['nginx', '-t'], capture_stderr=True, trap=True)[1].strip())
# Flag if local DNS is not running.
if service["port"] == 53 and service["public"] == False:
fatal = True
finally:
s.close()
return (i, running, fatal, output)
def run_system_checks(env, output): def run_system_checks(env, output):
check_ssh_password(env, output) check_ssh_password(env, output)
@ -146,9 +162,10 @@ def check_free_disk_space(env, output):
else: else:
output.print_error(disk_msg) output.print_error(disk_msg)
def run_network_checks(env, output): def run_network_checks(env):
# Also see setup/network-checks.sh. # Also see setup/network-checks.sh.
output = BufferedOutput()
output.add_heading("Network") output.add_heading("Network")
# Stop if we cannot make an outbound connection on port 25. Many residential # Stop if we cannot make an outbound connection on port 25. Many residential
@ -176,7 +193,9 @@ def run_network_checks(env, output):
which may prevent recipients from receiving your email. See http://www.spamhaus.org/query/ip/%s.""" which may prevent recipients from receiving your email. See http://www.spamhaus.org/query/ip/%s."""
% (env['PUBLIC_IP'], zen, env['PUBLIC_IP'])) % (env['PUBLIC_IP'], zen, env['PUBLIC_IP']))
def run_domain_checks(env, output): return output
def run_domain_checks(env):
# Get the list of domains we handle mail for. # Get the list of domains we handle mail for.
mail_domains = get_mail_domains(env) mail_domains = get_mail_domains(env)
@ -187,24 +206,44 @@ def run_domain_checks(env, output):
# Get the list of domains we serve HTTPS for. # Get the list of domains we serve HTTPS for.
web_domains = set(get_web_domains(env)) web_domains = set(get_web_domains(env))
# Check the domains. domains_to_check = mail_domains | dns_domains | web_domains
for domain in sort_domains(mail_domains | dns_domains | web_domains, env):
output.add_heading(domain)
if domain == env["PRIMARY_HOSTNAME"]: # Serial version:
check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles) #for domain in sort_domains(domains_to_check, env):
# run_domain_checks_on_domain(domain, env, dns_domains, dns_zonefiles, mail_domains, web_domains)
if domain in dns_domains: # Parallelize the checks across a worker pool.
check_dns_zone(domain, env, output, dns_zonefiles) args = ((domain, env, dns_domains, dns_zonefiles, mail_domains, web_domains)
for domain in domains_to_check)
pool = multiprocessing.pool.Pool(processes=10)
ret = pool.starmap(run_domain_checks_on_domain, args, chunksize=1)
ret = dict(ret) # (domain, output) => { domain: output }
output = BufferedOutput()
for domain in sort_domains(ret, env):
ret[domain].playback(output)
return output
if domain in mail_domains: def run_domain_checks_on_domain(domain, env, dns_domains, dns_zonefiles, mail_domains, web_domains):
check_mail_domain(domain, env, output) output = BufferedOutput()
if domain in web_domains: output.add_heading(domain)
check_web_domain(domain, env, output)
if domain in dns_domains: if domain == env["PRIMARY_HOSTNAME"]:
check_dns_zone_suggestions(domain, env, output, dns_zonefiles) check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles)
if domain in dns_domains:
check_dns_zone(domain, env, output, dns_zonefiles)
if domain in mail_domains:
check_mail_domain(domain, env, output)
if domain in web_domains:
check_web_domain(domain, env, output)
if domain in dns_domains:
check_dns_zone_suggestions(domain, env, output, dns_zonefiles)
return (domain, output)
def check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles): def check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles):
# If a DS record is set on the zone containing this domain, check DNSSEC now. # If a DS record is set on the zone containing this domain, check DNSSEC now.
@ -655,11 +694,12 @@ def list_apt_updates(apt_update=True):
return pkgs return pkgs
try:
terminal_columns = int(shell('check_output', ['stty', 'size']).split()[1])
except:
terminal_columns = 76
class ConsoleOutput: class ConsoleOutput:
try:
terminal_columns = int(shell('check_output', ['stty', 'size']).split()[1])
except:
terminal_columns = 76
def add_heading(self, heading): def add_heading(self, heading):
print() print()
print(heading) print(heading)
@ -680,7 +720,7 @@ class ConsoleOutput:
words = re.split("(\s+)", message) words = re.split("(\s+)", message)
linelen = 0 linelen = 0
for w in words: for w in words:
if linelen + len(w) > terminal_columns-1-len(first_line): if linelen + len(w) > self.terminal_columns-1-len(first_line):
print() print()
print(" ", end="") print(" ", end="")
linelen = 0 linelen = 0
@ -693,6 +733,21 @@ class ConsoleOutput:
for line in message.split("\n"): for line in message.split("\n"):
self.print_block(line) self.print_block(line)
class BufferedOutput:
# Record all of the instance method calls so we can play them back later.
def __init__(self):
self.buf = []
def __getattr__(self, attr):
if attr not in ("add_heading", "print_ok", "print_error", "print_warning", "print_block", "print_line"):
raise AttributeError
# Return a function that just records the call & arguments to our buffer.
def w(*args, **kwargs):
self.buf.append((attr, args, kwargs))
return w
def playback(self, output):
for attr, args, kwargs in self.buf:
getattr(output, attr)(*args, **kwargs)
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
from utils import load_environment from utils import load_environment