1
0
mirror of https://github.com/mail-in-a-box/mailinabox.git synced 2024-11-23 02:27:05 +00:00

Fixed errors found by the Ruff Python linter (#2343)

This commit is contained in:
Joshua Tauberer 2024-03-10 07:57:19 -04:00 committed by GitHub
commit 315d2cf691
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 423 additions and 469 deletions

View File

@ -1,4 +1,4 @@
import base64, os, os.path, hmac, json, secrets import base64, hmac, json, secrets
from datetime import timedelta from datetime import timedelta
from expiringdict import ExpiringDict from expiringdict import ExpiringDict
@ -22,7 +22,7 @@ class AuthService:
def init_system_api_key(self): def init_system_api_key(self):
"""Write an API key to a local file so local processes can use the API""" """Write an API key to a local file so local processes can use the API"""
with open(self.key_path, 'r') as file: with open(self.key_path, encoding='utf-8') as file:
self.key = file.read() self.key = file.read()
def authenticate(self, request, env, login_only=False, logout=False): def authenticate(self, request, env, login_only=False, logout=False):
@ -48,11 +48,13 @@ class AuthService:
return username, password return username, password
username, password = parse_http_authorization_basic(request.headers.get('Authorization', '')) username, password = parse_http_authorization_basic(request.headers.get('Authorization', ''))
if username in (None, ""): if username in {None, ""}:
raise ValueError("Authorization header invalid.") msg = "Authorization header invalid."
raise ValueError(msg)
if username.strip() == "" and password.strip() == "": if username.strip() == "" and password.strip() == "":
raise ValueError("No email address, password, session key, or API key provided.") msg = "No email address, password, session key, or API key provided."
raise ValueError(msg)
# If user passed the system API key, grant administrative privs. This key # If user passed the system API key, grant administrative privs. This key
# is not associated with a user. # is not associated with a user.
@ -72,7 +74,8 @@ class AuthService:
# If no password was given, but a username was given, we're missing some information. # If no password was given, but a username was given, we're missing some information.
elif password.strip() == "": elif password.strip() == "":
raise ValueError("Enter a password.") msg = "Enter a password."
raise ValueError(msg)
else: else:
# The user is trying to log in with a username and a password # The user is trying to log in with a username and a password
@ -114,7 +117,8 @@ class AuthService:
]) ])
except: except:
# Login failed. # Login failed.
raise ValueError("Incorrect email address or password.") msg = "Incorrect email address or password."
raise ValueError(msg)
# If MFA is enabled, check that MFA passes. # If MFA is enabled, check that MFA passes.
status, hints = validate_auth_mfa(email, request, env) status, hints = validate_auth_mfa(email, request, env)

View File

@ -7,7 +7,7 @@
# 4) The stopped services are restarted. # 4) The stopped services are restarted.
# 5) STORAGE_ROOT/backup/after-backup is executed if it exists. # 5) STORAGE_ROOT/backup/after-backup is executed if it exists.
import os, os.path, shutil, glob, re, datetime, sys import os, os.path, re, datetime, sys
import dateutil.parser, dateutil.relativedelta, dateutil.tz import dateutil.parser, dateutil.relativedelta, dateutil.tz
import rtyaml import rtyaml
from exclusiveprocess import Lock from exclusiveprocess import Lock
@ -59,7 +59,7 @@ def backup_status(env):
"--archive-dir", backup_cache_dir, "--archive-dir", backup_cache_dir,
"--gpg-options", "'--cipher-algo=AES256'", "--gpg-options", "'--cipher-algo=AES256'",
"--log-fd", "1", "--log-fd", "1",
] + get_duplicity_additional_args(env) + [ *get_duplicity_additional_args(env),
get_duplicity_target_url(config) get_duplicity_target_url(config)
], ],
get_duplicity_env_vars(env), get_duplicity_env_vars(env),
@ -69,7 +69,7 @@ def backup_status(env):
# destination for the backups or the last backup job terminated unexpectedly. # destination for the backups or the last backup job terminated unexpectedly.
raise Exception("Something is wrong with the backup: " + collection_status) raise Exception("Something is wrong with the backup: " + collection_status)
for line in collection_status.split('\n'): for line in collection_status.split('\n'):
if line.startswith(" full") or line.startswith(" inc"): if line.startswith((" full", " inc")):
backup = parse_line(line) backup = parse_line(line)
backups[backup["date"]] = backup backups[backup["date"]] = backup
@ -185,7 +185,7 @@ def get_passphrase(env):
# only needs to be 43 base64-characters to match AES256's key # only needs to be 43 base64-characters to match AES256's key
# length of 32 bytes. # length of 32 bytes.
backup_root = os.path.join(env["STORAGE_ROOT"], 'backup') backup_root = os.path.join(env["STORAGE_ROOT"], 'backup')
with open(os.path.join(backup_root, 'secret_key.txt')) as f: with open(os.path.join(backup_root, 'secret_key.txt'), encoding="utf-8") as f:
passphrase = f.readline().strip() passphrase = f.readline().strip()
if len(passphrase) < 43: raise Exception("secret_key.txt's first line is too short!") if len(passphrase) < 43: raise Exception("secret_key.txt's first line is too short!")
@ -257,8 +257,7 @@ def get_duplicity_env_vars(env):
return env return env
def get_target_type(config): def get_target_type(config):
protocol = config["target"].split(":")[0] return config["target"].split(":")[0]
return protocol
def perform_backup(full_backup): def perform_backup(full_backup):
env = load_environment() env = load_environment()
@ -323,8 +322,8 @@ def perform_backup(full_backup):
"--exclude", backup_root, "--exclude", backup_root,
"--volsize", "250", "--volsize", "250",
"--gpg-options", "'--cipher-algo=AES256'", "--gpg-options", "'--cipher-algo=AES256'",
"--allow-source-mismatch" "--allow-source-mismatch",
] + get_duplicity_additional_args(env) + [ *get_duplicity_additional_args(env),
env["STORAGE_ROOT"], env["STORAGE_ROOT"],
get_duplicity_target_url(config), get_duplicity_target_url(config),
], ],
@ -345,7 +344,7 @@ def perform_backup(full_backup):
"--verbosity", "error", "--verbosity", "error",
"--archive-dir", backup_cache_dir, "--archive-dir", backup_cache_dir,
"--force", "--force",
] + get_duplicity_additional_args(env) + [ *get_duplicity_additional_args(env),
get_duplicity_target_url(config) get_duplicity_target_url(config)
], ],
get_duplicity_env_vars(env)) get_duplicity_env_vars(env))
@ -361,7 +360,7 @@ def perform_backup(full_backup):
"--verbosity", "error", "--verbosity", "error",
"--archive-dir", backup_cache_dir, "--archive-dir", backup_cache_dir,
"--force", "--force",
] + get_duplicity_additional_args(env) + [ *get_duplicity_additional_args(env),
get_duplicity_target_url(config) get_duplicity_target_url(config)
], ],
get_duplicity_env_vars(env)) get_duplicity_env_vars(env))
@ -400,7 +399,7 @@ def run_duplicity_verification():
"--compare-data", "--compare-data",
"--archive-dir", backup_cache_dir, "--archive-dir", backup_cache_dir,
"--exclude", backup_root, "--exclude", backup_root,
] + get_duplicity_additional_args(env) + [ *get_duplicity_additional_args(env),
get_duplicity_target_url(config), get_duplicity_target_url(config),
env["STORAGE_ROOT"], env["STORAGE_ROOT"],
], get_duplicity_env_vars(env)) ], get_duplicity_env_vars(env))
@ -413,9 +412,9 @@ def run_duplicity_restore(args):
"/usr/bin/duplicity", "/usr/bin/duplicity",
"restore", "restore",
"--archive-dir", backup_cache_dir, "--archive-dir", backup_cache_dir,
] + get_duplicity_additional_args(env) + [ *get_duplicity_additional_args(env),
get_duplicity_target_url(config) get_duplicity_target_url(config),
] + args, *args],
get_duplicity_env_vars(env)) get_duplicity_env_vars(env))
def print_duplicity_command(): def print_duplicity_command():
@ -427,7 +426,7 @@ def print_duplicity_command():
print(f"export {k}={shlex.quote(v)}") print(f"export {k}={shlex.quote(v)}")
print("duplicity", "{command}", shlex.join([ print("duplicity", "{command}", shlex.join([
"--archive-dir", backup_cache_dir, "--archive-dir", backup_cache_dir,
] + get_duplicity_additional_args(env) + [ *get_duplicity_additional_args(env),
get_duplicity_target_url(config) get_duplicity_target_url(config)
])) ]))
@ -483,16 +482,17 @@ def list_target_files(config):
if 'Permission denied (publickey).' in listing: if 'Permission denied (publickey).' in listing:
reason = "Invalid user or check you correctly copied the SSH key." reason = "Invalid user or check you correctly copied the SSH key."
elif 'No such file or directory' in listing: elif 'No such file or directory' in listing:
reason = "Provided path {} is invalid.".format(target_path) reason = f"Provided path {target_path} is invalid."
elif 'Network is unreachable' in listing: elif 'Network is unreachable' in listing:
reason = "The IP address {} is unreachable.".format(target.hostname) reason = f"The IP address {target.hostname} is unreachable."
elif 'Could not resolve hostname' in listing: elif 'Could not resolve hostname' in listing:
reason = "The hostname {} cannot be resolved.".format(target.hostname) reason = f"The hostname {target.hostname} cannot be resolved."
else: else:
reason = "Unknown error." \ reason = ("Unknown error."
"Please check running 'management/backup.py --verify'" \ "Please check running 'management/backup.py --verify'"
"from mailinabox sources to debug the issue." "from mailinabox sources to debug the issue.")
raise ValueError("Connection to rsync host failed: {}".format(reason)) msg = f"Connection to rsync host failed: {reason}"
raise ValueError(msg)
elif target.scheme == "s3": elif target.scheme == "s3":
import boto3.s3 import boto3.s3
@ -507,7 +507,8 @@ def list_target_files(config):
path = '' path = ''
if bucket == "": if bucket == "":
raise ValueError("Enter an S3 bucket name.") msg = "Enter an S3 bucket name."
raise ValueError(msg)
# connect to the region & bucket # connect to the region & bucket
try: try:
@ -534,8 +535,9 @@ def list_target_files(config):
try: try:
b2_api.authorize_account("production", b2_application_keyid, b2_application_key) b2_api.authorize_account("production", b2_application_keyid, b2_application_key)
bucket = b2_api.get_bucket_by_name(b2_bucket) bucket = b2_api.get_bucket_by_name(b2_bucket)
except NonExistentBucket as e: except NonExistentBucket:
raise ValueError("B2 Bucket does not exist. Please double check your information!") msg = "B2 Bucket does not exist. Please double check your information!"
raise ValueError(msg)
return [(key.file_name, key.size) for key, _ in bucket.ls()] return [(key.file_name, key.size) for key, _ in bucket.ls()]
else: else:
@ -556,7 +558,7 @@ def backup_set_custom(env, target, target_user, target_pass, min_age):
# Validate. # Validate.
try: try:
if config["target"] not in ("off", "local"): if config["target"] not in {"off", "local"}:
# these aren't supported by the following function, which expects a full url in the target key, # these aren't supported by the following function, which expects a full url in the target key,
# which is what is there except when loading the config prior to saving # which is what is there except when loading the config prior to saving
list_target_files(config) list_target_files(config)
@ -578,9 +580,9 @@ def get_backup_config(env, for_save=False, for_ui=False):
# Merge in anything written to custom.yaml. # Merge in anything written to custom.yaml.
try: try:
with open(os.path.join(backup_root, 'custom.yaml'), 'r') as f: with open(os.path.join(backup_root, 'custom.yaml'), encoding="utf-8") as f:
custom_config = rtyaml.load(f) custom_config = rtyaml.load(f)
if not isinstance(custom_config, dict): raise ValueError() # caught below if not isinstance(custom_config, dict): raise ValueError # caught below
config.update(custom_config) config.update(custom_config)
except: except:
pass pass
@ -604,18 +606,17 @@ def get_backup_config(env, for_save=False, for_ui=False):
config["target"] = "file://" + config["file_target_directory"] config["target"] = "file://" + config["file_target_directory"]
ssh_pub_key = os.path.join('/root', '.ssh', 'id_rsa_miab.pub') ssh_pub_key = os.path.join('/root', '.ssh', 'id_rsa_miab.pub')
if os.path.exists(ssh_pub_key): if os.path.exists(ssh_pub_key):
with open(ssh_pub_key, 'r') as f: with open(ssh_pub_key, encoding="utf-8") as f:
config["ssh_pub_key"] = f.read() config["ssh_pub_key"] = f.read()
return config return config
def write_backup_config(env, newconfig): def write_backup_config(env, newconfig):
backup_root = os.path.join(env["STORAGE_ROOT"], 'backup') backup_root = os.path.join(env["STORAGE_ROOT"], 'backup')
with open(os.path.join(backup_root, 'custom.yaml'), "w") as f: with open(os.path.join(backup_root, 'custom.yaml'), "w", encoding="utf-8") as f:
f.write(rtyaml.dump(newconfig)) f.write(rtyaml.dump(newconfig))
if __name__ == "__main__": if __name__ == "__main__":
import sys
if sys.argv[-1] == "--verify": if sys.argv[-1] == "--verify":
# Run duplicity's verification command to check a) the backup files # Run duplicity's verification command to check a) the backup files
# are readable, and b) report if they are up to date. # are readable, and b) report if they are up to date.
@ -624,7 +625,7 @@ if __name__ == "__main__":
elif sys.argv[-1] == "--list": elif sys.argv[-1] == "--list":
# List the saved backup files. # List the saved backup files.
for fn, size in list_target_files(get_backup_config(load_environment())): for fn, size in list_target_files(get_backup_config(load_environment())):
print("{}\t{}".format(fn, size)) print(f"{fn}\t{size}")
elif sys.argv[-1] == "--status": elif sys.argv[-1] == "--status":
# Show backup status. # Show backup status.

View File

@ -6,7 +6,8 @@
# root API key. This file is readable only by root, so this # root API key. This file is readable only by root, so this
# tool can only be used as root. # tool can only be used as root.
import sys, getpass, urllib.request, urllib.error, json, re, csv import sys, getpass, urllib.request, urllib.error, json, csv
import contextlib
def mgmt(cmd, data=None, is_json=False): def mgmt(cmd, data=None, is_json=False):
# The base URL for the management daemon. (Listens on IPv4 only.) # The base URL for the management daemon. (Listens on IPv4 only.)
@ -19,10 +20,8 @@ def mgmt(cmd, data=None, is_json=False):
response = urllib.request.urlopen(req) response = urllib.request.urlopen(req)
except urllib.error.HTTPError as e: except urllib.error.HTTPError as e:
if e.code == 401: if e.code == 401:
try: with contextlib.suppress(Exception):
print(e.read().decode("utf8")) print(e.read().decode("utf8"))
except:
pass
print("The management daemon refused access. The API key file may be out of sync. Try 'service mailinabox restart'.", file=sys.stderr) print("The management daemon refused access. The API key file may be out of sync. Try 'service mailinabox restart'.", file=sys.stderr)
elif hasattr(e, 'read'): elif hasattr(e, 'read'):
print(e.read().decode('utf8'), file=sys.stderr) print(e.read().decode('utf8'), file=sys.stderr)
@ -47,7 +46,7 @@ def read_password():
return first return first
def setup_key_auth(mgmt_uri): def setup_key_auth(mgmt_uri):
with open('/var/lib/mailinabox/api.key', 'r') as f: with open('/var/lib/mailinabox/api.key', encoding='utf-8') as f:
key = f.read().strip() key = f.read().strip()
auth_handler = urllib.request.HTTPBasicAuthHandler() auth_handler = urllib.request.HTTPBasicAuthHandler()
@ -91,12 +90,9 @@ elif sys.argv[1] == "user" and len(sys.argv) == 2:
print("*", end='') print("*", end='')
print() print()
elif sys.argv[1] == "user" and sys.argv[2] in ("add", "password"): elif sys.argv[1] == "user" and sys.argv[2] in {"add", "password"}:
if len(sys.argv) < 5: if len(sys.argv) < 5:
if len(sys.argv) < 4: email = input('email: ') if len(sys.argv) < 4 else sys.argv[3]
email = input("email: ")
else:
email = sys.argv[3]
pw = read_password() pw = read_password()
else: else:
email, pw = sys.argv[3:5] email, pw = sys.argv[3:5]
@ -109,11 +105,8 @@ elif sys.argv[1] == "user" and sys.argv[2] in ("add", "password"):
elif sys.argv[1] == "user" and sys.argv[2] == "remove" and len(sys.argv) == 4: elif sys.argv[1] == "user" and sys.argv[2] == "remove" and len(sys.argv) == 4:
print(mgmt("/mail/users/remove", { "email": sys.argv[3] })) print(mgmt("/mail/users/remove", { "email": sys.argv[3] }))
elif sys.argv[1] == "user" and sys.argv[2] in ("make-admin", "remove-admin") and len(sys.argv) == 4: elif sys.argv[1] == "user" and sys.argv[2] in {"make-admin", "remove-admin"} and len(sys.argv) == 4:
if sys.argv[2] == "make-admin": action = 'add' if sys.argv[2] == 'make-admin' else 'remove'
action = "add"
else:
action = "remove"
print(mgmt("/mail/users/privileges/" + action, { "email": sys.argv[3], "privilege": "admin" })) print(mgmt("/mail/users/privileges/" + action, { "email": sys.argv[3], "privilege": "admin" }))
elif sys.argv[1] == "user" and sys.argv[2] == "admins": elif sys.argv[1] == "user" and sys.argv[2] == "admins":
@ -132,7 +125,7 @@ elif sys.argv[1] == "user" and len(sys.argv) == 5 and sys.argv[2:4] == ["mfa", "
for mfa in status["enabled_mfa"]: for mfa in status["enabled_mfa"]:
W.writerow([mfa["id"], mfa["type"], mfa["label"]]) W.writerow([mfa["id"], mfa["type"], mfa["label"]])
elif sys.argv[1] == "user" and len(sys.argv) in (5, 6) and sys.argv[2:4] == ["mfa", "disable"]: elif sys.argv[1] == "user" and len(sys.argv) in {5, 6} and sys.argv[2:4] == ["mfa", "disable"]:
# Disable MFA (all or a particular device) for a user. # Disable MFA (all or a particular device) for a user.
print(mgmt("/mfa/disable", { "user": sys.argv[4], "mfa-id": sys.argv[5] if len(sys.argv) == 6 else None })) print(mgmt("/mfa/disable", { "user": sys.argv[4], "mfa-id": sys.argv[5] if len(sys.argv) == 6 else None }))

View File

@ -11,17 +11,18 @@
# service mailinabox start # when done debugging, start it up again # service mailinabox start # when done debugging, start it up again
import os, os.path, re, json, time import os, os.path, re, json, time
import multiprocessing.pool, subprocess import multiprocessing.pool
from functools import wraps from functools import wraps
from flask import Flask, request, render_template, abort, Response, send_from_directory, make_response from flask import Flask, request, render_template, Response, send_from_directory, make_response
import auth, utils import auth, utils
from mailconfig import get_mail_users, get_mail_users_ex, get_admins, add_mail_user, set_mail_password, remove_mail_user from mailconfig import get_mail_users, get_mail_users_ex, get_admins, add_mail_user, set_mail_password, remove_mail_user
from mailconfig import get_mail_user_privileges, add_remove_mail_user_privilege from mailconfig import get_mail_user_privileges, add_remove_mail_user_privilege
from mailconfig import get_mail_aliases, get_mail_aliases_ex, get_mail_domains, add_mail_alias, remove_mail_alias from mailconfig import get_mail_aliases, get_mail_aliases_ex, get_mail_domains, add_mail_alias, remove_mail_alias
from mfa import get_public_mfa_state, provision_totp, validate_totp_secret, enable_mfa, disable_mfa from mfa import get_public_mfa_state, provision_totp, validate_totp_secret, enable_mfa, disable_mfa
import contextlib
env = utils.load_environment() env = utils.load_environment()
@ -29,14 +30,12 @@ auth_service = auth.AuthService()
# We may deploy via a symbolic link, which confuses flask's template finding. # We may deploy via a symbolic link, which confuses flask's template finding.
me = __file__ me = __file__
try: with contextlib.suppress(OSError):
me = os.readlink(__file__) me = os.readlink(__file__)
except OSError:
pass
# for generating CSRs we need a list of country codes # for generating CSRs we need a list of country codes
csr_country_codes = [] csr_country_codes = []
with open(os.path.join(os.path.dirname(me), "csr_country_codes.tsv")) as f: with open(os.path.join(os.path.dirname(me), "csr_country_codes.tsv"), encoding="utf-8") as f:
for line in f: for line in f:
if line.strip() == "" or line.startswith("#"): continue if line.strip() == "" or line.startswith("#"): continue
code, name = line.strip().split("\t")[0:2] code, name = line.strip().split("\t")[0:2]
@ -80,7 +79,7 @@ def authorized_personnel_only(viewfunc):
# Not authorized. Return a 401 (send auth) and a prompt to authorize by default. # Not authorized. Return a 401 (send auth) and a prompt to authorize by default.
status = 401 status = 401
headers = { headers = {
'WWW-Authenticate': 'Basic realm="{0}"'.format(auth_service.auth_realm), 'WWW-Authenticate': f'Basic realm="{auth_service.auth_realm}"',
'X-Reason': error, 'X-Reason': error,
} }
@ -90,7 +89,7 @@ def authorized_personnel_only(viewfunc):
status = 403 status = 403
headers = None headers = None
if request.headers.get('Accept') in (None, "", "*/*"): if request.headers.get('Accept') in {None, "", "*/*"}:
# Return plain text output. # Return plain text output.
return Response(error+"\n", status=status, mimetype='text/plain', headers=headers) return Response(error+"\n", status=status, mimetype='text/plain', headers=headers)
else: else:
@ -164,7 +163,7 @@ def login():
"api_key": auth_service.create_session_key(email, env, type='login'), "api_key": auth_service.create_session_key(email, env, type='login'),
} }
app.logger.info("New login session created for {}".format(email)) app.logger.info(f"New login session created for {email}")
# Return. # Return.
return json_response(resp) return json_response(resp)
@ -173,8 +172,8 @@ def login():
def logout(): def logout():
try: try:
email, _ = auth_service.authenticate(request, env, logout=True) email, _ = auth_service.authenticate(request, env, logout=True)
app.logger.info("{} logged out".format(email)) app.logger.info(f"{email} logged out")
except ValueError as e: except ValueError:
pass pass
finally: finally:
return json_response({ "status": "ok" }) return json_response({ "status": "ok" })
@ -355,9 +354,9 @@ def dns_set_record(qname, rtype="A"):
# Get the existing records matching the qname and rtype. # Get the existing records matching the qname and rtype.
return dns_get_records(qname, rtype) return dns_get_records(qname, rtype)
elif request.method in ("POST", "PUT"): elif request.method in {"POST", "PUT"}:
# There is a default value for A/AAAA records. # There is a default value for A/AAAA records.
if rtype in ("A", "AAAA") and value == "": if rtype in {"A", "AAAA"} and value == "":
value = request.environ.get("HTTP_X_FORWARDED_FOR") # normally REMOTE_ADDR but we're behind nginx as a reverse proxy value = request.environ.get("HTTP_X_FORWARDED_FOR") # normally REMOTE_ADDR but we're behind nginx as a reverse proxy
# Cannot add empty records. # Cannot add empty records.
@ -419,7 +418,7 @@ def ssl_get_status():
{ {
"domain": d["domain"], "domain": d["domain"],
"status": d["ssl_certificate"][0], "status": d["ssl_certificate"][0],
"text": d["ssl_certificate"][1] + ((" " + cant_provision[d["domain"]] if d["domain"] in cant_provision else "")) "text": d["ssl_certificate"][1] + (" " + cant_provision[d["domain"]] if d["domain"] in cant_provision else "")
} for d in domains_status ] } for d in domains_status ]
# Warn the user about domain names not hosted here because of other settings. # Warn the user about domain names not hosted here because of other settings.
@ -491,7 +490,7 @@ def totp_post_enable():
secret = request.form.get('secret') secret = request.form.get('secret')
token = request.form.get('token') token = request.form.get('token')
label = request.form.get('label') label = request.form.get('label')
if type(token) != str: if not isinstance(token, str):
return ("Bad Input", 400) return ("Bad Input", 400)
try: try:
validate_totp_secret(secret) validate_totp_secret(secret)
@ -580,8 +579,7 @@ def system_status():
def show_updates(): def show_updates():
from status_checks import list_apt_updates from status_checks import list_apt_updates
return "".join( return "".join(
"%s (%s)\n" "{} ({})\n".format(p["package"], p["version"])
% (p["package"], p["version"])
for p in list_apt_updates()) for p in list_apt_updates())
@app.route('/system/update-packages', methods=["POST"]) @app.route('/system/update-packages', methods=["POST"])
@ -751,14 +749,11 @@ def log_failed_login(request):
# During setup we call the management interface directly to determine the user # During setup we call the management interface directly to determine the user
# status. So we can't always use X-Forwarded-For because during setup that header # status. So we can't always use X-Forwarded-For because during setup that header
# will not be present. # will not be present.
if request.headers.getlist("X-Forwarded-For"): ip = request.headers.getlist("X-Forwarded-For")[0] if request.headers.getlist("X-Forwarded-For") else request.remote_addr
ip = request.headers.getlist("X-Forwarded-For")[0]
else:
ip = request.remote_addr
# We need to add a timestamp to the log message, otherwise /dev/log will eat the "duplicate" # We need to add a timestamp to the log message, otherwise /dev/log will eat the "duplicate"
# message. # message.
app.logger.warning( "Mail-in-a-Box Management Daemon: Failed login attempt from ip %s - timestamp %s" % (ip, time.time())) app.logger.warning( f"Mail-in-a-Box Management Daemon: Failed login attempt from ip {ip} - timestamp {time.time()}")
# APP # APP

