mirror of
https://github.com/mail-in-a-box/mailinabox.git
synced 2024-11-22 02:17:26 +00:00
run status checks asynchronously so that they finish faster, since many checks are waiting on network replies and ought not to block the whole thing
This commit is contained in:
parent
8fd98d7db3
commit
7e05d7478f
@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
__ALL__ = ['check_certificate']
|
__ALL__ = ['check_certificate']
|
||||||
|
|
||||||
import os, os.path, re, subprocess, datetime
|
import os, os.path, re, subprocess, datetime, multiprocessing.pool
|
||||||
|
|
||||||
import dns.reversename, dns.resolver
|
import dns.reversename, dns.resolver
|
||||||
import dateutil.parser, dateutil.tz
|
import dateutil.parser, dateutil.tz
|
||||||
@ -35,15 +35,17 @@ def run_checks(env, output):
|
|||||||
|
|
||||||
run_system_checks(env, output)
|
run_system_checks(env, output)
|
||||||
|
|
||||||
# perform other checks
|
# perform other checks asynchronously
|
||||||
run_network_checks(env, output)
|
|
||||||
run_domain_checks(env, output)
|
pool = multiprocessing.pool.Pool(processes=1)
|
||||||
|
r1 = pool.apply_async(run_network_checks, [env])
|
||||||
|
r2 = run_domain_checks(env)
|
||||||
|
r1.get().playback(output)
|
||||||
|
r2.playback(output)
|
||||||
|
|
||||||
def run_services_checks(env, output):
|
def run_services_checks(env, output):
|
||||||
# Check that system services are running.
|
# Check that system services are running.
|
||||||
|
|
||||||
import socket
|
|
||||||
|
|
||||||
services = [
|
services = [
|
||||||
{ "name": "Local DNS (bind9)", "port": 53, "public": False, },
|
{ "name": "Local DNS (bind9)", "port": 53, "public": False, },
|
||||||
#{ "name": "NSD Control", "port": 8952, "public": False, },
|
#{ "name": "NSD Control", "port": 8952, "public": False, },
|
||||||
@ -66,15 +68,33 @@ def run_services_checks(env, output):
|
|||||||
{ "name": "HTTPS Web (nginx)", "port": 443, "public": True, },
|
{ "name": "HTTPS Web (nginx)", "port": 443, "public": True, },
|
||||||
]
|
]
|
||||||
|
|
||||||
ok = True
|
all_running = True
|
||||||
|
fatal = False
|
||||||
|
pool = multiprocessing.pool.Pool(processes=10)
|
||||||
|
ret = pool.starmap(check_service, ((i, service, env) for i, service in enumerate(services)), chunksize=1)
|
||||||
|
for i, running, fatal2, output2 in sorted(ret):
|
||||||
|
all_running = all_running and running
|
||||||
|
fatal = fatal or fatal2
|
||||||
|
output2.playback(output)
|
||||||
|
|
||||||
for service in services:
|
if all_running:
|
||||||
|
output.print_ok("All system services are running.")
|
||||||
|
|
||||||
|
return not fatal
|
||||||
|
|
||||||
|
def check_service(i, service, env):
|
||||||
|
import socket
|
||||||
|
output = BufferedOutput()
|
||||||
|
running = False
|
||||||
|
fatal = False
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
s.settimeout(.1)
|
s.settimeout(1)
|
||||||
try:
|
try:
|
||||||
s.connect((
|
s.connect((
|
||||||
"127.0.0.1" if not service["public"] else env['PUBLIC_IP'],
|
"127.0.0.1" if not service["public"] else env['PUBLIC_IP'],
|
||||||
service["port"]))
|
service["port"]))
|
||||||
|
running = True
|
||||||
|
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
output.print_error("%s is not running (%s)." % (service['name'], str(e)))
|
output.print_error("%s is not running (%s)." % (service['name'], str(e)))
|
||||||
|
|
||||||
@ -84,15 +104,11 @@ def run_services_checks(env, output):
|
|||||||
|
|
||||||
# Flag if local DNS is not running.
|
# Flag if local DNS is not running.
|
||||||
if service["port"] == 53 and service["public"] == False:
|
if service["port"] == 53 and service["public"] == False:
|
||||||
ok = False
|
fatal = True
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
s.close()
|
s.close()
|
||||||
|
|
||||||
if ok:
|
return (i, running, fatal, output)
|
||||||
output.print_ok("All system services are running.")
|
|
||||||
|
|
||||||
return ok
|
|
||||||
|
|
||||||
def run_system_checks(env, output):
|
def run_system_checks(env, output):
|
||||||
check_ssh_password(env, output)
|
check_ssh_password(env, output)
|
||||||
@ -146,9 +162,10 @@ def check_free_disk_space(env, output):
|
|||||||
else:
|
else:
|
||||||
output.print_error(disk_msg)
|
output.print_error(disk_msg)
|
||||||
|
|
||||||
def run_network_checks(env, output):
|
def run_network_checks(env):
|
||||||
# Also see setup/network-checks.sh.
|
# Also see setup/network-checks.sh.
|
||||||
|
|
||||||
|
output = BufferedOutput()
|
||||||
output.add_heading("Network")
|
output.add_heading("Network")
|
||||||
|
|
||||||
# Stop if we cannot make an outbound connection on port 25. Many residential
|
# Stop if we cannot make an outbound connection on port 25. Many residential
|
||||||
@ -176,7 +193,9 @@ def run_network_checks(env, output):
|
|||||||
which may prevent recipients from receiving your email. See http://www.spamhaus.org/query/ip/%s."""
|
which may prevent recipients from receiving your email. See http://www.spamhaus.org/query/ip/%s."""
|
||||||
% (env['PUBLIC_IP'], zen, env['PUBLIC_IP']))
|
% (env['PUBLIC_IP'], zen, env['PUBLIC_IP']))
|
||||||
|
|
||||||
def run_domain_checks(env, output):
|
return output
|
||||||
|
|
||||||
|
def run_domain_checks(env):
|
||||||
# Get the list of domains we handle mail for.
|
# Get the list of domains we handle mail for.
|
||||||
mail_domains = get_mail_domains(env)
|
mail_domains = get_mail_domains(env)
|
||||||
|
|
||||||
@ -187,8 +206,26 @@ def run_domain_checks(env, output):
|
|||||||
# Get the list of domains we serve HTTPS for.
|
# Get the list of domains we serve HTTPS for.
|
||||||
web_domains = set(get_web_domains(env))
|
web_domains = set(get_web_domains(env))
|
||||||
|
|
||||||
# Check the domains.
|
domains_to_check = mail_domains | dns_domains | web_domains
|
||||||
for domain in sort_domains(mail_domains | dns_domains | web_domains, env):
|
|
||||||
|
# Serial version:
|
||||||
|
#for domain in sort_domains(domains_to_check, env):
|
||||||
|
# run_domain_checks_on_domain(domain, env, dns_domains, dns_zonefiles, mail_domains, web_domains)
|
||||||
|
|
||||||
|
# Parallelize the checks across a worker pool.
|
||||||
|
args = ((domain, env, dns_domains, dns_zonefiles, mail_domains, web_domains)
|
||||||
|
for domain in domains_to_check)
|
||||||
|
pool = multiprocessing.pool.Pool(processes=10)
|
||||||
|
ret = pool.starmap(run_domain_checks_on_domain, args, chunksize=1)
|
||||||
|
ret = dict(ret) # (domain, output) => { domain: output }
|
||||||
|
output = BufferedOutput()
|
||||||
|
for domain in sort_domains(ret, env):
|
||||||
|
ret[domain].playback(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def run_domain_checks_on_domain(domain, env, dns_domains, dns_zonefiles, mail_domains, web_domains):
|
||||||
|
output = BufferedOutput()
|
||||||
|
|
||||||
output.add_heading(domain)
|
output.add_heading(domain)
|
||||||
|
|
||||||
if domain == env["PRIMARY_HOSTNAME"]:
|
if domain == env["PRIMARY_HOSTNAME"]:
|
||||||
@ -206,6 +243,8 @@ def run_domain_checks(env, output):
|
|||||||
if domain in dns_domains:
|
if domain in dns_domains:
|
||||||
check_dns_zone_suggestions(domain, env, output, dns_zonefiles)
|
check_dns_zone_suggestions(domain, env, output, dns_zonefiles)
|
||||||
|
|
||||||
|
return (domain, output)
|
||||||
|
|
||||||
def check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles):
|
def check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles):
|
||||||
# If a DS record is set on the zone containing this domain, check DNSSEC now.
|
# If a DS record is set on the zone containing this domain, check DNSSEC now.
|
||||||
for zone in dns_domains:
|
for zone in dns_domains:
|
||||||
@ -655,11 +694,12 @@ def list_apt_updates(apt_update=True):
|
|||||||
return pkgs
|
return pkgs
|
||||||
|
|
||||||
|
|
||||||
|
class ConsoleOutput:
|
||||||
try:
|
try:
|
||||||
terminal_columns = int(shell('check_output', ['stty', 'size']).split()[1])
|
terminal_columns = int(shell('check_output', ['stty', 'size']).split()[1])
|
||||||
except:
|
except:
|
||||||
terminal_columns = 76
|
terminal_columns = 76
|
||||||
class ConsoleOutput:
|
|
||||||
def add_heading(self, heading):
|
def add_heading(self, heading):
|
||||||
print()
|
print()
|
||||||
print(heading)
|
print(heading)
|
||||||
@ -680,7 +720,7 @@ class ConsoleOutput:
|
|||||||
words = re.split("(\s+)", message)
|
words = re.split("(\s+)", message)
|
||||||
linelen = 0
|
linelen = 0
|
||||||
for w in words:
|
for w in words:
|
||||||
if linelen + len(w) > terminal_columns-1-len(first_line):
|
if linelen + len(w) > self.terminal_columns-1-len(first_line):
|
||||||
print()
|
print()
|
||||||
print(" ", end="")
|
print(" ", end="")
|
||||||
linelen = 0
|
linelen = 0
|
||||||
@ -693,6 +733,21 @@ class ConsoleOutput:
|
|||||||
for line in message.split("\n"):
|
for line in message.split("\n"):
|
||||||
self.print_block(line)
|
self.print_block(line)
|
||||||
|
|
||||||
|
class BufferedOutput:
|
||||||
|
# Record all of the instance method calls so we can play them back later.
|
||||||
|
def __init__(self):
|
||||||
|
self.buf = []
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
if attr not in ("add_heading", "print_ok", "print_error", "print_warning", "print_block", "print_line"):
|
||||||
|
raise AttributeError
|
||||||
|
# Return a function that just records the call & arguments to our buffer.
|
||||||
|
def w(*args, **kwargs):
|
||||||
|
self.buf.append((attr, args, kwargs))
|
||||||
|
return w
|
||||||
|
def playback(self, output):
|
||||||
|
for attr, args, kwargs in self.buf:
|
||||||
|
getattr(output, attr)(*args, **kwargs)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
from utils import load_environment
|
from utils import load_environment
|
||||||
|
Loading…
Reference in New Issue
Block a user