View File

@ -4,19 +4,20 @@
# and mail aliases and restarts nsd. # and mail aliases and restarts nsd.
######################################################################## ########################################################################
import sys, os, os.path, urllib.parse, datetime, re, hashlib, base64 import sys, os, os.path, datetime, re, hashlib, base64
import ipaddress import ipaddress
import rtyaml import rtyaml
import dns.resolver import dns.resolver
from utils import shell, load_env_vars_from_file, safe_domain_name, sort_domains from utils import shell, load_env_vars_from_file, safe_domain_name, sort_domains
from ssl_certificates import get_ssl_certificates, check_certificate from ssl_certificates import get_ssl_certificates, check_certificate
import contextlib
# From https://stackoverflow.com/questions/3026957/how-to-validate-a-domain-name-using-regex-php/16491074#16491074 # From https://stackoverflow.com/questions/3026957/how-to-validate-a-domain-name-using-regex-php/16491074#16491074
# This regular expression matches domain names according to RFCs, it also accepts fqdn with an leading dot, # This regular expression matches domain names according to RFCs, it also accepts fqdn with an leading dot,
# underscores, as well as asteriks which are allowed in domain names but not hostnames (i.e. allowed in # underscores, as well as asteriks which are allowed in domain names but not hostnames (i.e. allowed in
# DNS but not in URLs), which are common in certain record types like for DKIM. # DNS but not in URLs), which are common in certain record types like for DKIM.
DOMAIN_RE = "^(?!\-)(?:[*][.])?(?:[a-zA-Z\d\-_]{0,62}[a-zA-Z\d_]\.){1,126}(?!\d+)[a-zA-Z\d_]{1,63}(\.?)$" DOMAIN_RE = r"^(?!\-)(?:[*][.])?(?:[a-zA-Z\d\-_]{0,62}[a-zA-Z\d_]\.){1,126}(?!\d+)[a-zA-Z\d_]{1,63}(\.?)$"
def get_dns_domains(env): def get_dns_domains(env):
# Add all domain names in use by email users and mail aliases, any # Add all domain names in use by email users and mail aliases, any
@ -38,7 +39,7 @@ def get_dns_zones(env):
# Exclude domains that are subdomains of other domains we know. Proceed # Exclude domains that are subdomains of other domains we know. Proceed
# by looking at shorter domains first. # by looking at shorter domains first.
zone_domains = set() zone_domains = set()
for domain in sorted(domains, key=lambda d : len(d)): for domain in sorted(domains, key=len):
for d in zone_domains: for d in zone_domains:
if domain.endswith("." + d): if domain.endswith("." + d):
# We found a parent domain already in the list. # We found a parent domain already in the list.
@ -48,9 +49,7 @@ def get_dns_zones(env):
zone_domains.add(domain) zone_domains.add(domain)
# Make a nice and safe filename for each domain. # Make a nice and safe filename for each domain.
zonefiles = [] zonefiles = [[domain, safe_domain_name(domain) + ".txt"] for domain in zone_domains]
for domain in zone_domains:
zonefiles.append([domain, safe_domain_name(domain) + ".txt"])
# Sort the list so that the order is nice and so that nsd.conf has a # Sort the list so that the order is nice and so that nsd.conf has a
# stable order so we don't rewrite the file & restart the service # stable order so we don't rewrite the file & restart the service
@ -194,8 +193,7 @@ def build_zone(domain, domain_properties, additional_records, env, is_zone=True)
# User may provide one or more additional nameservers # User may provide one or more additional nameservers
secondary_ns_list = get_secondary_dns(additional_records, mode="NS") \ secondary_ns_list = get_secondary_dns(additional_records, mode="NS") \
or ["ns2." + env["PRIMARY_HOSTNAME"]] or ["ns2." + env["PRIMARY_HOSTNAME"]]
for secondary_ns in secondary_ns_list: records.extend((None, "NS", secondary_ns+'.', False) for secondary_ns in secondary_ns_list)
records.append((None, "NS", secondary_ns+'.', False))
# In PRIMARY_HOSTNAME... # In PRIMARY_HOSTNAME...
@ -212,8 +210,7 @@ def build_zone(domain, domain_properties, additional_records, env, is_zone=True)
records.append(("_443._tcp", "TLSA", build_tlsa_record(env), "Optional. When DNSSEC is enabled, provides out-of-band HTTPS certificate validation for a few web clients that support it.")) records.append(("_443._tcp", "TLSA", build_tlsa_record(env), "Optional. When DNSSEC is enabled, provides out-of-band HTTPS certificate validation for a few web clients that support it."))
# Add a SSHFP records to help SSH key validation. One per available SSH key on this system. # Add a SSHFP records to help SSH key validation. One per available SSH key on this system.
for value in build_sshfp_records(): records.extend((None, "SSHFP", value, "Optional. Provides an out-of-band method for verifying an SSH key before connecting. Use 'VerifyHostKeyDNS yes' (or 'VerifyHostKeyDNS ask') when connecting with ssh.") for value in build_sshfp_records())
records.append((None, "SSHFP", value, "Optional. Provides an out-of-band method for verifying an SSH key before connecting. Use 'VerifyHostKeyDNS yes' (or 'VerifyHostKeyDNS ask') when connecting with ssh."))
# Add DNS records for any subdomains of this domain. We should not have a zone for # Add DNS records for any subdomains of this domain. We should not have a zone for
# both a domain and one of its subdomains. # both a domain and one of its subdomains.
@ -223,7 +220,7 @@ def build_zone(domain, domain_properties, additional_records, env, is_zone=True)
subdomain_qname = subdomain[0:-len("." + domain)] subdomain_qname = subdomain[0:-len("." + domain)]
subzone = build_zone(subdomain, domain_properties, additional_records, env, is_zone=False) subzone = build_zone(subdomain, domain_properties, additional_records, env, is_zone=False)
for child_qname, child_rtype, child_value, child_explanation in subzone: for child_qname, child_rtype, child_value, child_explanation in subzone:
if child_qname == None: if child_qname is None:
child_qname = subdomain_qname child_qname = subdomain_qname
else: else:
child_qname += "." + subdomain_qname child_qname += "." + subdomain_qname
@ -231,10 +228,7 @@ def build_zone(domain, domain_properties, additional_records, env, is_zone=True)
has_rec_base = list(records) # clone current state has_rec_base = list(records) # clone current state
def has_rec(qname, rtype, prefix=None): def has_rec(qname, rtype, prefix=None):
for rec in has_rec_base: return any(rec[0] == qname and rec[1] == rtype and (prefix is None or rec[2].startswith(prefix)) for rec in has_rec_base)
if rec[0] == qname and rec[1] == rtype and (prefix is None or rec[2].startswith(prefix)):
return True
return False
# The user may set other records that don't conflict with our settings. # The user may set other records that don't conflict with our settings.
# Don't put any TXT records above this line, or it'll prevent any custom TXT records. # Don't put any TXT records above this line, or it'll prevent any custom TXT records.
@ -262,7 +256,7 @@ def build_zone(domain, domain_properties, additional_records, env, is_zone=True)
has_rec_base = list(records) has_rec_base = list(records)
a_expl = "Required. May have a different value. Sets the IP address that %s resolves to for web hosting and other services besides mail. The A record must be present but its value does not affect mail delivery." % domain a_expl = "Required. May have a different value. Sets the IP address that %s resolves to for web hosting and other services besides mail. The A record must be present but its value does not affect mail delivery." % domain
if domain_properties[domain]["auto"]: if domain_properties[domain]["auto"]:
if domain.startswith("ns1.") or domain.startswith("ns2."): a_expl = False # omit from 'External DNS' page since this only applies if box is its own DNS server if domain.startswith(("ns1.", "ns2.")): a_expl = False # omit from 'External DNS' page since this only applies if box is its own DNS server
if domain.startswith("www."): a_expl = "Optional. Sets the IP address that %s resolves to so that the box can provide a redirect to the parent domain." % domain if domain.startswith("www."): a_expl = "Optional. Sets the IP address that %s resolves to so that the box can provide a redirect to the parent domain." % domain
if domain.startswith("mta-sts."): a_expl = "Optional. MTA-STS Policy Host serving /.well-known/mta-sts.txt." if domain.startswith("mta-sts."): a_expl = "Optional. MTA-STS Policy Host serving /.well-known/mta-sts.txt."
if domain.startswith("autoconfig."): a_expl = "Provides email configuration autodiscovery support for Thunderbird Autoconfig." if domain.startswith("autoconfig."): a_expl = "Provides email configuration autodiscovery support for Thunderbird Autoconfig."
@ -298,7 +292,7 @@ def build_zone(domain, domain_properties, additional_records, env, is_zone=True)
# Append the DKIM TXT record to the zone as generated by OpenDKIM. # Append the DKIM TXT record to the zone as generated by OpenDKIM.
# Skip if the user has set a DKIM record already. # Skip if the user has set a DKIM record already.
opendkim_record_file = os.path.join(env['STORAGE_ROOT'], 'mail/dkim/mail.txt') opendkim_record_file = os.path.join(env['STORAGE_ROOT'], 'mail/dkim/mail.txt')
with open(opendkim_record_file) as orf: with open(opendkim_record_file, encoding="utf-8") as orf:
m = re.match(r'(\S+)\s+IN\s+TXT\s+\( ((?:"[^"]+"\s+)+)\)', orf.read(), re.S) m = re.match(r'(\S+)\s+IN\s+TXT\s+\( ((?:"[^"]+"\s+)+)\)', orf.read(), re.S)
val = "".join(re.findall(r'"([^"]+)"', m.group(2))) val = "".join(re.findall(r'"([^"]+)"', m.group(2)))
if not has_rec(m.group(1), "TXT", prefix="v=DKIM1; "): if not has_rec(m.group(1), "TXT", prefix="v=DKIM1; "):
@ -364,8 +358,8 @@ def build_zone(domain, domain_properties, additional_records, env, is_zone=True)
# non-mail domain and also may include qnames from custom DNS records. # non-mail domain and also may include qnames from custom DNS records.
# Do this once at the end of generating a zone. # Do this once at the end of generating a zone.
if is_zone: if is_zone:
qnames_with_a = set(qname for (qname, rtype, value, explanation) in records if rtype in ("A", "AAAA")) qnames_with_a = {qname for (qname, rtype, value, explanation) in records if rtype in {"A", "AAAA"}}
qnames_with_mx = set(qname for (qname, rtype, value, explanation) in records if rtype == "MX") qnames_with_mx = {qname for (qname, rtype, value, explanation) in records if rtype == "MX"}
for qname in qnames_with_a - qnames_with_mx: for qname in qnames_with_a - qnames_with_mx:
# Mark this domain as not sending mail with hard-fail SPF and DMARC records. # Mark this domain as not sending mail with hard-fail SPF and DMARC records.
d = (qname+"." if qname else "") + domain d = (qname+"." if qname else "") + domain
@ -455,14 +449,12 @@ def build_sshfp_records():
# specify that port to sshkeyscan. # specify that port to sshkeyscan.
port = 22 port = 22
with open('/etc/ssh/sshd_config', 'r') as f: with open('/etc/ssh/sshd_config', encoding="utf-8") as f:
for line in f: for line in f:
s = line.rstrip().split() s = line.rstrip().split()
if len(s) == 2 and s[0] == 'Port': if len(s) == 2 and s[0] == 'Port':
try: with contextlib.suppress(ValueError):
port = int(s[1]) port = int(s[1])
except ValueError:
pass
break break
keys = shell("check_output", ["ssh-keyscan", "-4", "-t", "rsa,dsa,ecdsa,ed25519", "-p", str(port), "localhost"]) keys = shell("check_output", ["ssh-keyscan", "-4", "-t", "rsa,dsa,ecdsa,ed25519", "-p", str(port), "localhost"])
@ -471,7 +463,7 @@ def build_sshfp_records():
for key in keys: for key in keys:
if key.strip() == "" or key[0] == "#": continue if key.strip() == "" or key[0] == "#": continue
try: try:
host, keytype, pubkey = key.split(" ") _host, keytype, pubkey = key.split(" ")
yield "%d %d ( %s )" % ( yield "%d %d ( %s )" % (
algorithm_number[keytype], algorithm_number[keytype],
2, # specifies we are using SHA-256 on next line 2, # specifies we are using SHA-256 on next line
@ -516,7 +508,7 @@ $TTL 86400 ; default time to live
zone = zone.format(domain=domain, primary_domain=env["PRIMARY_HOSTNAME"]) zone = zone.format(domain=domain, primary_domain=env["PRIMARY_HOSTNAME"])
# Add records. # Add records.
for subdomain, querytype, value, explanation in records: for subdomain, querytype, value, _explanation in records:
if subdomain: if subdomain:
zone += subdomain zone += subdomain
zone += "\tIN\t" + querytype + "\t" zone += "\tIN\t" + querytype + "\t"
@ -534,7 +526,7 @@ $TTL 86400 ; default time to live
zone += value + "\n" zone += value + "\n"
# Append a stable hash of DNSSEC signing keys in a comment. # Append a stable hash of DNSSEC signing keys in a comment.
zone += "\n; DNSSEC signing keys hash: {}\n".format(hash_dnssec_keys(domain, env)) zone += f"\n; DNSSEC signing keys hash: {hash_dnssec_keys(domain, env)}\n"
# DNSSEC requires re-signing a zone periodically. That requires # DNSSEC requires re-signing a zone periodically. That requires
# bumping the serial number even if no other records have changed. # bumping the serial number even if no other records have changed.
@ -550,7 +542,7 @@ $TTL 86400 ; default time to live
# We've signed the domain. Check if we are close to the expiration # We've signed the domain. Check if we are close to the expiration
# time of the signature. If so, we'll force a bump of the serial # time of the signature. If so, we'll force a bump of the serial
# number so we can re-sign it. # number so we can re-sign it.
with open(zonefile + ".signed") as f: with open(zonefile + ".signed", encoding="utf-8") as f:
signed_zone = f.read() signed_zone = f.read()
expiration_times = re.findall(r"\sRRSIG\s+SOA\s+\d+\s+\d+\s\d+\s+(\d{14})", signed_zone) expiration_times = re.findall(r"\sRRSIG\s+SOA\s+\d+\s+\d+\s\d+\s+(\d{14})", signed_zone)
if len(expiration_times) == 0: if len(expiration_times) == 0:
@ -569,7 +561,7 @@ $TTL 86400 ; default time to live
if os.path.exists(zonefile): if os.path.exists(zonefile):
# If the zone already exists, is different, and has a later serial number, # If the zone already exists, is different, and has a later serial number,
# increment the number. # increment the number.
with open(zonefile) as f: with open(zonefile, encoding="utf-8") as f:
existing_zone = f.read() existing_zone = f.read()
m = re.search(r"(\d+)\s*;\s*serial number", existing_zone) m = re.search(r"(\d+)\s*;\s*serial number", existing_zone)
if m: if m:
@ -593,7 +585,7 @@ $TTL 86400 ; default time to live
zone = zone.replace("__SERIAL__", serial) zone = zone.replace("__SERIAL__", serial)
# Write the zone file. # Write the zone file.
with open(zonefile, "w") as f: with open(zonefile, "w", encoding="utf-8") as f:
f.write(zone) f.write(zone)
return True # file is updated return True # file is updated
@ -606,7 +598,7 @@ def get_dns_zonefile(zone, env):
raise ValueError("%s is not a domain name that corresponds to a zone." % zone) raise ValueError("%s is not a domain name that corresponds to a zone." % zone)
nsd_zonefile = "/etc/nsd/zones/" + fn nsd_zonefile = "/etc/nsd/zones/" + fn
with open(nsd_zonefile, "r") as f: with open(nsd_zonefile, encoding="utf-8") as f:
return f.read() return f.read()
######################################################################## ########################################################################
@ -618,11 +610,11 @@ def write_nsd_conf(zonefiles, additional_records, env):
# Append the zones. # Append the zones.
for domain, zonefile in zonefiles: for domain, zonefile in zonefiles:
nsdconf += """ nsdconf += f"""
zone: zone:
name: %s name: {domain}
zonefile: %s zonefile: {zonefile}
""" % (domain, zonefile) """
# If custom secondary nameservers have been set, allow zone transfers # If custom secondary nameservers have been set, allow zone transfers
# and, if not a subnet, notifies to them. # and, if not a subnet, notifies to them.
@ -634,13 +626,13 @@ zone:
# Check if the file is changing. If it isn't changing, # Check if the file is changing. If it isn't changing,
# return False to flag that no change was made. # return False to flag that no change was made.
if os.path.exists(nsd_conf_file): if os.path.exists(nsd_conf_file):
with open(nsd_conf_file) as f: with open(nsd_conf_file, encoding="utf-8") as f:
if f.read() == nsdconf: if f.read() == nsdconf:
return False return False
# Write out new contents and return True to signal that # Write out new contents and return True to signal that
# configuration changed. # configuration changed.
with open(nsd_conf_file, "w") as f: with open(nsd_conf_file, "w", encoding="utf-8") as f:
f.write(nsdconf) f.write(nsdconf)
return True return True
@ -674,9 +666,8 @@ def hash_dnssec_keys(domain, env):
keydata = [] keydata = []
for keytype, keyfn in sorted(find_dnssec_signing_keys(domain, env)): for keytype, keyfn in sorted(find_dnssec_signing_keys(domain, env)):
oldkeyfn = os.path.join(env['STORAGE_ROOT'], 'dns/dnssec', keyfn + ".private") oldkeyfn = os.path.join(env['STORAGE_ROOT'], 'dns/dnssec', keyfn + ".private")
keydata.append(keytype) keydata.extend((keytype, keyfn))
keydata.append(keyfn) with open(oldkeyfn, encoding="utf-8") as fr:
with open(oldkeyfn, "r") as fr:
keydata.append( fr.read() ) keydata.append( fr.read() )
keydata = "".join(keydata).encode("utf8") keydata = "".join(keydata).encode("utf8")
return hashlib.sha1(keydata).hexdigest() return hashlib.sha1(keydata).hexdigest()
@ -704,12 +695,12 @@ def sign_zone(domain, zonefile, env):
# Use os.umask and open().write() to securely create a copy that only # Use os.umask and open().write() to securely create a copy that only
# we (root) can read. # we (root) can read.
oldkeyfn = os.path.join(env['STORAGE_ROOT'], 'dns/dnssec', keyfn + ext) oldkeyfn = os.path.join(env['STORAGE_ROOT'], 'dns/dnssec', keyfn + ext)
with open(oldkeyfn, "r") as fr: with open(oldkeyfn, encoding="utf-8") as fr:
keydata = fr.read() keydata = fr.read()
keydata = keydata.replace("_domain_", domain) keydata = keydata.replace("_domain_", domain)
prev_umask = os.umask(0o77) # ensure written file is not world-readable prev_umask = os.umask(0o77) # ensure written file is not world-readable
try: try:
with open(newkeyfn + ext, "w") as fw: with open(newkeyfn + ext, "w", encoding="utf-8") as fw:
fw.write(keydata) fw.write(keydata)
finally: finally:
os.umask(prev_umask) # other files we write should be world-readable os.umask(prev_umask) # other files we write should be world-readable
@ -743,7 +734,7 @@ def sign_zone(domain, zonefile, env):
# be used, so we'll pre-generate all for each key. One DS record per line. Only one # be used, so we'll pre-generate all for each key. One DS record per line. Only one
# needs to actually be deployed at the registrar. We'll select the preferred one # needs to actually be deployed at the registrar. We'll select the preferred one
# in the status checks. # in the status checks.
with open("/etc/nsd/zones/" + zonefile + ".ds", "w") as f: with open("/etc/nsd/zones/" + zonefile + ".ds", "w", encoding="utf-8") as f:
for key in ksk_keys: for key in ksk_keys:
for digest_type in ('1', '2', '4'): for digest_type in ('1', '2', '4'):
rr_ds = shell('check_output', ["/usr/bin/ldns-key2ds", rr_ds = shell('check_output', ["/usr/bin/ldns-key2ds",
@ -780,7 +771,7 @@ def write_opendkim_tables(domains, env):
# So we must have a separate KeyTable entry for each domain. # So we must have a separate KeyTable entry for each domain.
"SigningTable": "SigningTable":
"".join( "".join(
"*@{domain} {domain}\n".format(domain=domain) f"*@{domain} {domain}\n"
for domain in domains for domain in domains
), ),
@ -789,7 +780,7 @@ def write_opendkim_tables(domains, env):
# signing domain must match the sender's From: domain. # signing domain must match the sender's From: domain.
"KeyTable": "KeyTable":
"".join( "".join(
"{domain} {domain}:mail:{key_file}\n".format(domain=domain, key_file=opendkim_key_file) f"{domain} {domain}:mail:{opendkim_key_file}\n"
for domain in domains for domain in domains
), ),
} }
@ -798,12 +789,12 @@ def write_opendkim_tables(domains, env):
for filename, content in config.items(): for filename, content in config.items():
# Don't write the file if it doesn't need an update. # Don't write the file if it doesn't need an update.
if os.path.exists("/etc/opendkim/" + filename): if os.path.exists("/etc/opendkim/" + filename):
with open("/etc/opendkim/" + filename) as f: with open("/etc/opendkim/" + filename, encoding="utf-8") as f:
if f.read() == content: if f.read() == content:
continue continue
# The contents needs to change. # The contents needs to change.
with open("/etc/opendkim/" + filename, "w") as f: with open("/etc/opendkim/" + filename, "w", encoding="utf-8") as f:
f.write(content) f.write(content)
did_update = True did_update = True
@ -815,9 +806,9 @@ def write_opendkim_tables(domains, env):
def get_custom_dns_config(env, only_real_records=False): def get_custom_dns_config(env, only_real_records=False):
try: try:
with open(os.path.join(env['STORAGE_ROOT'], 'dns/custom.yaml'), 'r') as f: with open(os.path.join(env['STORAGE_ROOT'], 'dns/custom.yaml'), encoding="utf-8") as f:
custom_dns = rtyaml.load(f) custom_dns = rtyaml.load(f)
if not isinstance(custom_dns, dict): raise ValueError() # caught below if not isinstance(custom_dns, dict): raise ValueError # caught below
except: except:
return [ ] return [ ]
@ -835,7 +826,7 @@ def get_custom_dns_config(env, only_real_records=False):
# No other type of data is allowed. # No other type of data is allowed.
else: else:
raise ValueError() raise ValueError
for rtype, value2 in values: for rtype, value2 in values:
if isinstance(value2, str): if isinstance(value2, str):
@ -845,7 +836,7 @@ def get_custom_dns_config(env, only_real_records=False):
yield (qname, rtype, value3) yield (qname, rtype, value3)
# No other type of data is allowed. # No other type of data is allowed.
else: else:
raise ValueError() raise ValueError
def filter_custom_records(domain, custom_dns_iter): def filter_custom_records(domain, custom_dns_iter):
for qname, rtype, value in custom_dns_iter: for qname, rtype, value in custom_dns_iter:
@ -861,10 +852,7 @@ def filter_custom_records(domain, custom_dns_iter):
# our short form (None => domain, or a relative QNAME) if # our short form (None => domain, or a relative QNAME) if
# domain is not None. # domain is not None.
if domain is not None: if domain is not None:
if qname == domain: qname = None if qname == domain else qname[0:len(qname) - len("." + domain)]
qname = None
else:
qname = qname[0:len(qname)-len("." + domain)]
yield (qname, rtype, value) yield (qname, rtype, value)
@ -900,12 +888,12 @@ def write_custom_dns_config(config, env):
# Write. # Write.
config_yaml = rtyaml.dump(dns) config_yaml = rtyaml.dump(dns)
with open(os.path.join(env['STORAGE_ROOT'], 'dns/custom.yaml'), "w") as f: with open(os.path.join(env['STORAGE_ROOT'], 'dns/custom.yaml'), "w", encoding="utf-8") as f:
f.write(config_yaml) f.write(config_yaml)
def set_custom_dns_record(qname, rtype, value, action, env): def set_custom_dns_record(qname, rtype, value, action, env):
# validate qname # validate qname
for zone, fn in get_dns_zones(env): for zone, _fn in get_dns_zones(env):
# It must match a zone apex or be a subdomain of a zone # It must match a zone apex or be a subdomain of a zone
# that we are otherwise hosting. # that we are otherwise hosting.
if qname == zone or qname.endswith("."+zone): if qname == zone or qname.endswith("."+zone):
@ -919,24 +907,27 @@ def set_custom_dns_record(qname, rtype, value, action, env):
rtype = rtype.upper() rtype = rtype.upper()
if value is not None and qname != "_secondary_nameserver": if value is not None and qname != "_secondary_nameserver":
if not re.search(DOMAIN_RE, qname): if not re.search(DOMAIN_RE, qname):
raise ValueError("Invalid name.") msg = "Invalid name."
raise ValueError(msg)
if rtype in ("A", "AAAA"): if rtype in {"A", "AAAA"}:
if value != "local": # "local" is a special flag for us if value != "local": # "local" is a special flag for us
v = ipaddress.ip_address(value) # raises a ValueError if there's a problem v = ipaddress.ip_address(value) # raises a ValueError if there's a problem
if rtype == "A" and not isinstance(v, ipaddress.IPv4Address): raise ValueError("That's an IPv6 address.") if rtype == "A" and not isinstance(v, ipaddress.IPv4Address): raise ValueError("That's an IPv6 address.")
if rtype == "AAAA" and not isinstance(v, ipaddress.IPv6Address): raise ValueError("That's an IPv4 address.") if rtype == "AAAA" and not isinstance(v, ipaddress.IPv6Address): raise ValueError("That's an IPv4 address.")
elif rtype in ("CNAME", "NS"): elif rtype in {"CNAME", "NS"}:
if rtype == "NS" and qname == zone: if rtype == "NS" and qname == zone:
raise ValueError("NS records can only be set for subdomains.") msg = "NS records can only be set for subdomains."
raise ValueError(msg)
# ensure value has a trailing dot # ensure value has a trailing dot
if not value.endswith("."): if not value.endswith("."):
value = value + "." value = value + "."
if not re.search(DOMAIN_RE, value): if not re.search(DOMAIN_RE, value):
raise ValueError("Invalid value.") msg = "Invalid value."
elif rtype in ("CNAME", "TXT", "SRV", "MX", "SSHFP", "CAA"): raise ValueError(msg)
elif rtype in {"CNAME", "TXT", "SRV", "MX", "SSHFP", "CAA"}:
# anything goes # anything goes
pass pass
else: else:
@ -969,7 +960,7 @@ def set_custom_dns_record(qname, rtype, value, action, env):
# Drop this record. # Drop this record.
made_change = True made_change = True
continue continue
if value == None and (_qname, _rtype) == (qname, rtype): if value is None and (_qname, _rtype) == (qname, rtype):
# Drop all qname-rtype records. # Drop all qname-rtype records.
made_change = True made_change = True
continue continue
@ -979,7 +970,7 @@ def set_custom_dns_record(qname, rtype, value, action, env):
# Preserve this record. # Preserve this record.
newconfig.append((_qname, _rtype, _value)) newconfig.append((_qname, _rtype, _value))
if action in ("add", "set") and needs_add and value is not None: if action in {"add", "set"} and needs_add and value is not None:
newconfig.append((qname, rtype, value)) newconfig.append((qname, rtype, value))
made_change = True made_change = True
@ -996,11 +987,11 @@ def get_secondary_dns(custom_dns, mode=None):
resolver.lifetime = 10 resolver.lifetime = 10
values = [] values = []
for qname, rtype, value in custom_dns: for qname, _rtype, value in custom_dns:
if qname != '_secondary_nameserver': continue if qname != '_secondary_nameserver': continue
for hostname in value.split(" "): for hostname in value.split(" "):
hostname = hostname.strip() hostname = hostname.strip()
if mode == None: if mode is None:
# Just return the setting. # Just return the setting.
values.append(hostname) values.append(hostname)
continue continue
@ -1046,19 +1037,19 @@ def set_secondary_dns(hostnames, env):
if not item.startswith("xfr:"): if not item.startswith("xfr:"):
# Resolve hostname. # Resolve hostname.
try: try:
response = resolver.resolve(item, "A") resolver.resolve(item, "A")
except (dns.resolver.NoNameservers, dns.resolver.NXDOMAIN, dns.resolver.NoAnswer, dns.resolver.Timeout): except (dns.resolver.NoNameservers, dns.resolver.NXDOMAIN, dns.resolver.NoAnswer, dns.resolver.Timeout):
try: try:
response = resolver.resolve(item, "AAAA") resolver.resolve(item, "AAAA")
except (dns.resolver.NoNameservers, dns.resolver.NXDOMAIN, dns.resolver.NoAnswer, dns.resolver.Timeout): except (dns.resolver.NoNameservers, dns.resolver.NXDOMAIN, dns.resolver.NoAnswer, dns.resolver.Timeout):
raise ValueError("Could not resolve the IP address of %s." % item) raise ValueError("Could not resolve the IP address of %s." % item)
else: else:
# Validate IP address. # Validate IP address.
try: try:
if "/" in item[4:]: if "/" in item[4:]:
v = ipaddress.ip_network(item[4:]) # raises a ValueError if there's a problem ipaddress.ip_network(item[4:]) # raises a ValueError if there's a problem
else: else:
v = ipaddress.ip_address(item[4:]) # raises a ValueError if there's a problem ipaddress.ip_address(item[4:]) # raises a ValueError if there's a problem
except ValueError: except ValueError:
raise ValueError("'%s' is not an IPv4 or IPv6 address or subnet." % item[4:]) raise ValueError("'%s' is not an IPv4 or IPv6 address or subnet." % item[4:])
@ -1076,13 +1067,12 @@ def get_custom_dns_records(custom_dns, qname, rtype):
for qname1, rtype1, value in custom_dns: for qname1, rtype1, value in custom_dns:
if qname1 == qname and rtype1 == rtype: if qname1 == qname and rtype1 == rtype:
yield value yield value
return None
######################################################################## ########################################################################
def build_recommended_dns(env): def build_recommended_dns(env):
ret = [] ret = []
for (domain, zonefile, records) in build_zones(env): for (domain, _zonefile, records) in build_zones(env):
# remove records that we don't display # remove records that we don't display
records = [r for r in records if r[3] is not False] records = [r for r in records if r[3] is not False]
@ -1091,10 +1081,7 @@ def build_recommended_dns(env):
# expand qnames # expand qnames
for i in range(len(records)): for i in range(len(records)):
if records[i][0] == None: qname = domain if records[i][0] is None else records[i][0] + "." + domain
qname = domain
else:
qname = records[i][0] + "." + domain
records[i] = { records[i] = {
"qname": qname, "qname": qname,
@ -1113,7 +1100,7 @@ if __name__ == "__main__":
if sys.argv[-1] == "--lint": if sys.argv[-1] == "--lint":
write_custom_dns_config(get_custom_dns_config(env), env) write_custom_dns_config(get_custom_dns_config(env), env)
else: else:
for zone, records in build_recommended_dns(env): for _zone, records in build_recommended_dns(env):
for record in records: for record in records:
print("; " + record['explanation']) print("; " + record['explanation'])
print(record['qname'], record['rtype'], record['value'], sep="\t") print(record['qname'], record['rtype'], record['value'], sep="\t")

View File

@ -37,11 +37,11 @@ msg = MIMEMultipart('alternative')
# In Python 3.6: # In Python 3.6:
#msg = Message() #msg = Message()
msg['From'] = "\"%s\" <%s>" % (env['PRIMARY_HOSTNAME'], admin_addr) msg['From'] = '"{}" <{}>'.format(env['PRIMARY_HOSTNAME'], admin_addr)
msg['To'] = admin_addr msg['To'] = admin_addr
msg['Subject'] = "[%s] %s" % (env['PRIMARY_HOSTNAME'], subject) msg['Subject'] = "[{}] {}".format(env['PRIMARY_HOSTNAME'], subject)
content_html = '<html><body><pre style="overflow-x: scroll; white-space: pre;">{}</pre></body></html>'.format(html.escape(content)) content_html = f'<html><body><pre style="overflow-x: scroll; white-space: pre;">{html.escape(content)}</pre></body></html>'
msg.attach(MIMEText(content, 'plain')) msg.attach(MIMEText(content, 'plain'))
msg.attach(MIMEText(content_html, 'html')) msg.attach(MIMEText(content_html, 'html'))

View File

@ -116,12 +116,11 @@ def scan_mail_log(env):
try: try:
import mailconfig import mailconfig
collector["known_addresses"] = (set(mailconfig.get_mail_users(env)) | collector["known_addresses"] = (set(mailconfig.get_mail_users(env)) |
set(alias[0] for alias in mailconfig.get_mail_aliases(env))) {alias[0] for alias in mailconfig.get_mail_aliases(env)})
except ImportError: except ImportError:
pass pass
print("Scanning logs from {:%Y-%m-%d %H:%M:%S} to {:%Y-%m-%d %H:%M:%S}".format( print(f"Scanning logs from {START_DATE:%Y-%m-%d %H:%M:%S} to {END_DATE:%Y-%m-%d %H:%M:%S}"
START_DATE, END_DATE)
) )
# Scan the lines in the log files until the date goes out of range # Scan the lines in the log files until the date goes out of range
@ -227,7 +226,7 @@ def scan_mail_log(env):
], ],
sub_data=[ sub_data=[
("Protocol and Source", [[ ("Protocol and Source", [[
"{} {}: {} times".format(protocol_name, host, count) f"{protocol_name} {host}: {count} times"
for (protocol_name, host), count for (protocol_name, host), count
in sorted(u["totals_by_protocol_and_host"].items(), key=lambda kv:-kv[1]) in sorted(u["totals_by_protocol_and_host"].items(), key=lambda kv:-kv[1])
] for u in data.values()]) ] for u in data.values()])
@ -303,8 +302,7 @@ def scan_mail_log(env):
for date, sender, message in user_data["blocked"]: for date, sender, message in user_data["blocked"]:
if len(sender) > 64: if len(sender) > 64:
sender = sender[:32] + "" + sender[-32:] sender = sender[:32] + "" + sender[-32:]
user_rejects.append("%s - %s " % (date, sender)) user_rejects.extend((f'{date} - {sender} ', ' %s' % message))
user_rejects.append(" %s" % message)
rejects.append(user_rejects) rejects.append(user_rejects)
print_user_table( print_user_table(
@ -322,7 +320,7 @@ def scan_mail_log(env):
if collector["other-services"] and VERBOSE and False: if collector["other-services"] and VERBOSE and False:
print_header("Other services") print_header("Other services")
print("The following unkown services were found in the log file.") print("The following unkown services were found in the log file.")
print(" ", *sorted(list(collector["other-services"])), sep='\n') print(" ", *sorted(collector["other-services"]), sep='\n')
def scan_mail_log_line(line, collector): def scan_mail_log_line(line, collector):
@ -333,7 +331,7 @@ def scan_mail_log_line(line, collector):
if not m: if not m:
return True return True
date, system, service, log = m.groups() date, _system, service, log = m.groups()
collector["scan_count"] += 1 collector["scan_count"] += 1
# print() # print()
@ -376,9 +374,9 @@ def scan_mail_log_line(line, collector):
elif service == "postfix/smtpd": elif service == "postfix/smtpd":
if SCAN_BLOCKED: if SCAN_BLOCKED:
scan_postfix_smtpd_line(date, log, collector) scan_postfix_smtpd_line(date, log, collector)
elif service in ("postfix/qmgr", "postfix/pickup", "postfix/cleanup", "postfix/scache", elif service in {"postfix/qmgr", "postfix/pickup", "postfix/cleanup", "postfix/scache",
"spampd", "postfix/anvil", "postfix/master", "opendkim", "postfix/lmtp", "spampd", "postfix/anvil", "postfix/master", "opendkim", "postfix/lmtp",
"postfix/tlsmgr", "anvil"): "postfix/tlsmgr", "anvil"}:
# nothing to look at # nothing to look at
return True return True
else: else:
@ -392,7 +390,7 @@ def scan_mail_log_line(line, collector):
def scan_postgrey_line(date, log, collector): def scan_postgrey_line(date, log, collector):
""" Scan a postgrey log line and extract interesting data """ """ Scan a postgrey log line and extract interesting data """
m = re.match("action=(greylist|pass), reason=(.*?), (?:delay=\d+, )?client_name=(.*), " m = re.match(r"action=(greylist|pass), reason=(.*?), (?:delay=\d+, )?client_name=(.*), "
"client_address=(.*), sender=(.*), recipient=(.*)", "client_address=(.*), sender=(.*), recipient=(.*)",
log) log)
@ -435,8 +433,7 @@ def scan_postfix_smtpd_line(date, log, collector):
return return
# only log mail to known recipients # only log mail to known recipients
if user_match(user): if user_match(user) and (collector["known_addresses"] is None or user in collector["known_addresses"]):
if collector["known_addresses"] is None or user in collector["known_addresses"]:
data = collector["rejected"].get( data = collector["rejected"].get(
user, user,
{ {
@ -500,7 +497,7 @@ def add_login(user, date, protocol_name, host, collector):
data["totals_by_protocol"][protocol_name] += 1 data["totals_by_protocol"][protocol_name] += 1
data["totals_by_protocol_and_host"][(protocol_name, host)] += 1 data["totals_by_protocol_and_host"][(protocol_name, host)] += 1
if host not in ("127.0.0.1", "::1") or True: if host not in {"127.0.0.1", "::1"} or True:
data["activity-by-hour"][protocol_name][date.hour] += 1 data["activity-by-hour"][protocol_name][date.hour] += 1
collector["logins"][user] = data collector["logins"][user] = data
@ -514,7 +511,7 @@ def scan_postfix_lmtp_line(date, log, collector):
""" """
m = re.match("([A-Z0-9]+): to=<(\S+)>, .* Saved", log) m = re.match(r"([A-Z0-9]+): to=<(\S+)>, .* Saved", log)
if m: if m:
_, user = m.groups() _, user = m.groups()
@ -552,10 +549,10 @@ def scan_postfix_submission_line(date, log, collector):
# Match both the 'plain' and 'login' sasl methods, since both authentication methods are # Match both the 'plain' and 'login' sasl methods, since both authentication methods are
# allowed by Dovecot. Exclude trailing comma after the username when additional fields # allowed by Dovecot. Exclude trailing comma after the username when additional fields
# follow after. # follow after.
m = re.match("([A-Z0-9]+): client=(\S+), sasl_method=(PLAIN|LOGIN), sasl_username=(\S+)(?<!,)", log) m = re.match(r"([A-Z0-9]+): client=(\S+), sasl_method=(PLAIN|LOGIN), sasl_username=(\S+)(?<!,)", log)
if m: if m:
_, client, method, user = m.groups() _, client, _method, user = m.groups()
if user_match(user): if user_match(user):
# Get the user data, or create it if the user is new # Get the user data, or create it if the user is new
@ -588,7 +585,7 @@ def scan_postfix_submission_line(date, log, collector):
def readline(filename): def readline(filename):
""" A generator that returns the lines of a file """ A generator that returns the lines of a file
""" """
with open(filename, errors='replace') as file: with open(filename, errors='replace', encoding='utf-8') as file:
while True: while True:
line = file.readline() line = file.readline()
if not line: if not line:
@ -622,10 +619,7 @@ def print_time_table(labels, data, do_print=True):
data.insert(0, [str(h) for h in range(24)]) data.insert(0, [str(h) for h in range(24)])
temp = "{:<%d} " % max(len(l) for l in labels) temp = "{:<%d} " % max(len(l) for l in labels)
lines = [] lines = [temp.format(label) for label in labels]
for label in labels:
lines.append(temp.format(label))
for h in range(24): for h in range(24):
max_len = max(len(str(d[h])) for d in data) max_len = max(len(str(d[h])) for d in data)
@ -639,6 +633,7 @@ def print_time_table(labels, data, do_print=True):
if do_print: if do_print:
print("\n".join(lines)) print("\n".join(lines))
return None
else: else:
return lines return lines
@ -672,7 +667,7 @@ def print_user_table(users, data=None, sub_data=None, activity=None, latest=None
col_str = str_temp.format(d[row][:31] + "" if len(d[row]) > 32 else d[row]) col_str = str_temp.format(d[row][:31] + "" if len(d[row]) > 32 else d[row])
col_left[col] = True col_left[col] = True
elif isinstance(d[row], datetime.datetime): elif isinstance(d[row], datetime.datetime):
col_str = "{:<20}".format(str(d[row])) col_str = f"{d[row]!s:<20}"
col_left[col] = True col_left[col] = True
else: else:
temp = "{:>%s}" % max(5, len(l) + 1, len(str(d[row])) + 1) temp = "{:>%s}" % max(5, len(l) + 1, len(str(d[row])) + 1)
@ -684,7 +679,7 @@ def print_user_table(users, data=None, sub_data=None, activity=None, latest=None
data_accum[col] += d[row] data_accum[col] += d[row]
try: try:
if None not in [latest, earliest]: if None not in {latest, earliest}:
vert_pos = len(line) vert_pos = len(line)
e = earliest[row] e = earliest[row]
l = latest[row] l = latest[row]
@ -712,10 +707,7 @@ def print_user_table(users, data=None, sub_data=None, activity=None, latest=None
if sub_data is not None: if sub_data is not None:
for l, d in sub_data: for l, d in sub_data:
if d[row]: if d[row]:
lines.append("") lines.extend(('', '%s' % l, '├─%s' % (len(l) * ''), ''))
lines.append("%s" % l)
lines.append("├─%s" % (len(l) * ""))
lines.append("")
max_len = 0 max_len = 0
for v in list(d[row]): for v in list(d[row]):
lines.append("%s" % v) lines.append("%s" % v)
@ -740,7 +732,7 @@ def print_user_table(users, data=None, sub_data=None, activity=None, latest=None
else: else:
header += l.rjust(max(5, len(l) + 1, col_widths[col])) header += l.rjust(max(5, len(l) + 1, col_widths[col]))
if None not in (latest, earliest): if None not in {latest, earliest}:
header += " │ timespan " header += " │ timespan "
lines.insert(0, header.rstrip()) lines.insert(0, header.rstrip())
@ -765,7 +757,7 @@ def print_user_table(users, data=None, sub_data=None, activity=None, latest=None
footer += temp.format(data_accum[row]) footer += temp.format(data_accum[row])
try: try:
if None not in [latest, earliest]: if None not in {latest, earliest}:
max_l = max(latest) max_l = max(latest)
min_e = min(earliest) min_e = min(earliest)
timespan = relativedelta(max_l, min_e) timespan = relativedelta(max_l, min_e)
@ -844,7 +836,7 @@ if __name__ == "__main__":
END_DATE = args.enddate END_DATE = args.enddate
if args.timespan == 'today': if args.timespan == 'today':
args.timespan = 'day' args.timespan = 'day'
print("Setting end date to {}".format(END_DATE)) print(f"Setting end date to {END_DATE}")
START_DATE = END_DATE - TIME_DELTAS[args.timespan] START_DATE = END_DATE - TIME_DELTAS[args.timespan]

View File

@ -9,7 +9,7 @@
# Python 3 in setup/questions.sh to validate the email # Python 3 in setup/questions.sh to validate the email
# address entered by the user. # address entered by the user.
import subprocess, shutil, os, sqlite3, re import os, sqlite3, re
import utils import utils
from email_validator import validate_email as validate_email_, EmailNotValidError from email_validator import validate_email as validate_email_, EmailNotValidError
import idna import idna
@ -86,10 +86,7 @@ def prettify_idn_email_address(email):
def is_dcv_address(email): def is_dcv_address(email):
email = email.lower() email = email.lower()
for localpart in ("admin", "administrator", "postmaster", "hostmaster", "webmaster", "abuse"): return any(email.startswith((localpart + "@", localpart + "+")) for localpart in ("admin", "administrator", "postmaster", "hostmaster", "webmaster", "abuse"))
if email.startswith(localpart+"@") or email.startswith(localpart+"+"):
return True
return False
def open_database(env, with_connection=False): def open_database(env, with_connection=False):
conn = sqlite3.connect(env["STORAGE_ROOT"] + "/mail/users.sqlite") conn = sqlite3.connect(env["STORAGE_ROOT"] + "/mail/users.sqlite")
@ -192,8 +189,7 @@ def get_mail_aliases(env):
aliases = { row[0]: row for row in c.fetchall() } # make dict aliases = { row[0]: row for row in c.fetchall() } # make dict
# put in a canonical order: sort by domain, then by email address lexicographically # put in a canonical order: sort by domain, then by email address lexicographically
aliases = [ aliases[address] for address in utils.sort_email_addresses(aliases.keys(), env) ] return [ aliases[address] for address in utils.sort_email_addresses(aliases.keys(), env) ]
return aliases
def get_mail_aliases_ex(env): def get_mail_aliases_ex(env):
# Returns a complex data structure of all mail aliases, similar # Returns a complex data structure of all mail aliases, similar
@ -225,7 +221,7 @@ def get_mail_aliases_ex(env):
domain = get_domain(address) domain = get_domain(address)
# add to list # add to list
if not domain in domains: if domain not in domains:
domains[domain] = { domains[domain] = {
"domain": domain, "domain": domain,
"aliases": [], "aliases": [],
@ -477,10 +473,7 @@ def add_mail_alias(address, forwards_to, permitted_senders, env, update_if_exist
forwards_to = ",".join(validated_forwards_to) forwards_to = ",".join(validated_forwards_to)
if len(validated_permitted_senders) == 0: permitted_senders = None if len(validated_permitted_senders) == 0 else ",".join(validated_permitted_senders)
permitted_senders = None
else:
permitted_senders = ",".join(validated_permitted_senders)
conn, c = open_database(env, with_connection=True) conn, c = open_database(env, with_connection=True)
try: try:
@ -498,6 +491,7 @@ def add_mail_alias(address, forwards_to, permitted_senders, env, update_if_exist
if do_kick: if do_kick:
# Update things in case any new domains are added. # Update things in case any new domains are added.
return kick(env, return_status) return kick(env, return_status)
return None
def remove_mail_alias(address, env, do_kick=True): def remove_mail_alias(address, env, do_kick=True):
# convert Unicode domain to IDNA # convert Unicode domain to IDNA
@ -513,10 +507,11 @@ def remove_mail_alias(address, env, do_kick=True):
if do_kick: if do_kick:
# Update things in case any domains are removed. # Update things in case any domains are removed.
return kick(env, "alias removed") return kick(env, "alias removed")
return None
def add_auto_aliases(aliases, env): def add_auto_aliases(aliases, env):
conn, c = open_database(env, with_connection=True) conn, c = open_database(env, with_connection=True)
c.execute("DELETE FROM auto_aliases"); c.execute("DELETE FROM auto_aliases")
for source, destination in aliases.items(): for source, destination in aliases.items():
c.execute("INSERT INTO auto_aliases (source, destination) VALUES (?, ?)", (source, destination)) c.execute("INSERT INTO auto_aliases (source, destination) VALUES (?, ?)", (source, destination))
conn.commit() conn.commit()
@ -586,14 +581,14 @@ def kick(env, mail_result=None):
# Remove auto-generated postmaster/admin/abuse alises from the main aliases table. # Remove auto-generated postmaster/admin/abuse alises from the main aliases table.
# They are now stored in the auto_aliases table. # They are now stored in the auto_aliases table.
for address, forwards_to, permitted_senders, auto in get_mail_aliases(env): for address, forwards_to, _permitted_senders, auto in get_mail_aliases(env):
user, domain = address.split("@") user, domain = address.split("@")
if user in ("postmaster", "admin", "abuse") \ if user in {"postmaster", "admin", "abuse"} \
and address not in required_aliases \ and address not in required_aliases \
and forwards_to == get_system_administrator(env) \ and forwards_to == get_system_administrator(env) \
and not auto: and not auto:
remove_mail_alias(address, env, do_kick=False) remove_mail_alias(address, env, do_kick=False)
results.append("removed alias %s (was to %s; domain no longer used for email)\n" % (address, forwards_to)) results.append(f"removed alias {address} (was to {forwards_to}; domain no longer used for email)\n")
# Update DNS and nginx in case any domains are added/removed. # Update DNS and nginx in case any domains are added/removed.
@ -608,9 +603,11 @@ def kick(env, mail_result=None):
def validate_password(pw): def validate_password(pw):
# validate password # validate password
if pw.strip() == "": if pw.strip() == "":
raise ValueError("No password provided.") msg = "No password provided."
raise ValueError(msg)
if len(pw) < 8: if len(pw) < 8:
raise ValueError("Passwords must be at least eight characters.") msg = "Passwords must be at least eight characters."
raise ValueError(msg)
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys

View File

@ -41,9 +41,11 @@ def enable_mfa(email, type, secret, token, label, env):
# Sanity check with the provide current token. # Sanity check with the provide current token.
totp = pyotp.TOTP(secret) totp = pyotp.TOTP(secret)
if not totp.verify(token, valid_window=1): if not totp.verify(token, valid_window=1):
raise ValueError("Invalid token.") msg = "Invalid token."
raise ValueError(msg)
else: else:
raise ValueError("Invalid MFA type.") msg = "Invalid MFA type."
raise ValueError(msg)
conn, c = open_database(env, with_connection=True) conn, c = open_database(env, with_connection=True)
c.execute('INSERT INTO mfa (user_id, type, secret, label) VALUES (?, ?, ?, ?)', (get_user_id(email, c), type, secret, label)) c.execute('INSERT INTO mfa (user_id, type, secret, label) VALUES (?, ?, ?, ?)', (get_user_id(email, c), type, secret, label))
@ -66,10 +68,12 @@ def disable_mfa(email, mfa_id, env):
return c.rowcount > 0 return c.rowcount > 0
def validate_totp_secret(secret): def validate_totp_secret(secret):
if type(secret) != str or secret.strip() == "": if not isinstance(secret, str) or secret.strip() == "":
raise ValueError("No secret provided.") msg = "No secret provided."
raise ValueError(msg)
if len(secret) != 32: if len(secret) != 32:
raise ValueError("Secret should be a 32 characters base32 string") msg = "Secret should be a 32 characters base32 string"
raise ValueError(msg)
def provision_totp(email, env): def provision_totp(email, env):
# Make a new secret. # Make a new secret.

View File

@ -4,7 +4,8 @@
import os, os.path, re, shutil, subprocess, tempfile import os, os.path, re, shutil, subprocess, tempfile
from utils import shell, safe_domain_name, sort_domains from utils import shell, safe_domain_name, sort_domains
import idna import functools
import operator
# SELECTING SSL CERTIFICATES FOR USE IN WEB # SELECTING SSL CERTIFICATES FOR USE IN WEB
@ -83,8 +84,7 @@ def get_ssl_certificates(env):
for domain in cert_domains: for domain in cert_domains:
# The primary hostname can only use a certificate mapped # The primary hostname can only use a certificate mapped
# to the system private key. # to the system private key.
if domain == env['PRIMARY_HOSTNAME']: if domain == env['PRIMARY_HOSTNAME'] and cert["private_key"]["filename"] != os.path.join(env['STORAGE_ROOT'], 'ssl', 'ssl_private_key.pem'):
if cert["private_key"]["filename"] != os.path.join(env['STORAGE_ROOT'], 'ssl', 'ssl_private_key.pem'):
continue continue
domains.setdefault(domain, []).append(cert) domains.setdefault(domain, []).append(cert)
@ -150,13 +150,12 @@ def get_domain_ssl_files(domain, ssl_certificates, env, allow_missing_cert=False
"certificate_object": load_pem(load_cert_chain(ssl_certificate)[0]), "certificate_object": load_pem(load_cert_chain(ssl_certificate)[0]),
} }
if use_main_cert: if use_main_cert and domain == env['PRIMARY_HOSTNAME']:
if domain == env['PRIMARY_HOSTNAME']:
# The primary domain must use the server certificate because # The primary domain must use the server certificate because
# it is hard-coded in some service configuration files. # it is hard-coded in some service configuration files.
return system_certificate return system_certificate
wildcard_domain = re.sub("^[^\.]+", "*", domain) wildcard_domain = re.sub(r"^[^\.]+", "*", domain)
if domain in ssl_certificates: if domain in ssl_certificates:
return ssl_certificates[domain] return ssl_certificates[domain]
elif wildcard_domain in ssl_certificates: elif wildcard_domain in ssl_certificates:
@ -212,7 +211,7 @@ def get_certificates_to_provision(env, limit_domains=None, show_valid_certs=True
if not value: continue # IPv6 is not configured if not value: continue # IPv6 is not configured
response = query_dns(domain, rtype) response = query_dns(domain, rtype)
if response != normalize_ip(value): if response != normalize_ip(value):
bad_dns.append("%s (%s)" % (response, rtype)) bad_dns.append(f"{response} ({rtype})")
if bad_dns: if bad_dns:
domains_cant_provision[domain] = "The domain name does not resolve to this machine: " \ domains_cant_provision[domain] = "The domain name does not resolve to this machine: " \
@ -265,11 +264,11 @@ def provision_certificates(env, limit_domains):
# primary domain listed in each certificate. # primary domain listed in each certificate.
from dns_update import get_dns_zones from dns_update import get_dns_zones
certs = { } certs = { }
for zone, zonefile in get_dns_zones(env): for zone, _zonefile in get_dns_zones(env):
certs[zone] = [[]] certs[zone] = [[]]
for domain in sort_domains(domains, env): for domain in sort_domains(domains, env):
# Does the domain end with any domain we've seen so far. # Does the domain end with any domain we've seen so far.
for parent in certs.keys(): for parent in certs:
if domain.endswith("." + parent): if domain.endswith("." + parent):
# Add this to the parent's list of domains. # Add this to the parent's list of domains.
# Start a new group if the list already has # Start a new group if the list already has
@ -286,7 +285,7 @@ def provision_certificates(env, limit_domains):
# Flatten to a list of lists of domains (from a mapping). Remove empty # Flatten to a list of lists of domains (from a mapping). Remove empty
# lists (zones with no domains that need certs). # lists (zones with no domains that need certs).
certs = sum(certs.values(), []) certs = functools.reduce(operator.iadd, certs.values(), [])
certs = [_ for _ in certs if len(_) > 0] certs = [_ for _ in certs if len(_) > 0]
# Prepare to provision. # Prepare to provision.
@ -414,7 +413,7 @@ def create_csr(domain, ssl_key, country_code, env):
"openssl", "req", "-new", "openssl", "req", "-new",
"-key", ssl_key, "-key", ssl_key,
"-sha256", "-sha256",
"-subj", "/C=%s/CN=%s" % (country_code, domain)]) "-subj", f"/C={country_code}/CN={domain}"])
def install_cert(domain, ssl_cert, ssl_chain, env, raw=False): def install_cert(domain, ssl_cert, ssl_chain, env, raw=False):
# Write the combined cert+chain to a temporary path and validate that it is OK. # Write the combined cert+chain to a temporary path and validate that it is OK.
@ -450,8 +449,8 @@ def install_cert_copy_file(fn, env):
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from binascii import hexlify from binascii import hexlify
cert = load_pem(load_cert_chain(fn)[0]) cert = load_pem(load_cert_chain(fn)[0])
all_domains, cn = get_certificate_domains(cert) _all_domains, cn = get_certificate_domains(cert)
path = "%s-%s-%s.pem" % ( path = "{}-{}-{}.pem".format(
safe_domain_name(cn), # common name, which should be filename safe because it is IDNA-encoded, but in case of a malformed cert make sure it's ok to use as a filename safe_domain_name(cn), # common name, which should be filename safe because it is IDNA-encoded, but in case of a malformed cert make sure it's ok to use as a filename
cert.not_valid_after.date().isoformat().replace("-", ""), # expiration date cert.not_valid_after.date().isoformat().replace("-", ""), # expiration date
hexlify(cert.fingerprint(hashes.SHA256())).decode("ascii")[0:8], # fingerprint prefix hexlify(cert.fingerprint(hashes.SHA256())).decode("ascii")[0:8], # fingerprint prefix
@ -522,12 +521,12 @@ def check_certificate(domain, ssl_certificate, ssl_private_key, warn_if_expiring
# First check that the domain name is one of the names allowed by # First check that the domain name is one of the names allowed by
# the certificate. # the certificate.
if domain is not None: if domain is not None:
certificate_names, cert_primary_name = get_certificate_domains(cert) certificate_names, _cert_primary_name = get_certificate_domains(cert)
# Check that the domain appears among the acceptable names, or a wildcard # Check that the domain appears among the acceptable names, or a wildcard
# form of the domain name (which is a stricter check than the specs but # form of the domain name (which is a stricter check than the specs but
# should work in normal cases). # should work in normal cases).
wildcard_domain = re.sub("^[^\.]+", "*", domain) wildcard_domain = re.sub(r"^[^\.]+", "*", domain)
if domain not in certificate_names and wildcard_domain not in certificate_names: if domain not in certificate_names and wildcard_domain not in certificate_names:
return ("The certificate is for the wrong domain name. It is for %s." return ("The certificate is for the wrong domain name. It is for %s."
% ", ".join(sorted(certificate_names)), None) % ", ".join(sorted(certificate_names)), None)
@ -538,7 +537,7 @@ def check_certificate(domain, ssl_certificate, ssl_private_key, warn_if_expiring
with open(ssl_private_key, 'rb') as f: with open(ssl_private_key, 'rb') as f:
priv_key = load_pem(f.read()) priv_key = load_pem(f.read())
except ValueError as e: except ValueError as e:
return ("The private key file %s is not a private key file: %s" % (ssl_private_key, str(e)), None) return (f"The private key file {ssl_private_key} is not a private key file: {e!s}", None)
if not isinstance(priv_key, RSAPrivateKey): if not isinstance(priv_key, RSAPrivateKey):
return ("The private key file %s is not a private key file." % ssl_private_key, None) return ("The private key file %s is not a private key file." % ssl_private_key, None)
@ -566,7 +565,7 @@ def check_certificate(domain, ssl_certificate, ssl_private_key, warn_if_expiring
import datetime import datetime
now = datetime.datetime.utcnow() now = datetime.datetime.utcnow()
if not(cert.not_valid_before <= now <= cert.not_valid_after): if not(cert.not_valid_before <= now <= cert.not_valid_after):
return ("The certificate has expired or is not yet valid. It is valid from %s to %s." % (cert.not_valid_before, cert.not_valid_after), None) return (f"The certificate has expired or is not yet valid. It is valid from {cert.not_valid_before} to {cert.not_valid_after}.", None)
# Next validate that the certificate is valid. This checks whether the certificate # Next validate that the certificate is valid. This checks whether the certificate
# is self-signed, that the chain of trust makes sense, that it is signed by a CA # is self-signed, that the chain of trust makes sense, that it is signed by a CA
@ -625,7 +624,8 @@ def load_cert_chain(pemfile):
pem = f.read() + b"\n" # ensure trailing newline pem = f.read() + b"\n" # ensure trailing newline
pemblocks = re.findall(re_pem, pem) pemblocks = re.findall(re_pem, pem)
if len(pemblocks) == 0: if len(pemblocks) == 0:
raise ValueError("File does not contain valid PEM data.") msg = "File does not contain valid PEM data."
raise ValueError(msg)
return pemblocks return pemblocks
def load_pem(pem): def load_pem(pem):
@ -636,9 +636,10 @@ def load_pem(pem):
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
pem_type = re.match(b"-+BEGIN (.*?)-+[\r\n]", pem) pem_type = re.match(b"-+BEGIN (.*?)-+[\r\n]", pem)
if pem_type is None: if pem_type is None:
raise ValueError("File is not a valid PEM-formatted file.") msg = "File is not a valid PEM-formatted file."
raise ValueError(msg)
pem_type = pem_type.group(1) pem_type = pem_type.group(1)
if pem_type in (b"RSA PRIVATE KEY", b"PRIVATE KEY"): if pem_type in {b"RSA PRIVATE KEY", b"PRIVATE KEY"}:
return serialization.load_pem_private_key(pem, password=None, backend=default_backend()) return serialization.load_pem_private_key(pem, password=None, backend=default_backend())
if pem_type == b"CERTIFICATE": if pem_type == b"CERTIFICATE":
return load_pem_x509_certificate(pem, default_backend()) return load_pem_x509_certificate(pem, default_backend())

View File

@ -4,11 +4,10 @@
# TLS certificates have been signed, etc., and if not tells the user # TLS certificates have been signed, etc., and if not tells the user
# what to do next. # what to do next.
import sys, os, os.path, re, subprocess, datetime, multiprocessing.pool import sys, os, os.path, re, datetime, multiprocessing.pool
import asyncio import asyncio
import dns.reversename, dns.resolver import dns.reversename, dns.resolver
import dateutil.parser, dateutil.tz
import idna import idna
import psutil import psutil
import postfix_mta_sts_resolver.resolver import postfix_mta_sts_resolver.resolver
@ -89,7 +88,7 @@ def run_services_checks(env, output, pool):
all_running = True all_running = True
fatal = False fatal = False
ret = pool.starmap(check_service, ((i, service, env) for i, service in enumerate(get_services())), chunksize=1) ret = pool.starmap(check_service, ((i, service, env) for i, service in enumerate(get_services())), chunksize=1)
for i, running, fatal2, output2 in sorted(ret): for _i, running, fatal2, output2 in sorted(ret):
if output2 is None: continue # skip check (e.g. no port was set, e.g. no sshd) if output2 is None: continue # skip check (e.g. no port was set, e.g. no sshd)
all_running = all_running and running all_running = all_running and running
fatal = fatal or fatal2 fatal = fatal or fatal2
@ -125,7 +124,7 @@ def check_service(i, service, env):
try: try:
s.connect((ip, service["port"])) s.connect((ip, service["port"]))
return True return True
except OSError as e: except OSError:
# timed out or some other odd error # timed out or some other odd error
return False return False
finally: finally:
@ -152,18 +151,17 @@ def check_service(i, service, env):
output.print_error("%s is not running (port %d)." % (service['name'], service['port'])) output.print_error("%s is not running (port %d)." % (service['name'], service['port']))
# Why is nginx not running? # Why is nginx not running?
if not running and service["port"] in (80, 443): if not running and service["port"] in {80, 443}:
output.print_line(shell('check_output', ['nginx', '-t'], capture_stderr=True, trap=True)[1].strip()) output.print_line(shell('check_output', ['nginx', '-t'], capture_stderr=True, trap=True)[1].strip())
else:
# Service should be running locally. # Service should be running locally.
if try_connect("127.0.0.1"): elif try_connect("127.0.0.1"):
running = True running = True
else: else:
output.print_error("%s is not running (port %d)." % (service['name'], service['port'])) output.print_error("%s is not running (port %d)." % (service['name'], service['port']))
# Flag if local DNS is not running. # Flag if local DNS is not running.
if not running and service["port"] == 53 and service["public"] == False: if not running and service["port"] == 53 and service["public"] is False:
fatal = True fatal = True
return (i, running, fatal, output) return (i, running, fatal, output)
@ -195,7 +193,7 @@ def check_ufw(env, output):
for service in get_services(): for service in get_services():
if service["public"] and not is_port_allowed(ufw, service["port"]): if service["public"] and not is_port_allowed(ufw, service["port"]):
not_allowed_ports += 1 not_allowed_ports += 1
output.print_error("Port %s (%s) should be allowed in the firewall, please re-run the setup." % (service["port"], service["name"])) output.print_error("Port {} ({}) should be allowed in the firewall, please re-run the setup.".format(service["port"], service["name"]))
if not_allowed_ports == 0: if not_allowed_ports == 0:
output.print_ok("Firewall is active.") output.print_ok("Firewall is active.")
@ -213,10 +211,10 @@ def check_ssh_password(env, output):
# the configuration file. # the configuration file.
if not os.path.exists("/etc/ssh/sshd_config"): if not os.path.exists("/etc/ssh/sshd_config"):
return return
with open("/etc/ssh/sshd_config", "r") as f: with open("/etc/ssh/sshd_config", encoding="utf-8") as f:
sshd = f.read() sshd = f.read()
if re.search("\nPasswordAuthentication\s+yes", sshd) \ if re.search("\nPasswordAuthentication\\s+yes", sshd) \
or not re.search("\nPasswordAuthentication\s+no", sshd): or not re.search("\nPasswordAuthentication\\s+no", sshd):
output.print_error("""The SSH server on this machine permits password-based login. A more secure output.print_error("""The SSH server on this machine permits password-based login. A more secure
way to log in is using a public key. Add your SSH public key to $HOME/.ssh/authorized_keys, check way to log in is using a public key. Add your SSH public key to $HOME/.ssh/authorized_keys, check
that you can log in without a password, set the option 'PasswordAuthentication no' in that you can log in without a password, set the option 'PasswordAuthentication no' in
@ -237,7 +235,7 @@ def check_software_updates(env, output):
else: else:
output.print_error("There are %d software packages that can be updated." % len(pkgs)) output.print_error("There are %d software packages that can be updated." % len(pkgs))
for p in pkgs: for p in pkgs:
output.print_line("%s (%s)" % (p["package"], p["version"])) output.print_line("{} ({})".format(p["package"], p["version"]))
def check_system_aliases(env, output): def check_system_aliases(env, output):
# Check that the administrator alias exists since that's where all # Check that the administrator alias exists since that's where all
@ -269,8 +267,7 @@ def check_free_disk_space(rounded_values, env, output):
except: except:
backup_cache_count = 0 backup_cache_count = 0
if backup_cache_count > 1: if backup_cache_count > 1:
output.print_warning("The backup cache directory {} has more than one backup target cache. Consider clearing this directory to save disk space." output.print_warning(f"The backup cache directory {backup_cache_path} has more than one backup target cache. Consider clearing this directory to save disk space.")
.format(backup_cache_path))
def check_free_memory(rounded_values, env, output): def check_free_memory(rounded_values, env, output):
# Check free memory. # Check free memory.
@ -296,7 +293,7 @@ def run_network_checks(env, output):
# Stop if we cannot make an outbound connection on port 25. Many residential # Stop if we cannot make an outbound connection on port 25. Many residential
# networks block outbound port 25 to prevent their network from sending spam. # networks block outbound port 25 to prevent their network from sending spam.
# See if we can reach one of Google's MTAs with a 5-second timeout. # See if we can reach one of Google's MTAs with a 5-second timeout.
code, ret = shell("check_call", ["/bin/nc", "-z", "-w5", "aspmx.l.google.com", "25"], trap=True) _code, ret = shell("check_call", ["/bin/nc", "-z", "-w5", "aspmx.l.google.com", "25"], trap=True)
if ret == 0: if ret == 0:
output.print_ok("Outbound mail (SMTP port 25) is not blocked.") output.print_ok("Outbound mail (SMTP port 25) is not blocked.")
else: else:
@ -318,9 +315,8 @@ def run_network_checks(env, output):
elif zen == "[Not Set]": elif zen == "[Not Set]":
output.print_warning("Could not connect to zen.spamhaus.org. We could not determine whether your server's IP address is blacklisted. Please try again later.") output.print_warning("Could not connect to zen.spamhaus.org. We could not determine whether your server's IP address is blacklisted. Please try again later.")
else: else:
output.print_error("""The IP address of this machine %s is listed in the Spamhaus Block List (code %s), output.print_error("""The IP address of this machine {} is listed in the Spamhaus Block List (code {}),
which may prevent recipients from receiving your email. See http://www.spamhaus.org/query/ip/%s.""" which may prevent recipients from receiving your email. See http://www.spamhaus.org/query/ip/{}.""".format(env['PUBLIC_IP'], zen, env['PUBLIC_IP']))
% (env['PUBLIC_IP'], zen, env['PUBLIC_IP']))
def run_domain_checks(rounded_time, env, output, pool, domains_to_check=None): def run_domain_checks(rounded_time, env, output, pool, domains_to_check=None):
# Get the list of domains we handle mail for. # Get the list of domains we handle mail for.
@ -341,7 +337,7 @@ def run_domain_checks(rounded_time, env, output, pool, domains_to_check=None):
domains_to_check = [ domains_to_check = [
d for d in domains_to_check d for d in domains_to_check
if not ( if not (
d.split(".", 1)[0] in ("www", "autoconfig", "autodiscover", "mta-sts") d.split(".", 1)[0] in {"www", "autoconfig", "autodiscover", "mta-sts"}
and len(d.split(".", 1)) == 2 and len(d.split(".", 1)) == 2
and d.split(".", 1)[1] in domains_to_check and d.split(".", 1)[1] in domains_to_check
) )
@ -423,8 +419,7 @@ def check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles):
# If a DS record is set on the zone containing this domain, check DNSSEC now. # If a DS record is set on the zone containing this domain, check DNSSEC now.
has_dnssec = False has_dnssec = False
for zone in dns_domains: for zone in dns_domains:
if zone == domain or domain.endswith("." + zone): if (zone == domain or domain.endswith("." + zone)) and query_dns(zone, "DS", nxdomain=None) is not None:
if query_dns(zone, "DS", nxdomain=None) is not None:
has_dnssec = True has_dnssec = True
check_dnssec(zone, env, output, dns_zonefiles, is_checking_primary=True) check_dnssec(zone, env, output, dns_zonefiles, is_checking_primary=True)
@ -438,44 +433,41 @@ def check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles):
# the nameserver, are reporting the right info --- but if the glue is incorrect this # the nameserver, are reporting the right info --- but if the glue is incorrect this
# will probably fail. # will probably fail.
if ns_ips == env['PUBLIC_IP'] + '/' + env['PUBLIC_IP']: if ns_ips == env['PUBLIC_IP'] + '/' + env['PUBLIC_IP']:
output.print_ok("Nameserver glue records are correct at registrar. [ns1/ns2.%s%s]" % (env['PRIMARY_HOSTNAME'], env['PUBLIC_IP'])) output.print_ok("Nameserver glue records are correct at registrar. [ns1/ns2.{}{}]".format(env['PRIMARY_HOSTNAME'], env['PUBLIC_IP']))
elif ip == env['PUBLIC_IP']: elif ip == env['PUBLIC_IP']:
# The NS records are not what we expect, but the domain resolves correctly, so # The NS records are not what we expect, but the domain resolves correctly, so
# the user may have set up external DNS. List this discrepancy as a warning. # the user may have set up external DNS. List this discrepancy as a warning.
output.print_warning("""Nameserver glue records (ns1.%s and ns2.%s) should be configured at your domain name output.print_warning("""Nameserver glue records (ns1.{} and ns2.{}) should be configured at your domain name
registrar as having the IP address of this box (%s). They currently report addresses of %s. If you have set up External DNS, this may be OK.""" registrar as having the IP address of this box ({}). They currently report addresses of {}. If you have set up External DNS, this may be OK.""".format(env['PRIMARY_HOSTNAME'], env['PRIMARY_HOSTNAME'], env['PUBLIC_IP'], ns_ips))
% (env['PRIMARY_HOSTNAME'], env['PRIMARY_HOSTNAME'], env['PUBLIC_IP'], ns_ips))
else: else:
output.print_error("""Nameserver glue records are incorrect. The ns1.%s and ns2.%s nameservers must be configured at your domain name output.print_error("""Nameserver glue records are incorrect. The ns1.{} and ns2.{} nameservers must be configured at your domain name
registrar as having the IP address %s. They currently report addresses of %s. It may take several hours for registrar as having the IP address {}. They currently report addresses of {}. It may take several hours for
public DNS to update after a change.""" public DNS to update after a change.""".format(env['PRIMARY_HOSTNAME'], env['PRIMARY_HOSTNAME'], env['PUBLIC_IP'], ns_ips))
% (env['PRIMARY_HOSTNAME'], env['PRIMARY_HOSTNAME'], env['PUBLIC_IP'], ns_ips))
# Check that PRIMARY_HOSTNAME resolves to PUBLIC_IP[V6] in public DNS. # Check that PRIMARY_HOSTNAME resolves to PUBLIC_IP[V6] in public DNS.
ipv6 = query_dns(domain, "AAAA") if env.get("PUBLIC_IPV6") else None ipv6 = query_dns(domain, "AAAA") if env.get("PUBLIC_IPV6") else None
if ip == env['PUBLIC_IP'] and not (ipv6 and env['PUBLIC_IPV6'] and ipv6 != normalize_ip(env['PUBLIC_IPV6'])): if ip == env['PUBLIC_IP'] and not (ipv6 and env['PUBLIC_IPV6'] and ipv6 != normalize_ip(env['PUBLIC_IPV6'])):
output.print_ok("Domain resolves to box's IP address. [%s%s]" % (env['PRIMARY_HOSTNAME'], my_ips)) output.print_ok("Domain resolves to box's IP address. [{}{}]".format(env['PRIMARY_HOSTNAME'], my_ips))
else: else:
output.print_error("""This domain must resolve to your box's IP address (%s) in public DNS but it currently resolves output.print_error("""This domain must resolve to your box's IP address ({}) in public DNS but it currently resolves
to %s. It may take several hours for public DNS to update after a change. This problem may result from other to {}. It may take several hours for public DNS to update after a change. This problem may result from other
issues listed above.""" issues listed above.""".format(my_ips, ip + ((" / " + ipv6) if ipv6 is not None else "")))
% (my_ips, ip + ((" / " + ipv6) if ipv6 is not None else "")))
# Check reverse DNS matches the PRIMARY_HOSTNAME. Note that it might not be # Check reverse DNS matches the PRIMARY_HOSTNAME. Note that it might not be
# a DNS zone if it is a subdomain of another domain we have a zone for. # a DNS zone if it is a subdomain of another domain we have a zone for.
existing_rdns_v4 = query_dns(dns.reversename.from_address(env['PUBLIC_IP']), "PTR") existing_rdns_v4 = query_dns(dns.reversename.from_address(env['PUBLIC_IP']), "PTR")
existing_rdns_v6 = query_dns(dns.reversename.from_address(env['PUBLIC_IPV6']), "PTR") if env.get("PUBLIC_IPV6") else None existing_rdns_v6 = query_dns(dns.reversename.from_address(env['PUBLIC_IPV6']), "PTR") if env.get("PUBLIC_IPV6") else None
if existing_rdns_v4 == domain and existing_rdns_v6 in (None, domain): if existing_rdns_v4 == domain and existing_rdns_v6 in {None, domain}:
output.print_ok("Reverse DNS is set correctly at ISP. [%s%s]" % (my_ips, env['PRIMARY_HOSTNAME'])) output.print_ok("Reverse DNS is set correctly at ISP. [{}{}]".format(my_ips, env['PRIMARY_HOSTNAME']))
elif existing_rdns_v4 == existing_rdns_v6 or existing_rdns_v6 is None: elif existing_rdns_v4 == existing_rdns_v6 or existing_rdns_v6 is None:
output.print_error("""Your box's reverse DNS is currently %s, but it should be %s. Your ISP or cloud provider will have instructions output.print_error(f"""Your box's reverse DNS is currently {existing_rdns_v4}, but it should be {domain}. Your ISP or cloud provider will have instructions
on setting up reverse DNS for your box.""" % (existing_rdns_v4, domain) ) on setting up reverse DNS for your box.""" )
else: else:
output.print_error("""Your box's reverse DNS is currently %s (IPv4) and %s (IPv6), but it should be %s. Your ISP or cloud provider will have instructions output.print_error(f"""Your box's reverse DNS is currently {existing_rdns_v4} (IPv4) and {existing_rdns_v6} (IPv6), but it should be {domain}. Your ISP or cloud provider will have instructions
on setting up reverse DNS for your box.""" % (existing_rdns_v4, existing_rdns_v6, domain) ) on setting up reverse DNS for your box.""" )
# Check the TLSA record. # Check the TLSA record.
tlsa_qname = "_25._tcp." + domain tlsa_qname = "_25._tcp." + domain
@ -489,18 +481,17 @@ def check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles):
# since TLSA shouldn't be used without DNSSEC. # since TLSA shouldn't be used without DNSSEC.
output.print_warning("""The DANE TLSA record for incoming mail is not set. This is optional.""") output.print_warning("""The DANE TLSA record for incoming mail is not set. This is optional.""")
else: else:
output.print_error("""The DANE TLSA record for incoming mail (%s) is not correct. It is '%s' but it should be '%s'. output.print_error(f"""The DANE TLSA record for incoming mail ({tlsa_qname}) is not correct. It is '{tlsa25}' but it should be '{tlsa25_expected}'.
It may take several hours for public DNS to update after a change.""" It may take several hours for public DNS to update after a change.""")
% (tlsa_qname, tlsa25, tlsa25_expected))
# Check that the hostmaster@ email address exists. # Check that the hostmaster@ email address exists.
check_alias_exists("Hostmaster contact address", "hostmaster@" + domain, env, output) check_alias_exists("Hostmaster contact address", "hostmaster@" + domain, env, output)
def check_alias_exists(alias_name, alias, env, output): def check_alias_exists(alias_name, alias, env, output):
mail_aliases = dict([(address, receivers) for address, receivers, *_ in get_mail_aliases(env)]) mail_aliases = {address: receivers for address, receivers, *_ in get_mail_aliases(env)}
if alias in mail_aliases: if alias in mail_aliases:
if mail_aliases[alias]: if mail_aliases[alias]:
output.print_ok("%s exists as a mail alias. [%s%s]" % (alias_name, alias, mail_aliases[alias])) output.print_ok(f"{alias_name} exists as a mail alias. [{alias}{mail_aliases[alias]}]")
else: else:
output.print_error("""You must set the destination of the mail alias for %s to direct email to you or another administrator.""" % alias) output.print_error("""You must set the destination of the mail alias for %s to direct email to you or another administrator.""" % alias)
else: else:
@ -526,7 +517,7 @@ def check_dns_zone(domain, env, output, dns_zonefiles):
secondary_ns = custom_secondary_ns or ["ns2." + env['PRIMARY_HOSTNAME']] secondary_ns = custom_secondary_ns or ["ns2." + env['PRIMARY_HOSTNAME']]
existing_ns = query_dns(domain, "NS") existing_ns = query_dns(domain, "NS")
correct_ns = "; ".join(sorted(["ns1." + env['PRIMARY_HOSTNAME']] + secondary_ns)) correct_ns = "; ".join(sorted(["ns1." + env["PRIMARY_HOSTNAME"], *secondary_ns]))
ip = query_dns(domain, "A") ip = query_dns(domain, "A")
probably_external_dns = False probably_external_dns = False
@ -535,14 +526,12 @@ def check_dns_zone(domain, env, output, dns_zonefiles):
output.print_ok("Nameservers are set correctly at registrar. [%s]" % correct_ns) output.print_ok("Nameservers are set correctly at registrar. [%s]" % correct_ns)
elif ip == correct_ip: elif ip == correct_ip:
# The domain resolves correctly, so maybe the user is using External DNS. # The domain resolves correctly, so maybe the user is using External DNS.
output.print_warning("""The nameservers set on this domain at your domain name registrar should be %s. They are currently %s. output.print_warning(f"""The nameservers set on this domain at your domain name registrar should be {correct_ns}. They are currently {existing_ns}.
If you are using External DNS, this may be OK.""" If you are using External DNS, this may be OK.""" )
% (correct_ns, existing_ns) )
probably_external_dns = True probably_external_dns = True
else: else:
output.print_error("""The nameservers set on this domain are incorrect. They are currently %s. Use your domain name registrar's output.print_error(f"""The nameservers set on this domain are incorrect. They are currently {existing_ns}. Use your domain name registrar's
control panel to set the nameservers to %s.""" control panel to set the nameservers to {correct_ns}.""" )
% (existing_ns, correct_ns) )
# Check that each custom secondary nameserver resolves the IP address. # Check that each custom secondary nameserver resolves the IP address.
@ -563,7 +552,7 @@ def check_dns_zone(domain, env, output, dns_zonefiles):
elif ip is None: elif ip is None:
output.print_error("Secondary nameserver %s is not configured to resolve this domain." % ns) output.print_error("Secondary nameserver %s is not configured to resolve this domain." % ns)
else: else:
output.print_error("Secondary nameserver %s is not configured correctly. (It resolved this domain as %s. It should be %s.)" % (ns, ip, correct_ip)) output.print_error(f"Secondary nameserver {ns} is not configured correctly. (It resolved this domain as {ip}. It should be {correct_ip}.)")
def check_dns_zone_suggestions(domain, env, output, dns_zonefiles, domains_with_a_records): def check_dns_zone_suggestions(domain, env, output, dns_zonefiles, domains_with_a_records):
# Warn if a custom DNS record is preventing this or the automatic www redirect from # Warn if a custom DNS record is preventing this or the automatic www redirect from
@ -592,7 +581,7 @@ def check_dnssec(domain, env, output, dns_zonefiles, is_checking_primary=False):
expected_ds_records = { } expected_ds_records = { }
ds_file = '/etc/nsd/zones/' + dns_zonefiles[domain] + '.ds' ds_file = '/etc/nsd/zones/' + dns_zonefiles[domain] + '.ds'
if not os.path.exists(ds_file): return # Domain is in our database but DNS has not yet been updated. if not os.path.exists(ds_file): return # Domain is in our database but DNS has not yet been updated.
with open(ds_file) as f: with open(ds_file, encoding="utf-8") as f:
for rr_ds in f: for rr_ds in f:
rr_ds = rr_ds.rstrip() rr_ds = rr_ds.rstrip()
ds_keytag, ds_alg, ds_digalg, ds_digest = rr_ds.split("\t")[4].split(" ") ds_keytag, ds_alg, ds_digalg, ds_digest = rr_ds.split("\t")[4].split(" ")
@ -601,7 +590,7 @@ def check_dnssec(domain, env, output, dns_zonefiles, is_checking_primary=False):
# record that we suggest using is for the KSK (and that's how the DS records were generated). # record that we suggest using is for the KSK (and that's how the DS records were generated).
# We'll also give the nice name for the key algorithm. # We'll also give the nice name for the key algorithm.
dnssec_keys = load_env_vars_from_file(os.path.join(env['STORAGE_ROOT'], 'dns/dnssec/%s.conf' % alg_name_map[ds_alg])) dnssec_keys = load_env_vars_from_file(os.path.join(env['STORAGE_ROOT'], 'dns/dnssec/%s.conf' % alg_name_map[ds_alg]))
with open(os.path.join(env['STORAGE_ROOT'], 'dns/dnssec/' + dnssec_keys['KSK'] + '.key'), 'r') as f: with open(os.path.join(env['STORAGE_ROOT'], 'dns/dnssec/' + dnssec_keys['KSK'] + '.key'), encoding="utf-8") as f:
dnsssec_pubkey = f.read().split("\t")[3].split(" ")[3] dnsssec_pubkey = f.read().split("\t")[3].split(" ")[3]
expected_ds_records[ (ds_keytag, ds_alg, ds_digalg, ds_digest) ] = { expected_ds_records[ (ds_keytag, ds_alg, ds_digalg, ds_digest) ] = {
@ -634,10 +623,10 @@ def check_dnssec(domain, env, output, dns_zonefiles, is_checking_primary=False):
# #
# But it may not be preferred. Only algorithm 13 is preferred. Warn if any of the # But it may not be preferred. Only algorithm 13 is preferred. Warn if any of the
# matched zones uses a different algorithm. # matched zones uses a different algorithm.
if set(r[1] for r in matched_ds) == { '13' } and set(r[2] for r in matched_ds) <= { '2', '4' }: # all are alg 13 and digest type 2 or 4 if {r[1] for r in matched_ds} == { '13' } and {r[2] for r in matched_ds} <= { '2', '4' }: # all are alg 13 and digest type 2 or 4
output.print_ok("DNSSEC 'DS' record is set correctly at registrar.") output.print_ok("DNSSEC 'DS' record is set correctly at registrar.")
return return
elif len([r for r in matched_ds if r[1] == '13' and r[2] in ( '2', '4' )]) > 0: # some but not all are alg 13 elif len([r for r in matched_ds if r[1] == '13' and r[2] in { '2', '4' }]) > 0: # some but not all are alg 13
output.print_ok("DNSSEC 'DS' record is set correctly at registrar. (Records using algorithm other than ECDSAP256SHA256 and digest types other than SHA-256/384 should be removed.)") output.print_ok("DNSSEC 'DS' record is set correctly at registrar. (Records using algorithm other than ECDSAP256SHA256 and digest types other than SHA-256/384 should be removed.)")
return return
else: # no record uses alg 13 else: # no record uses alg 13
@ -669,8 +658,8 @@ def check_dnssec(domain, env, output, dns_zonefiles, is_checking_primary=False):
output.print_line("----------") output.print_line("----------")
output.print_line("Key Tag: " + ds_suggestion['keytag']) output.print_line("Key Tag: " + ds_suggestion['keytag'])
output.print_line("Key Flags: KSK / 257") output.print_line("Key Flags: KSK / 257")
output.print_line("Algorithm: %s / %s" % (ds_suggestion['alg'], ds_suggestion['alg_name'])) output.print_line("Algorithm: {} / {}".format(ds_suggestion['alg'], ds_suggestion['alg_name']))
output.print_line("Digest Type: %s / %s" % (ds_suggestion['digalg'], ds_suggestion['digalg_name'])) output.print_line("Digest Type: {} / {}".format(ds_suggestion['digalg'], ds_suggestion['digalg_name']))
output.print_line("Digest: " + ds_suggestion['digest']) output.print_line("Digest: " + ds_suggestion['digest'])
output.print_line("Public Key: ") output.print_line("Public Key: ")
output.print_line(ds_suggestion['pubkey'], monospace=True) output.print_line(ds_suggestion['pubkey'], monospace=True)
@ -681,7 +670,7 @@ def check_dnssec(domain, env, output, dns_zonefiles, is_checking_primary=False):
output.print_line("") output.print_line("")
output.print_line("The DS record is currently set to:") output.print_line("The DS record is currently set to:")
for rr in sorted(ds): for rr in sorted(ds):
output.print_line("Key Tag: {0}, Algorithm: {1}, Digest Type: {2}, Digest: {3}".format(*rr)) output.print_line("Key Tag: {}, Algorithm: {}, Digest Type: {}, Digest: {}".format(*rr))
def check_mail_domain(domain, env, output): def check_mail_domain(domain, env, output):
# Check the MX record. # Check the MX record.
@ -689,21 +678,19 @@ def check_mail_domain(domain, env, output):
recommended_mx = "10 " + env['PRIMARY_HOSTNAME'] recommended_mx = "10 " + env['PRIMARY_HOSTNAME']
mx = query_dns(domain, "MX", nxdomain=None) mx = query_dns(domain, "MX", nxdomain=None)
if mx is None: if mx is None or mx == "[timeout]":
mxhost = None
elif mx == "[timeout]":
mxhost = None mxhost = None
else: else:
# query_dns returns a semicolon-delimited list # query_dns returns a semicolon-delimited list
# of priority-host pairs. # of priority-host pairs.
mxhost = mx.split('; ')[0].split(' ')[1] mxhost = mx.split('; ')[0].split(' ')[1]
if mxhost == None: if mxhost is None:
# A missing MX record is okay on the primary hostname because # A missing MX record is okay on the primary hostname because
# the primary hostname's A record (the MX fallback) is... itself, # the primary hostname's A record (the MX fallback) is... itself,
# which is what we want the MX to be. # which is what we want the MX to be.
if domain == env['PRIMARY_HOSTNAME']: if domain == env['PRIMARY_HOSTNAME']:
output.print_ok("Domain's email is directed to this domain. [%s has no MX record, which is ok]" % (domain,)) output.print_ok(f"Domain's email is directed to this domain. [{domain} has no MX record, which is ok]")
# And a missing MX record is okay on other domains if the A record # And a missing MX record is okay on other domains if the A record
# matches the A record of the PRIMARY_HOSTNAME. Actually this will # matches the A record of the PRIMARY_HOSTNAME. Actually this will
@ -711,17 +698,17 @@ def check_mail_domain(domain, env, output):
else: else:
domain_a = query_dns(domain, "A", nxdomain=None) domain_a = query_dns(domain, "A", nxdomain=None)
primary_a = query_dns(env['PRIMARY_HOSTNAME'], "A", nxdomain=None) primary_a = query_dns(env['PRIMARY_HOSTNAME'], "A", nxdomain=None)
if domain_a != None and domain_a == primary_a: if domain_a is not None and domain_a == primary_a:
output.print_ok("Domain's email is directed to this domain. [%s has no MX record but its A record is OK]" % (domain,)) output.print_ok(f"Domain's email is directed to this domain. [{domain} has no MX record but its A record is OK]")
else: else:
output.print_error("""This domain's DNS MX record is not set. It should be '%s'. Mail will not output.print_error(f"""This domain's DNS MX record is not set. It should be '{recommended_mx}'. Mail will not
be delivered to this box. It may take several hours for public DNS to update after a be delivered to this box. It may take several hours for public DNS to update after a
change. This problem may result from other issues listed here.""" % (recommended_mx,)) change. This problem may result from other issues listed here.""")
elif mxhost == env['PRIMARY_HOSTNAME']: elif mxhost == env['PRIMARY_HOSTNAME']:
good_news = "Domain's email is directed to this domain. [%s%s]" % (domain, mx) good_news = f"Domain's email is directed to this domain. [{domain}{mx}]"
if mx != recommended_mx: if mx != recommended_mx:
good_news += " This configuration is non-standard. The recommended configuration is '%s'." % (recommended_mx,) good_news += f" This configuration is non-standard. The recommended configuration is '{recommended_mx}'."
output.print_ok(good_news) output.print_ok(good_news)
# Check MTA-STS policy. # Check MTA-STS policy.
@ -732,14 +719,14 @@ def check_mail_domain(domain, env, output):
if policy[1].get("mx") == [env['PRIMARY_HOSTNAME']] and policy[1].get("mode") == "enforce": # policy[0] is the policyid if policy[1].get("mx") == [env['PRIMARY_HOSTNAME']] and policy[1].get("mode") == "enforce": # policy[0] is the policyid
output.print_ok("MTA-STS policy is present.") output.print_ok("MTA-STS policy is present.")
else: else:
output.print_error("MTA-STS policy is present but has unexpected settings. [{}]".format(policy[1])) output.print_error(f"MTA-STS policy is present but has unexpected settings. [{policy[1]}]")
else: else:
output.print_error("MTA-STS policy is missing: {}".format(valid)) output.print_error(f"MTA-STS policy is missing: {valid}")
else: else:
output.print_error("""This domain's DNS MX record is incorrect. It is currently set to '%s' but should be '%s'. Mail will not output.print_error(f"""This domain's DNS MX record is incorrect. It is currently set to '{mx}' but should be '{recommended_mx}'. Mail will not
be delivered to this box. It may take several hours for public DNS to update after a change. This problem may result from be delivered to this box. It may take several hours for public DNS to update after a change. This problem may result from
other issues listed here.""" % (mx, recommended_mx)) other issues listed here.""")
# Check that the postmaster@ email address exists. Not required if the domain has a # Check that the postmaster@ email address exists. Not required if the domain has a
# catch-all address or domain alias. # catch-all address or domain alias.
@ -753,13 +740,13 @@ def check_mail_domain(domain, env, output):
if dbl is None: if dbl is None:
output.print_ok("Domain is not blacklisted by dbl.spamhaus.org.") output.print_ok("Domain is not blacklisted by dbl.spamhaus.org.")
elif dbl == "[timeout]": elif dbl == "[timeout]":
output.print_warning("Connection to dbl.spamhaus.org timed out. We could not determine whether the domain {} is blacklisted. Please try again later.".format(domain)) output.print_warning(f"Connection to dbl.spamhaus.org timed out. We could not determine whether the domain {domain} is blacklisted. Please try again later.")
elif dbl == "[Not Set]": elif dbl == "[Not Set]":
output.print_warning("Could not connect to dbl.spamhaus.org. We could not determine whether the domain {} is blacklisted. Please try again later.".format(domain)) output.print_warning(f"Could not connect to dbl.spamhaus.org. We could not determine whether the domain {domain} is blacklisted. Please try again later.")
else: else:
output.print_error("""This domain is listed in the Spamhaus Domain Block List (code %s), output.print_error(f"""This domain is listed in the Spamhaus Domain Block List (code {dbl}),
which may prevent recipients from receiving your mail. which may prevent recipients from receiving your mail.
See http://www.spamhaus.org/dbl/ and http://www.spamhaus.org/query/domain/%s.""" % (dbl, domain)) See http://www.spamhaus.org/dbl/ and http://www.spamhaus.org/query/domain/{domain}.""")
def check_web_domain(domain, rounded_time, ssl_certificates, env, output): def check_web_domain(domain, rounded_time, ssl_certificates, env, output):
# See if the domain's A record resolves to our PUBLIC_IP. This is already checked # See if the domain's A record resolves to our PUBLIC_IP. This is already checked
@ -773,13 +760,13 @@ def check_web_domain(domain, rounded_time, ssl_certificates, env, output):
if value == normalize_ip(expected): if value == normalize_ip(expected):
ok_values.append(value) ok_values.append(value)
else: else:
output.print_error("""This domain should resolve to your box's IP address (%s %s) if you would like the box to serve output.print_error(f"""This domain should resolve to your box's IP address ({rtype} {expected}) if you would like the box to serve
webmail or a website on this domain. The domain currently resolves to %s in public DNS. It may take several hours for webmail or a website on this domain. The domain currently resolves to {value} in public DNS. It may take several hours for
public DNS to update after a change. This problem may result from other issues listed here.""" % (rtype, expected, value)) public DNS to update after a change. This problem may result from other issues listed here.""")
return return
# If both A and AAAA are correct... # If both A and AAAA are correct...
output.print_ok("Domain resolves to this box's IP address. [%s%s]" % (domain, '; '.join(ok_values))) output.print_ok("Domain resolves to this box's IP address. [{}{}]".format(domain, '; '.join(ok_values)))
# We need a TLS certificate for PRIMARY_HOSTNAME because that's where the # We need a TLS certificate for PRIMARY_HOSTNAME because that's where the
@ -826,7 +813,7 @@ def query_dns(qname, rtype, nxdomain='[Not Set]', at=None, as_list=False):
# be expressed in equivalent string forms. Canonicalize the form before # be expressed in equivalent string forms. Canonicalize the form before
# returning them. The caller should normalize any IP addresses the result # returning them. The caller should normalize any IP addresses the result
# of this method is compared with. # of this method is compared with.
if rtype in ("A", "AAAA"): if rtype in {"A", "AAAA"}:
response = [normalize_ip(str(r)) for r in response] response = [normalize_ip(str(r)) for r in response]
if as_list: if as_list:
@ -842,7 +829,7 @@ def check_ssl_cert(domain, rounded_time, ssl_certificates, env, output):
# Check that TLS certificate is signed. # Check that TLS certificate is signed.
# Skip the check if the A record is not pointed here. # Skip the check if the A record is not pointed here.
if query_dns(domain, "A", None) not in (env['PUBLIC_IP'], None): return if query_dns(domain, "A", None) not in {env['PUBLIC_IP'], None}: return
# Where is the certificate file stored? # Where is the certificate file stored?
tls_cert = get_domain_ssl_files(domain, ssl_certificates, env, allow_missing_cert=True) tls_cert = get_domain_ssl_files(domain, ssl_certificates, env, allow_missing_cert=True)
@ -916,18 +903,16 @@ def what_version_is_this(env):
# Git may not be installed and Mail-in-a-Box may not have been cloned from github, # Git may not be installed and Mail-in-a-Box may not have been cloned from github,
# so this function may raise all sorts of exceptions. # so this function may raise all sorts of exceptions.
miab_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) miab_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
tag = shell("check_output", ["/usr/bin/git", "describe", "--always", "--abbrev=0"], env={"GIT_DIR": os.path.join(miab_dir, '.git')}).strip() return shell("check_output", ["/usr/bin/git", "describe", "--always", "--abbrev=0"], env={"GIT_DIR": os.path.join(miab_dir, '.git')}).strip()
return tag
def get_latest_miab_version(): def get_latest_miab_version():
# This pings https://mailinabox.email/setup.sh and extracts the tag named in # This pings https://mailinabox.email/setup.sh and extracts the tag named in
# the script to determine the current product version. # the script to determine the current product version.
from urllib.request import urlopen, HTTPError, URLError from urllib.request import urlopen, HTTPError, URLError
from socket import timeout
try: try:
return re.search(b'TAG=(.*)', urlopen("https://mailinabox.email/setup.sh?ping=1", timeout=5).read()).group(1).decode("utf8") return re.search(b'TAG=(.*)', urlopen("https://mailinabox.email/setup.sh?ping=1", timeout=5).read()).group(1).decode("utf8")
except (HTTPError, URLError, timeout): except (TimeoutError, HTTPError, URLError):
return None return None
def check_miab_version(env, output): def check_miab_version(env, output):
@ -948,8 +933,7 @@ def check_miab_version(env, output):
elif latest_ver is None: elif latest_ver is None:
output.print_error("Latest Mail-in-a-Box version could not be determined. You are running version %s." % this_ver) output.print_error("Latest Mail-in-a-Box version could not be determined. You are running version %s." % this_ver)
else: else:
output.print_error("A new version of Mail-in-a-Box is available. You are running version %s. The latest version is %s. For upgrade instructions, see https://mailinabox.email. " output.print_error(f"A new version of Mail-in-a-Box is available. You are running version {this_ver}. The latest version is {latest_ver}. For upgrade instructions, see https://mailinabox.email. ")
% (this_ver, latest_ver))
def run_and_output_changes(env, pool): def run_and_output_changes(env, pool):
import json import json
@ -964,7 +948,7 @@ def run_and_output_changes(env, pool):
# Load previously saved status checks. # Load previously saved status checks.
cache_fn = "/var/cache/mailinabox/status_checks.json" cache_fn = "/var/cache/mailinabox/status_checks.json"
if os.path.exists(cache_fn): if os.path.exists(cache_fn):
with open(cache_fn, 'r') as f: with open(cache_fn, encoding="utf-8") as f:
try: try:
prev = json.load(f) prev = json.load(f)
except json.JSONDecodeError: except json.JSONDecodeError:
@ -1003,14 +987,14 @@ def run_and_output_changes(env, pool):
out.add_heading(category + " -- Previously:") out.add_heading(category + " -- Previously:")
elif op == "delete": elif op == "delete":
out.add_heading(category + " -- Removed") out.add_heading(category + " -- Removed")
if op in ("replace", "delete"): if op in {"replace", "delete"}:
BufferedOutput(with_lines=prev_lines[i1:i2]).playback(out) BufferedOutput(with_lines=prev_lines[i1:i2]).playback(out)
if op == "replace": if op == "replace":
out.add_heading(category + " -- Currently:") out.add_heading(category + " -- Currently:")
elif op == "insert": elif op == "insert":
out.add_heading(category + " -- Added") out.add_heading(category + " -- Added")
if op in ("replace", "insert"): if op in {"replace", "insert"}:
BufferedOutput(with_lines=cur_lines[j1:j2]).playback(out) BufferedOutput(with_lines=cur_lines[j1:j2]).playback(out)
for category, prev_lines in prev_status.items(): for category, prev_lines in prev_status.items():
@ -1020,7 +1004,7 @@ def run_and_output_changes(env, pool):
# Store the current status checks output for next time. # Store the current status checks output for next time.
os.makedirs(os.path.dirname(cache_fn), exist_ok=True) os.makedirs(os.path.dirname(cache_fn), exist_ok=True)
with open(cache_fn, "w") as f: with open(cache_fn, "w", encoding="utf-8") as f:
json.dump(cur.buf, f, indent=True) json.dump(cur.buf, f, indent=True)
def normalize_ip(ip): def normalize_ip(ip):
@ -1054,8 +1038,8 @@ class FileOutput:
def print_block(self, message, first_line=" "): def print_block(self, message, first_line=" "):
print(first_line, end='', file=self.buf) print(first_line, end='', file=self.buf)
message = re.sub("\n\s*", " ", message) message = re.sub("\n\\s*", " ", message)
words = re.split("(\s+)", message) words = re.split(r"(\s+)", message)
linelen = 0 linelen = 0
for w in words: for w in words:
if self.width and (linelen + len(w) > self.width-1-len(first_line)): if self.width and (linelen + len(w) > self.width-1-len(first_line)):
@ -1094,9 +1078,9 @@ class ConsoleOutput(FileOutput):
class BufferedOutput: class BufferedOutput:
# Record all of the instance method calls so we can play them back later. # Record all of the instance method calls so we can play them back later.
def __init__(self, with_lines=None): def __init__(self, with_lines=None):
self.buf = [] if not with_lines else with_lines self.buf = with_lines if with_lines else []
def __getattr__(self, attr): def __getattr__(self, attr):
if attr not in ("add_heading", "print_ok", "print_error", "print_warning", "print_block", "print_line"): if attr not in {"add_heading", "print_ok", "print_error", "print_warning", "print_block", "print_line"}:
raise AttributeError raise AttributeError
# Return a function that just records the call & arguments to our buffer. # Return a function that just records the call & arguments to our buffer.
def w(*args, **kwargs): def w(*args, **kwargs):

View File

@ -14,31 +14,31 @@ def load_env_vars_from_file(fn):
# Load settings from a KEY=VALUE file. # Load settings from a KEY=VALUE file.
import collections import collections
env = collections.OrderedDict() env = collections.OrderedDict()
with open(fn, 'r') as f: with open(fn, encoding="utf-8") as f:
for line in f: for line in f:
env.setdefault(*line.strip().split("=", 1)) env.setdefault(*line.strip().split("=", 1))
return env return env
def save_environment(env): def save_environment(env):
with open("/etc/mailinabox.conf", "w") as f: with open("/etc/mailinabox.conf", "w", encoding="utf-8") as f:
for k, v in env.items(): for k, v in env.items():
f.write("%s=%s\n" % (k, v)) f.write(f"{k}={v}\n")
# THE SETTINGS FILE AT STORAGE_ROOT/settings.yaml. # THE SETTINGS FILE AT STORAGE_ROOT/settings.yaml.
def write_settings(config, env): def write_settings(config, env):
import rtyaml import rtyaml
fn = os.path.join(env['STORAGE_ROOT'], 'settings.yaml') fn = os.path.join(env['STORAGE_ROOT'], 'settings.yaml')
with open(fn, "w") as f: with open(fn, "w", encoding="utf-8") as f:
f.write(rtyaml.dump(config)) f.write(rtyaml.dump(config))
def load_settings(env): def load_settings(env):
import rtyaml import rtyaml
fn = os.path.join(env['STORAGE_ROOT'], 'settings.yaml') fn = os.path.join(env['STORAGE_ROOT'], 'settings.yaml')
try: try:
with open(fn, "r") as f: with open(fn, encoding="utf-8") as f:
config = rtyaml.load(f) config = rtyaml.load(f)
if not isinstance(config, dict): raise ValueError() # caught below if not isinstance(config, dict): raise ValueError # caught below
return config return config
except: except:
return { } return { }
@ -59,7 +59,7 @@ def sort_domains(domain_names, env):
# from shortest to longest since zones are always shorter than their # from shortest to longest since zones are always shorter than their
# subdomains. # subdomains.
zones = { } zones = { }
for domain in sorted(domain_names, key=lambda d : len(d)): for domain in sorted(domain_names, key=len):
for z in zones.values(): for z in zones.values():
if domain.endswith("." + z): if domain.endswith("." + z):
# We found a parent domain already in the list. # We found a parent domain already in the list.
@ -81,7 +81,7 @@ def sort_domains(domain_names, env):
)) ))
# Now sort the domain names that fall within each zone. # Now sort the domain names that fall within each zone.
domain_names = sorted(domain_names, return sorted(domain_names,
key = lambda d : ( key = lambda d : (
# First by zone. # First by zone.
zone_domains.index(zones[d]), zone_domains.index(zones[d]),
@ -96,24 +96,25 @@ def sort_domains(domain_names, env):
list(reversed(d.split("."))), list(reversed(d.split("."))),
)) ))
return domain_names
def sort_email_addresses(email_addresses, env): def sort_email_addresses(email_addresses, env):
email_addresses = set(email_addresses) email_addresses = set(email_addresses)
domains = set(email.split("@", 1)[1] for email in email_addresses if "@" in email) domains = {email.split("@", 1)[1] for email in email_addresses if "@" in email}
ret = [] ret = []
for domain in sort_domains(domains, env): for domain in sort_domains(domains, env):
domain_emails = set(email for email in email_addresses if email.endswith("@" + domain)) domain_emails = {email for email in email_addresses if email.endswith("@" + domain)}
ret.extend(sorted(domain_emails)) ret.extend(sorted(domain_emails))
email_addresses -= domain_emails email_addresses -= domain_emails
ret.extend(sorted(email_addresses)) # whatever is left ret.extend(sorted(email_addresses)) # whatever is left
return ret return ret
def shell(method, cmd_args, env={}, capture_stderr=False, return_bytes=False, trap=False, input=None): def shell(method, cmd_args, env=None, capture_stderr=False, return_bytes=False, trap=False, input=None):
# A safe way to execute processes. # A safe way to execute processes.
# Some processes like apt-get require being given a sane PATH. # Some processes like apt-get require being given a sane PATH.
import subprocess import subprocess
if env is None:
env = {}
env.update({ "PATH": "/sbin:/bin:/usr/sbin:/usr/bin" }) env.update({ "PATH": "/sbin:/bin:/usr/sbin:/usr/bin" })
kwargs = { kwargs = {
'env': env, 'env': env,
@ -149,7 +150,7 @@ def du(path):
# soft and hard links. # soft and hard links.
total_size = 0 total_size = 0
seen = set() seen = set()
for dirpath, dirnames, filenames in os.walk(path): for dirpath, _dirnames, filenames in os.walk(path):
for f in filenames: for f in filenames:
fp = os.path.join(dirpath, f) fp = os.path.join(dirpath, f)
try: try:

View File

@ -22,17 +22,17 @@ def get_web_domains(env, include_www_redirects=True, include_auto=True, exclude_
# Add 'www.' subdomains that we want to provide default redirects # Add 'www.' subdomains that we want to provide default redirects
# to the main domain for. We'll add 'www.' to any DNS zones, i.e. # to the main domain for. We'll add 'www.' to any DNS zones, i.e.
# the topmost of each domain we serve. # the topmost of each domain we serve.
domains |= set('www.' + zone for zone, zonefile in get_dns_zones(env)) domains |= {'www.' + zone for zone, zonefile in get_dns_zones(env)}
if include_auto: if include_auto:
# Add Autoconfiguration domains for domains that there are user accounts at: # Add Autoconfiguration domains for domains that there are user accounts at:
# 'autoconfig.' for Mozilla Thunderbird auto setup. # 'autoconfig.' for Mozilla Thunderbird auto setup.
# 'autodiscover.' for ActiveSync autodiscovery (Z-Push). # 'autodiscover.' for ActiveSync autodiscovery (Z-Push).
domains |= set('autoconfig.' + maildomain for maildomain in get_mail_domains(env, users_only=True)) domains |= {'autoconfig.' + maildomain for maildomain in get_mail_domains(env, users_only=True)}
domains |= set('autodiscover.' + maildomain for maildomain in get_mail_domains(env, users_only=True)) domains |= {'autodiscover.' + maildomain for maildomain in get_mail_domains(env, users_only=True)}
# 'mta-sts.' for MTA-STS support for all domains that have email addresses. # 'mta-sts.' for MTA-STS support for all domains that have email addresses.
domains |= set('mta-sts.' + maildomain for maildomain in get_mail_domains(env)) domains |= {'mta-sts.' + maildomain for maildomain in get_mail_domains(env)}
if exclude_dns_elsewhere: if exclude_dns_elsewhere:
# ...Unless the domain has an A/AAAA record that maps it to a different # ...Unless the domain has an A/AAAA record that maps it to a different
@ -45,15 +45,14 @@ def get_web_domains(env, include_www_redirects=True, include_auto=True, exclude_
domains.add(env['PRIMARY_HOSTNAME']) domains.add(env['PRIMARY_HOSTNAME'])
# Sort the list so the nginx conf gets written in a stable order. # Sort the list so the nginx conf gets written in a stable order.
domains = sort_domains(domains, env) return sort_domains(domains, env)
return domains
def get_domains_with_a_records(env): def get_domains_with_a_records(env):
domains = set() domains = set()
dns = get_custom_dns_config(env) dns = get_custom_dns_config(env)
for domain, rtype, value in dns: for domain, rtype, value in dns:
if rtype == "CNAME" or (rtype in ("A", "AAAA") and value not in ("local", env['PUBLIC_IP'])): if rtype == "CNAME" or (rtype in {"A", "AAAA"} and value not in {"local", env['PUBLIC_IP']}):
domains.add(domain) domains.add(domain)
return domains return domains
@ -63,7 +62,7 @@ def get_web_domains_with_root_overrides(env):
root_overrides = { } root_overrides = { }
nginx_conf_custom_fn = os.path.join(env["STORAGE_ROOT"], "www/custom.yaml") nginx_conf_custom_fn = os.path.join(env["STORAGE_ROOT"], "www/custom.yaml")
if os.path.exists(nginx_conf_custom_fn): if os.path.exists(nginx_conf_custom_fn):
with open(nginx_conf_custom_fn, 'r') as f: with open(nginx_conf_custom_fn, encoding='utf-8') as f:
custom_settings = rtyaml.load(f) custom_settings = rtyaml.load(f)
for domain, settings in custom_settings.items(): for domain, settings in custom_settings.items():
for type, value in [('redirect', settings.get('redirects', {}).get('/')), for type, value in [('redirect', settings.get('redirects', {}).get('/')),
@ -78,7 +77,7 @@ def do_web_update(env):
# Helper for reading config files and templates # Helper for reading config files and templates
def read_conf(conf_fn): def read_conf(conf_fn):
with open(os.path.join(os.path.dirname(__file__), "../conf", conf_fn), "r") as f: with open(os.path.join(os.path.dirname(__file__), "../conf", conf_fn), encoding='utf-8') as f:
return f.read() return f.read()
# Build an nginx configuration file. # Build an nginx configuration file.
@ -113,12 +112,12 @@ def do_web_update(env):
# Did the file change? If not, don't bother writing & restarting nginx. # Did the file change? If not, don't bother writing & restarting nginx.
nginx_conf_fn = "/etc/nginx/conf.d/local.conf" nginx_conf_fn = "/etc/nginx/conf.d/local.conf"
if os.path.exists(nginx_conf_fn): if os.path.exists(nginx_conf_fn):
with open(nginx_conf_fn) as f: with open(nginx_conf_fn, encoding='utf-8') as f:
if f.read() == nginx_conf: if f.read() == nginx_conf:
return "" return ""
# Save the file. # Save the file.
with open(nginx_conf_fn, "w") as f: with open(nginx_conf_fn, "w", encoding='utf-8') as f:
f.write(nginx_conf) f.write(nginx_conf)
# Kick nginx. Since this might be called from the web admin # Kick nginx. Since this might be called from the web admin
@ -150,13 +149,13 @@ def make_domain_config(domain, templates, ssl_certificates, env):
with open(filepath, 'rb') as f: with open(filepath, 'rb') as f:
sha1.update(f.read()) sha1.update(f.read())
return sha1.hexdigest() return sha1.hexdigest()
nginx_conf_extra += "\t# ssl files sha1: %s / %s\n" % (hashfile(tls_cert["private-key"]), hashfile(tls_cert["certificate"])) nginx_conf_extra += "\t# ssl files sha1: {} / {}\n".format(hashfile(tls_cert["private-key"]), hashfile(tls_cert["certificate"]))
# Add in any user customizations in YAML format. # Add in any user customizations in YAML format.
hsts = "yes" hsts = "yes"
nginx_conf_custom_fn = os.path.join(env["STORAGE_ROOT"], "www/custom.yaml") nginx_conf_custom_fn = os.path.join(env["STORAGE_ROOT"], "www/custom.yaml")
if os.path.exists(nginx_conf_custom_fn): if os.path.exists(nginx_conf_custom_fn):
with open(nginx_conf_custom_fn, 'r') as f: with open(nginx_conf_custom_fn, encoding='utf-8') as f:
yaml = rtyaml.load(f) yaml = rtyaml.load(f)
if domain in yaml: if domain in yaml:
yaml = yaml[domain] yaml = yaml[domain]
@ -196,16 +195,16 @@ def make_domain_config(domain, templates, ssl_certificates, env):
nginx_conf_extra += "\n\t\talias %s;" % alias nginx_conf_extra += "\n\t\talias %s;" % alias
nginx_conf_extra += "\n\t}\n" nginx_conf_extra += "\n\t}\n"
for path, url in yaml.get("redirects", {}).items(): for path, url in yaml.get("redirects", {}).items():
nginx_conf_extra += "\trewrite %s %s permanent;\n" % (path, url) nginx_conf_extra += f"\trewrite {path} {url} permanent;\n"
# override the HSTS directive type # override the HSTS directive type
hsts = yaml.get("hsts", hsts) hsts = yaml.get("hsts", hsts)
# Add the HSTS header. # Add the HSTS header.
if hsts == "yes": if hsts == "yes":
nginx_conf_extra += "\tadd_header Strict-Transport-Security \"max-age=15768000\" always;\n" nginx_conf_extra += '\tadd_header Strict-Transport-Security "max-age=15768000" always;\n'
elif hsts == "preload": elif hsts == "preload":
nginx_conf_extra += "\tadd_header Strict-Transport-Security \"max-age=15768000; includeSubDomains; preload\" always;\n" nginx_conf_extra += '\tadd_header Strict-Transport-Security "max-age=15768000; includeSubDomains; preload" always;\n'
# Add in any user customizations in the includes/ folder. # Add in any user customizations in the includes/ folder.
nginx_conf_custom_include = os.path.join(env["STORAGE_ROOT"], "www", safe_domain_name(domain) + ".conf") nginx_conf_custom_include = os.path.join(env["STORAGE_ROOT"], "www", safe_domain_name(domain) + ".conf")
@ -216,7 +215,7 @@ def make_domain_config(domain, templates, ssl_certificates, env):
# Combine the pieces. Iteratively place each template into the "# ADDITIONAL DIRECTIVES HERE" placeholder # Combine the pieces. Iteratively place each template into the "# ADDITIONAL DIRECTIVES HERE" placeholder
# of the previous template. # of the previous template.
nginx_conf = "# ADDITIONAL DIRECTIVES HERE\n" nginx_conf = "# ADDITIONAL DIRECTIVES HERE\n"
for t in templates + [nginx_conf_extra]: for t in [*templates, nginx_conf_extra]:
nginx_conf = re.sub("[ \t]*# ADDITIONAL DIRECTIVES HERE *\n", t, nginx_conf) nginx_conf = re.sub("[ \t]*# ADDITIONAL DIRECTIVES HERE *\n", t, nginx_conf)
# Replace substitution strings in the template & return. # Replace substitution strings in the template & return.
@ -225,9 +224,8 @@ def make_domain_config(domain, templates, ssl_certificates, env):
nginx_conf = nginx_conf.replace("$ROOT", root) nginx_conf = nginx_conf.replace("$ROOT", root)
nginx_conf = nginx_conf.replace("$SSL_KEY", tls_cert["private-key"]) nginx_conf = nginx_conf.replace("$SSL_KEY", tls_cert["private-key"])
nginx_conf = nginx_conf.replace("$SSL_CERTIFICATE", tls_cert["certificate"]) nginx_conf = nginx_conf.replace("$SSL_CERTIFICATE", tls_cert["certificate"])
nginx_conf = nginx_conf.replace("$REDIRECT_DOMAIN", re.sub(r"^www\.", "", domain)) # for default www redirects to parent domain return nginx_conf.replace("$REDIRECT_DOMAIN", re.sub(r"^www\.", "", domain)) # for default www redirects to parent domain
return nginx_conf
def get_web_root(domain, env, test_exists=True): def get_web_root(domain, env, test_exists=True):
# Try STORAGE_ROOT/web/domain_name if it exists, but fall back to STORAGE_ROOT/web/default. # Try STORAGE_ROOT/web/domain_name if it exists, but fall back to STORAGE_ROOT/web/default.

View File

@ -1,5 +1,5 @@
from daemon import app from daemon import app
import auth, utils import utils
app.logger.addHandler(utils.create_syslog_handler()) app.logger.addHandler(utils.create_syslog_handler())

View File

@ -9,6 +9,7 @@ import sys, os, os.path, glob, re, shutil
sys.path.insert(0, 'management') sys.path.insert(0, 'management')
from utils import load_environment, save_environment, shell from utils import load_environment, save_environment, shell
import contextlib
def migration_1(env): def migration_1(env):
# Re-arrange where we store SSL certificates. There was a typo also. # Re-arrange where we store SSL certificates. There was a typo also.
@ -31,10 +32,8 @@ def migration_1(env):
move_file(sslfn, domain_name, file_type) move_file(sslfn, domain_name, file_type)
# Move the old domains directory if it is now empty. # Move the old domains directory if it is now empty.
try: with contextlib.suppress(Exception):
os.rmdir(os.path.join( env["STORAGE_ROOT"], 'ssl/domains')) os.rmdir(os.path.join( env["STORAGE_ROOT"], 'ssl/domains'))
except:
pass
def migration_2(env): def migration_2(env):
# Delete the .dovecot_sieve script everywhere. This was formerly a copy of our spam -> Spam # Delete the .dovecot_sieve script everywhere. This was formerly a copy of our spam -> Spam
@ -168,7 +167,7 @@ def migration_12(env):
dropcmd = "DROP TABLE %s" % table dropcmd = "DROP TABLE %s" % table
c.execute(dropcmd) c.execute(dropcmd)
except: except:
print("Failed to drop table", table, e) print("Failed to drop table", table)
# Save. # Save.
conn.commit() conn.commit()
conn.close() conn.close()
@ -212,8 +211,8 @@ def run_migrations():
migration_id_file = os.path.join(env['STORAGE_ROOT'], 'mailinabox.version') migration_id_file = os.path.join(env['STORAGE_ROOT'], 'mailinabox.version')
migration_id = None migration_id = None
if os.path.exists(migration_id_file): if os.path.exists(migration_id_file):
with open(migration_id_file) as f: with open(migration_id_file, encoding='utf-8') as f:
migration_id = f.read().strip(); migration_id = f.read().strip()
if migration_id is None: if migration_id is None:
# Load the legacy location of the migration ID. We'll drop support # Load the legacy location of the migration ID. We'll drop support
@ -222,7 +221,7 @@ def run_migrations():
if migration_id is None: if migration_id is None:
print() print()
print("%s file doesn't exists. Skipping migration..." % (migration_id_file,)) print(f"{migration_id_file} file doesn't exists. Skipping migration...")
return return
ourver = int(migration_id) ourver = int(migration_id)
@ -253,7 +252,7 @@ def run_migrations():
# Write out our current version now. Do this sooner rather than later # Write out our current version now. Do this sooner rather than later
# in case of any problems. # in case of any problems.
with open(migration_id_file, "w") as f: with open(migration_id_file, "w", encoding='utf-8') as f:
f.write(str(ourver) + "\n") f.write(str(ourver) + "\n")
# Delete the legacy location of this field. # Delete the legacy location of this field.

View File

@ -6,12 +6,12 @@
# try to log in to. # try to log in to.
###################################################################### ######################################################################
import sys, os, time, functools import sys, os, time
# parse command line # parse command line
if len(sys.argv) != 4: if len(sys.argv) != 4:
print("Usage: tests/fail2ban.py \"ssh user@hostname\" hostname owncloud_user") print('Usage: tests/fail2ban.py "ssh user@hostname" hostname owncloud_user')
sys.exit(1) sys.exit(1)
ssh_command, hostname, owncloud_user = sys.argv[1:4] ssh_command, hostname, owncloud_user = sys.argv[1:4]
@ -24,7 +24,6 @@ socket.setdefaulttimeout(10)
class IsBlocked(Exception): class IsBlocked(Exception):
"""Tests raise this exception when it appears that a fail2ban """Tests raise this exception when it appears that a fail2ban
jail is in effect, i.e. on a connection refused error.""" jail is in effect, i.e. on a connection refused error."""
pass
def smtp_test(): def smtp_test():
import smtplib import smtplib
@ -33,13 +32,14 @@ def smtp_test():
server = smtplib.SMTP(hostname, 587) server = smtplib.SMTP(hostname, 587)
except ConnectionRefusedError: except ConnectionRefusedError:
# looks like fail2ban worked # looks like fail2ban worked
raise IsBlocked() raise IsBlocked
server.starttls() server.starttls()
server.ehlo_or_helo_if_needed() server.ehlo_or_helo_if_needed()
try: try:
server.login("fakeuser", "fakepassword") server.login("fakeuser", "fakepassword")
raise Exception("authentication didn't fail") msg = "authentication didn't fail"
raise Exception(msg)
except smtplib.SMTPAuthenticationError: except smtplib.SMTPAuthenticationError:
# athentication should fail # athentication should fail
pass pass
@ -57,11 +57,12 @@ def imap_test():
M = imaplib.IMAP4_SSL(hostname) M = imaplib.IMAP4_SSL(hostname)
except ConnectionRefusedError: except ConnectionRefusedError:
# looks like fail2ban worked # looks like fail2ban worked
raise IsBlocked() raise IsBlocked
try: try:
M.login("fakeuser", "fakepassword") M.login("fakeuser", "fakepassword")
raise Exception("authentication didn't fail") msg = "authentication didn't fail"
raise Exception(msg)
except imaplib.IMAP4.error: except imaplib.IMAP4.error:
# authentication should fail # authentication should fail
pass pass
@ -75,17 +76,18 @@ def pop_test():
M = poplib.POP3_SSL(hostname) M = poplib.POP3_SSL(hostname)
except ConnectionRefusedError: except ConnectionRefusedError:
# looks like fail2ban worked # looks like fail2ban worked
raise IsBlocked() raise IsBlocked
try: try:
M.user('fakeuser') M.user('fakeuser')
try: try:
M.pass_('fakepassword') M.pass_('fakepassword')
except poplib.error_proto as e: except poplib.error_proto:
# Authentication should fail. # Authentication should fail.
M = None # don't .quit() M = None # don't .quit()
return return
M.list() M.list()
raise Exception("authentication didn't fail") msg = "authentication didn't fail"
raise Exception(msg)
finally: finally:
if M: if M:
M.quit() M.quit()
@ -99,11 +101,12 @@ def managesieve_test():
M = imaplib.IMAP4(hostname, 4190) M = imaplib.IMAP4(hostname, 4190)
except ConnectionRefusedError: except ConnectionRefusedError:
# looks like fail2ban worked # looks like fail2ban worked
raise IsBlocked() raise IsBlocked
try: try:
M.login("fakeuser", "fakepassword") M.login("fakeuser", "fakepassword")
raise Exception("authentication didn't fail") msg = "authentication didn't fail"
raise Exception(msg)
except imaplib.IMAP4.error: except imaplib.IMAP4.error:
# authentication should fail # authentication should fail
pass pass
@ -129,17 +132,17 @@ def http_test(url, expected_status, postdata=None, qsargs=None, auth=None):
headers={'User-Agent': 'Mail-in-a-Box fail2ban tester'}, headers={'User-Agent': 'Mail-in-a-Box fail2ban tester'},
timeout=8, timeout=8,
verify=False) # don't bother with HTTPS validation, it may not be configured yet verify=False) # don't bother with HTTPS validation, it may not be configured yet
except requests.exceptions.ConnectTimeout as e: except requests.exceptions.ConnectTimeout:
raise IsBlocked() raise IsBlocked
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:
if "Connection refused" in str(e): if "Connection refused" in str(e):
raise IsBlocked() raise IsBlocked
raise # some other unexpected condition raise # some other unexpected condition
# return response status code # return response status code
if r.status_code != expected_status: if r.status_code != expected_status:
r.raise_for_status() # anything but 200 r.raise_for_status() # anything but 200
raise IOError("Got unexpected status code %s." % r.status_code) raise OSError("Got unexpected status code %s." % r.status_code)
# define how to run a test # define how to run a test
@ -149,7 +152,7 @@ def restart_fail2ban_service(final=False):
if not final: if not final:
# Stop recidive jails during testing. # Stop recidive jails during testing.
command += " && sudo fail2ban-client stop recidive" command += " && sudo fail2ban-client stop recidive"
os.system("%s \"%s\"" % (ssh_command, command)) os.system(f'{ssh_command} "{command}"')
def testfunc_runner(i, testfunc, *args): def testfunc_runner(i, testfunc, *args):
print(i+1, end=" ", flush=True) print(i+1, end=" ", flush=True)
@ -163,7 +166,6 @@ def run_test(testfunc, args, count, within_seconds, parallel):
# run testfunc sequentially and still get to count requests within # run testfunc sequentially and still get to count requests within
# the required time. So we split the requests across threads. # the required time. So we split the requests across threads.
import requests.exceptions
from multiprocessing import Pool from multiprocessing import Pool
restart_fail2ban_service() restart_fail2ban_service()
@ -179,7 +181,7 @@ def run_test(testfunc, args, count, within_seconds, parallel):
# Distribute the requests across the pool. # Distribute the requests across the pool.
asyncresults = [] asyncresults = []
for i in range(count): for i in range(count):
ar = p.apply_async(testfunc_runner, [i, testfunc] + list(args)) ar = p.apply_async(testfunc_runner, [i, testfunc, *list(args)])
asyncresults.append(ar) asyncresults.append(ar)
# Wait for all runs to finish. # Wait for all runs to finish.

View File

@ -7,7 +7,7 @@
# where ipaddr is the IP address of your Mail-in-a-Box # where ipaddr is the IP address of your Mail-in-a-Box
# and hostname is the domain name to check the DNS for. # and hostname is the domain name to check the DNS for.
import sys, re, difflib import sys, re
import dns.reversename, dns.resolver import dns.reversename, dns.resolver
if len(sys.argv) < 3: if len(sys.argv) < 3:
@ -27,10 +27,10 @@ def test(server, description):
("ns2." + primary_hostname, "A", ipaddr), ("ns2." + primary_hostname, "A", ipaddr),
("www." + hostname, "A", ipaddr), ("www." + hostname, "A", ipaddr),
(hostname, "MX", "10 " + primary_hostname + "."), (hostname, "MX", "10 " + primary_hostname + "."),
(hostname, "TXT", "\"v=spf1 mx -all\""), (hostname, "TXT", '"v=spf1 mx -all"'),
("mail._domainkey." + hostname, "TXT", "\"v=DKIM1; k=rsa; s=email; \" \"p=__KEY__\""), ("mail._domainkey." + hostname, "TXT", '"v=DKIM1; k=rsa; s=email; " "p=__KEY__"'),
#("_adsp._domainkey." + hostname, "TXT", "\"dkim=all\""), #("_adsp._domainkey." + hostname, "TXT", "\"dkim=all\""),
("_dmarc." + hostname, "TXT", "\"v=DMARC1; p=quarantine;\""), ("_dmarc." + hostname, "TXT", '"v=DMARC1; p=quarantine;"'),
] ]
return test2(tests, server, description) return test2(tests, server, description)
@ -59,7 +59,7 @@ def test2(tests, server, description):
response = ["[no value]"] response = ["[no value]"]
response = ";".join(str(r) for r in response) response = ";".join(str(r) for r in response)
response = re.sub(r"(\"p=).*(\")", r"\1__KEY__\2", response) # normalize DKIM key response = re.sub(r"(\"p=).*(\")", r"\1__KEY__\2", response) # normalize DKIM key
response = response.replace("\"\" ", "") # normalize TXT records (DNSSEC signing inserts empty text string components) response = response.replace('"" ', "") # normalize TXT records (DNSSEC signing inserts empty text string components)
# is it right? # is it right?
if response == expected_answer: if response == expected_answer:
@ -98,7 +98,7 @@ else:
# And if that's OK, also check reverse DNS (the PTR record). # And if that's OK, also check reverse DNS (the PTR record).
if not test_ptr("8.8.8.8", "Google Public DNS (Reverse DNS)"): if not test_ptr("8.8.8.8", "Google Public DNS (Reverse DNS)"):
print () print ()
print ("The reverse DNS for %s is not correct. Consult your ISP for how to set the reverse DNS (also called the PTR record) for %s to %s." % (hostname, hostname, ipaddr)) print (f"The reverse DNS for {hostname} is not correct. Consult your ISP for how to set the reverse DNS (also called the PTR record) for {hostname} to {ipaddr}.")
sys.exit(1) sys.exit(1)
else: else:
print ("And the reverse DNS for the domain is correct.") print ("And the reverse DNS for the domain is correct.")

View File

@ -30,15 +30,11 @@ print("IMAP login is OK.")
# Attempt to send a mail to ourself. # Attempt to send a mail to ourself.
mailsubject = "Mail-in-a-Box Automated Test Message " + uuid.uuid4().hex mailsubject = "Mail-in-a-Box Automated Test Message " + uuid.uuid4().hex
emailto = emailaddress emailto = emailaddress
msg = """From: {emailaddress} msg = f"""From: {emailaddress}
To: {emailto} To: {emailto}
Subject: {subject} Subject: {mailsubject}
This is a test message. It should be automatically deleted by the test script.""".format( This is a test message. It should be automatically deleted by the test script."""
emailaddress=emailaddress,
emailto=emailto,
subject=mailsubject,
)
# Connect to the server on the SMTP submission TLS port. # Connect to the server on the SMTP submission TLS port.
server = smtplib.SMTP_SSL(host) server = smtplib.SMTP_SSL(host)

View File

@ -6,11 +6,11 @@ if len(sys.argv) < 3:
sys.exit(1) sys.exit(1)
host, toaddr, fromaddr = sys.argv[1:4] host, toaddr, fromaddr = sys.argv[1:4]
msg = """From: %s msg = f"""From: {fromaddr}
To: %s To: {toaddr}
Subject: SMTP server test Subject: SMTP server test
This is a test message.""" % (fromaddr, toaddr) This is a test message."""
server = smtplib.SMTP(host, 25) server = smtplib.SMTP(host, 25)
server.set_debuglevel(1) server.set_debuglevel(1)

View File

@ -88,14 +88,14 @@ def sslyze(opts, port, ok_ciphers):
try: try:
# Execute SSLyze. # Execute SSLyze.
out = subprocess.check_output([SSLYZE] + common_opts + opts + [connection_string]) out = subprocess.check_output([SSLYZE, *common_opts, *opts, connection_string])
out = out.decode("utf8") out = out.decode("utf8")
# Trim output to make better for storing in git. # Trim output to make better for storing in git.
if "SCAN RESULTS FOR" not in out: if "SCAN RESULTS FOR" not in out:
# Failed. Just output the error. # Failed. Just output the error.
out = re.sub("[\w\W]*CHECKING HOST\(S\) AVAILABILITY\n\s*-+\n", "", out) # chop off header that shows the host we queried out = re.sub("[\\w\\W]*CHECKING HOST\\(S\\) AVAILABILITY\n\\s*-+\n", "", out) # chop off header that shows the host we queried
out = re.sub("[\w\W]*SCAN RESULTS FOR.*\n\s*-+\n", "", out) # chop off header that shows the host we queried out = re.sub("[\\w\\W]*SCAN RESULTS FOR.*\n\\s*-+\n", "", out) # chop off header that shows the host we queried
out = re.sub("SCAN COMPLETED IN .*", "", out) out = re.sub("SCAN COMPLETED IN .*", "", out)
out = out.rstrip(" \n-") + "\n" out = out.rstrip(" \n-") + "\n"
@ -105,8 +105,8 @@ def sslyze(opts, port, ok_ciphers):
# Pull out the accepted ciphers list for each SSL/TLS protocol # Pull out the accepted ciphers list for each SSL/TLS protocol
# version outputted. # version outputted.
accepted_ciphers = set() accepted_ciphers = set()
for ciphers in re.findall(" Accepted:([\w\W]*?)\n *\n", out): for ciphers in re.findall(" Accepted:([\\w\\W]*?)\n *\n", out):
accepted_ciphers |= set(re.findall("\n\s*(\S*)", ciphers)) accepted_ciphers |= set(re.findall("\n\\s*(\\S*)", ciphers))
# Compare to what Mozilla recommends, for a given modernness-level. # Compare to what Mozilla recommends, for a given modernness-level.
print(" Should Not Offer: " + (", ".join(sorted(accepted_ciphers-set(ok_ciphers))) or "(none -- good)")) print(" Should Not Offer: " + (", ".join(sorted(accepted_ciphers-set(ok_ciphers))) or "(none -- good)"))
@ -142,7 +142,7 @@ for cipher in csv.DictReader(io.StringIO(urllib.request.urlopen("https://raw.git
client_compatibility = json.loads(urllib.request.urlopen("https://raw.githubusercontent.com/mail-in-a-box/user-agent-tls-capabilities/master/clients.json").read().decode("utf8")) client_compatibility = json.loads(urllib.request.urlopen("https://raw.githubusercontent.com/mail-in-a-box/user-agent-tls-capabilities/master/clients.json").read().decode("utf8"))
cipher_clients = { } cipher_clients = { }
for client in client_compatibility: for client in client_compatibility:
if len(set(client['protocols']) & set(["TLS 1.0", "TLS 1.1", "TLS 1.2"])) == 0: continue # does not support TLS if len(set(client['protocols']) & {"TLS 1.0", "TLS 1.1", "TLS 1.2"}) == 0: continue # does not support TLS
for cipher in client['ciphers']: for cipher in client['ciphers']:
cipher_clients.setdefault(cipher_names.get(cipher), set()).add("/".join(x for x in [client['client']['name'], client['client']['version'], client['client']['platform']] if x)) cipher_clients.setdefault(cipher_names.get(cipher), set()).add("/".join(x for x in [client['client']['name'], client['client']['version'], client['client']['platform']] if x))

View File

@ -76,7 +76,7 @@ for setting in settings:
found = set() found = set()
buf = "" buf = ""
with open(filename, "r") as f: with open(filename, encoding="utf-8") as f:
input_lines = list(f) input_lines = list(f)
while len(input_lines) > 0: while len(input_lines) > 0:
@ -84,7 +84,7 @@ while len(input_lines) > 0:
# If this configuration file uses folded lines, append any folded lines # If this configuration file uses folded lines, append any folded lines
# into our input buffer. # into our input buffer.
if folded_lines and line[0] not in (comment_char, " ", ""): if folded_lines and line[0] not in {comment_char, " ", ""}:
while len(input_lines) > 0 and input_lines[0][0] in " \t": while len(input_lines) > 0 and input_lines[0][0] in " \t":
line += input_lines.pop(0) line += input_lines.pop(0)
@ -93,9 +93,9 @@ while len(input_lines) > 0:
# Check if this line contain this setting from the command-line arguments. # Check if this line contain this setting from the command-line arguments.
name, val = settings[i].split("=", 1) name, val = settings[i].split("=", 1)
m = re.match( m = re.match(
"(\s*)" r"(\s*)"
+ "(" + re.escape(comment_char) + "\s*)?" "(" + re.escape(comment_char) + r"\s*)?"
+ re.escape(name) + delimiter_re + "(.*?)\s*$", + re.escape(name) + delimiter_re + r"(.*?)\s*$",
line, re.S) line, re.S)
if not m: continue if not m: continue
indent, is_comment, existing_val = m.groups() indent, is_comment, existing_val = m.groups()
@ -144,7 +144,7 @@ for i in range(len(settings)):
if not testing: if not testing:
# Write out the new file. # Write out the new file.
with open(filename, "w") as f: with open(filename, "w", encoding="utf-8") as f:
f.write(buf) f.write(buf)
else: else:
# Just print the new file to stdout. # Just print the new file to stdout.

View File

@ -38,7 +38,7 @@ for date, ip in accesses:
# Since logs are rotated, store the statistics permanently in a JSON file. # Since logs are rotated, store the statistics permanently in a JSON file.
# Load in the stats from an existing file. # Load in the stats from an existing file.
if os.path.exists(outfn): if os.path.exists(outfn):
with open(outfn, "r") as f: with open(outfn, encoding="utf-8") as f:
existing_data = json.load(f) existing_data = json.load(f)
for date, count in existing_data: for date, count in existing_data:
if date not in by_date: if date not in by_date:
@ -51,5 +51,5 @@ by_date = sorted(by_date.items())
by_date.pop(-1) by_date.pop(-1)
# Write out. # Write out.
with open(outfn, "w") as f: with open(outfn, "w", encoding="utf-8") as f:
json.dump(by_date, f, sort_keys=True, indent=True) json.dump(by_date, f, sort_keys=True, indent=True)