Fixed PLR6201 (literal-membership): Use a `set` literal when testing for membership

This commit is contained in:
Teal Dulcet 2023-12-22 07:10:25 -08:00 committed by Joshua Tauberer
parent cb922ec286
commit 49124cc9ca
11 changed files with 32 additions and 32 deletions

View File

@ -48,7 +48,7 @@ class AuthService:
return username, password return username, password
username, password = parse_http_authorization_basic(request.headers.get('Authorization', '')) username, password = parse_http_authorization_basic(request.headers.get('Authorization', ''))
if username in (None, ""): if username in {None, ""}:
raise ValueError("Authorization header invalid.") raise ValueError("Authorization header invalid.")
if username.strip() == "" and password.strip() == "": if username.strip() == "" and password.strip() == "":

View File

@ -556,7 +556,7 @@ def backup_set_custom(env, target, target_user, target_pass, min_age):
# Validate. # Validate.
try: try:
if config["target"] not in ("off", "local"): if config["target"] not in {"off", "local"}:
# these aren't supported by the following function, which expects a full url in the target key, # these aren't supported by the following function, which expects a full url in the target key,
# which is what is there except when loading the config prior to saving # which is what is there except when loading the config prior to saving
list_target_files(config) list_target_files(config)

View File

@ -91,7 +91,7 @@ elif sys.argv[1] == "user" and len(sys.argv) == 2:
print("*", end='') print("*", end='')
print() print()
elif sys.argv[1] == "user" and sys.argv[2] in ("add", "password"): elif sys.argv[1] == "user" and sys.argv[2] in {"add", "password"}:
if len(sys.argv) < 5: if len(sys.argv) < 5:
if len(sys.argv) < 4: if len(sys.argv) < 4:
email = input("email: ") email = input("email: ")
@ -109,7 +109,7 @@ elif sys.argv[1] == "user" and sys.argv[2] in ("add", "password"):
elif sys.argv[1] == "user" and sys.argv[2] == "remove" and len(sys.argv) == 4: elif sys.argv[1] == "user" and sys.argv[2] == "remove" and len(sys.argv) == 4:
print(mgmt("/mail/users/remove", { "email": sys.argv[3] })) print(mgmt("/mail/users/remove", { "email": sys.argv[3] }))
elif sys.argv[1] == "user" and sys.argv[2] in ("make-admin", "remove-admin") and len(sys.argv) == 4: elif sys.argv[1] == "user" and sys.argv[2] in {"make-admin", "remove-admin"} and len(sys.argv) == 4:
if sys.argv[2] == "make-admin": if sys.argv[2] == "make-admin":
action = "add" action = "add"
else: else:
@ -132,7 +132,7 @@ elif sys.argv[1] == "user" and len(sys.argv) == 5 and sys.argv[2:4] == ["mfa", "
for mfa in status["enabled_mfa"]: for mfa in status["enabled_mfa"]:
W.writerow([mfa["id"], mfa["type"], mfa["label"]]) W.writerow([mfa["id"], mfa["type"], mfa["label"]])
elif sys.argv[1] == "user" and len(sys.argv) in (5, 6) and sys.argv[2:4] == ["mfa", "disable"]: elif sys.argv[1] == "user" and len(sys.argv) in {5, 6} and sys.argv[2:4] == ["mfa", "disable"]:
# Disable MFA (all or a particular device) for a user. # Disable MFA (all or a particular device) for a user.
print(mgmt("/mfa/disable", { "user": sys.argv[4], "mfa-id": sys.argv[5] if len(sys.argv) == 6 else None })) print(mgmt("/mfa/disable", { "user": sys.argv[4], "mfa-id": sys.argv[5] if len(sys.argv) == 6 else None }))

View File

@ -90,7 +90,7 @@ def authorized_personnel_only(viewfunc):
status = 403 status = 403
headers = None headers = None
if request.headers.get('Accept') in (None, "", "*/*"): if request.headers.get('Accept') in {None, "", "*/*"}:
# Return plain text output. # Return plain text output.
return Response(error+"\n", status=status, mimetype='text/plain', headers=headers) return Response(error+"\n", status=status, mimetype='text/plain', headers=headers)
else: else:
@ -355,9 +355,9 @@ def dns_set_record(qname, rtype="A"):
# Get the existing records matching the qname and rtype. # Get the existing records matching the qname and rtype.
return dns_get_records(qname, rtype) return dns_get_records(qname, rtype)
elif request.method in ("POST", "PUT"): elif request.method in {"POST", "PUT"}:
# There is a default value for A/AAAA records. # There is a default value for A/AAAA records.
if rtype in ("A", "AAAA") and value == "": if rtype in {"A", "AAAA"} and value == "":
value = request.environ.get("HTTP_X_FORWARDED_FOR") # normally REMOTE_ADDR but we're behind nginx as a reverse proxy value = request.environ.get("HTTP_X_FORWARDED_FOR") # normally REMOTE_ADDR but we're behind nginx as a reverse proxy
# Cannot add empty records. # Cannot add empty records.

View File

@ -364,7 +364,7 @@ def build_zone(domain, domain_properties, additional_records, env, is_zone=True)
# non-mail domain and also may include qnames from custom DNS records. # non-mail domain and also may include qnames from custom DNS records.
# Do this once at the end of generating a zone. # Do this once at the end of generating a zone.
if is_zone: if is_zone:
qnames_with_a = set(qname for (qname, rtype, value, explanation) in records if rtype in ("A", "AAAA")) qnames_with_a = set(qname for (qname, rtype, value, explanation) in records if rtype in {"A", "AAAA"})
qnames_with_mx = set(qname for (qname, rtype, value, explanation) in records if rtype == "MX") qnames_with_mx = set(qname for (qname, rtype, value, explanation) in records if rtype == "MX")
for qname in qnames_with_a - qnames_with_mx: for qname in qnames_with_a - qnames_with_mx:
# Mark this domain as not sending mail with hard-fail SPF and DMARC records. # Mark this domain as not sending mail with hard-fail SPF and DMARC records.
@ -921,12 +921,12 @@ def set_custom_dns_record(qname, rtype, value, action, env):
if not re.search(DOMAIN_RE, qname): if not re.search(DOMAIN_RE, qname):
raise ValueError("Invalid name.") raise ValueError("Invalid name.")
if rtype in ("A", "AAAA"): if rtype in {"A", "AAAA"}:
if value != "local": # "local" is a special flag for us if value != "local": # "local" is a special flag for us
v = ipaddress.ip_address(value) # raises a ValueError if there's a problem v = ipaddress.ip_address(value) # raises a ValueError if there's a problem
if rtype == "A" and not isinstance(v, ipaddress.IPv4Address): raise ValueError("That's an IPv6 address.") if rtype == "A" and not isinstance(v, ipaddress.IPv4Address): raise ValueError("That's an IPv6 address.")
if rtype == "AAAA" and not isinstance(v, ipaddress.IPv6Address): raise ValueError("That's an IPv4 address.") if rtype == "AAAA" and not isinstance(v, ipaddress.IPv6Address): raise ValueError("That's an IPv4 address.")
elif rtype in ("CNAME", "NS"): elif rtype in {"CNAME", "NS"}:
if rtype == "NS" and qname == zone: if rtype == "NS" and qname == zone:
raise ValueError("NS records can only be set for subdomains.") raise ValueError("NS records can only be set for subdomains.")
@ -936,7 +936,7 @@ def set_custom_dns_record(qname, rtype, value, action, env):
if not re.search(DOMAIN_RE, value): if not re.search(DOMAIN_RE, value):
raise ValueError("Invalid value.") raise ValueError("Invalid value.")
elif rtype in ("CNAME", "TXT", "SRV", "MX", "SSHFP", "CAA"): elif rtype in {"CNAME", "TXT", "SRV", "MX", "SSHFP", "CAA"}:
# anything goes # anything goes
pass pass
else: else:
@ -979,7 +979,7 @@ def set_custom_dns_record(qname, rtype, value, action, env):
# Preserve this record. # Preserve this record.
newconfig.append((_qname, _rtype, _value)) newconfig.append((_qname, _rtype, _value))
if action in ("add", "set") and needs_add and value is not None: if action in {"add", "set"} and needs_add and value is not None:
newconfig.append((qname, rtype, value)) newconfig.append((qname, rtype, value))
made_change = True made_change = True

View File

@ -376,9 +376,9 @@ def scan_mail_log_line(line, collector):
elif service == "postfix/smtpd": elif service == "postfix/smtpd":
if SCAN_BLOCKED: if SCAN_BLOCKED:
scan_postfix_smtpd_line(date, log, collector) scan_postfix_smtpd_line(date, log, collector)
elif service in ("postfix/qmgr", "postfix/pickup", "postfix/cleanup", "postfix/scache", elif service in {"postfix/qmgr", "postfix/pickup", "postfix/cleanup", "postfix/scache",
"spampd", "postfix/anvil", "postfix/master", "opendkim", "postfix/lmtp", "spampd", "postfix/anvil", "postfix/master", "opendkim", "postfix/lmtp",
"postfix/tlsmgr", "anvil"): "postfix/tlsmgr", "anvil"}:
# nothing to look at # nothing to look at
return True return True
else: else:
@ -500,7 +500,7 @@ def add_login(user, date, protocol_name, host, collector):
data["totals_by_protocol"][protocol_name] += 1 data["totals_by_protocol"][protocol_name] += 1
data["totals_by_protocol_and_host"][(protocol_name, host)] += 1 data["totals_by_protocol_and_host"][(protocol_name, host)] += 1
if host not in ("127.0.0.1", "::1") or True: if host not in {"127.0.0.1", "::1"} or True:
data["activity-by-hour"][protocol_name][date.hour] += 1 data["activity-by-hour"][protocol_name][date.hour] += 1
collector["logins"][user] = data collector["logins"][user] = data
@ -684,7 +684,7 @@ def print_user_table(users, data=None, sub_data=None, activity=None, latest=None
data_accum[col] += d[row] data_accum[col] += d[row]
try: try:
if None not in [latest, earliest]: if None not in {latest, earliest}:
vert_pos = len(line) vert_pos = len(line)
e = earliest[row] e = earliest[row]
l = latest[row] l = latest[row]
@ -740,7 +740,7 @@ def print_user_table(users, data=None, sub_data=None, activity=None, latest=None
else: else:
header += l.rjust(max(5, len(l) + 1, col_widths[col])) header += l.rjust(max(5, len(l) + 1, col_widths[col]))
if None not in (latest, earliest): if None not in {latest, earliest}:
header += " │ timespan " header += " │ timespan "
lines.insert(0, header.rstrip()) lines.insert(0, header.rstrip())
@ -765,7 +765,7 @@ def print_user_table(users, data=None, sub_data=None, activity=None, latest=None
footer += temp.format(data_accum[row]) footer += temp.format(data_accum[row])
try: try:
if None not in [latest, earliest]: if None not in {latest, earliest}:
max_l = max(latest) max_l = max(latest)
min_e = min(earliest) min_e = min(earliest)
timespan = relativedelta(max_l, min_e) timespan = relativedelta(max_l, min_e)

View File

@ -588,7 +588,7 @@ def kick(env, mail_result=None):
# They are now stored in the auto_aliases table. # They are now stored in the auto_aliases table.
for address, forwards_to, permitted_senders, auto in get_mail_aliases(env): for address, forwards_to, permitted_senders, auto in get_mail_aliases(env):
user, domain = address.split("@") user, domain = address.split("@")
if user in ("postmaster", "admin", "abuse") \ if user in {"postmaster", "admin", "abuse"} \
and address not in required_aliases \ and address not in required_aliases \
and forwards_to == get_system_administrator(env) \ and forwards_to == get_system_administrator(env) \
and not auto: and not auto:

View File

@ -637,7 +637,7 @@ def load_pem(pem):
if pem_type is None: if pem_type is None:
raise ValueError("File is not a valid PEM-formatted file.") raise ValueError("File is not a valid PEM-formatted file.")
pem_type = pem_type.group(1) pem_type = pem_type.group(1)
if pem_type in (b"RSA PRIVATE KEY", b"PRIVATE KEY"): if pem_type in {b"RSA PRIVATE KEY", b"PRIVATE KEY"}:
return serialization.load_pem_private_key(pem, password=None, backend=default_backend()) return serialization.load_pem_private_key(pem, password=None, backend=default_backend())
if pem_type == b"CERTIFICATE": if pem_type == b"CERTIFICATE":
return load_pem_x509_certificate(pem, default_backend()) return load_pem_x509_certificate(pem, default_backend())

View File

@ -151,7 +151,7 @@ def check_service(i, service, env):
output.print_error("%s is not running (port %d)." % (service['name'], service['port'])) output.print_error("%s is not running (port %d)." % (service['name'], service['port']))
# Why is nginx not running? # Why is nginx not running?
if not running and service["port"] in (80, 443): if not running and service["port"] in {80, 443}:
output.print_line(shell('check_output', ['nginx', '-t'], capture_stderr=True, trap=True)[1].strip()) output.print_line(shell('check_output', ['nginx', '-t'], capture_stderr=True, trap=True)[1].strip())
else: else:
@ -340,7 +340,7 @@ def run_domain_checks(rounded_time, env, output, pool, domains_to_check=None):
domains_to_check = [ domains_to_check = [
d for d in domains_to_check d for d in domains_to_check
if not ( if not (
d.split(".", 1)[0] in ("www", "autoconfig", "autodiscover", "mta-sts") d.split(".", 1)[0] in {"www", "autoconfig", "autodiscover", "mta-sts"}
and len(d.split(".", 1)) == 2 and len(d.split(".", 1)) == 2
and d.split(".", 1)[1] in domains_to_check and d.split(".", 1)[1] in domains_to_check
) )
@ -467,7 +467,7 @@ def check_primary_hostname_dns(domain, env, output, dns_domains, dns_zonefiles):
# a DNS zone if it is a subdomain of another domain we have a zone for. # a DNS zone if it is a subdomain of another domain we have a zone for.
existing_rdns_v4 = query_dns(dns.reversename.from_address(env['PUBLIC_IP']), "PTR") existing_rdns_v4 = query_dns(dns.reversename.from_address(env['PUBLIC_IP']), "PTR")
existing_rdns_v6 = query_dns(dns.reversename.from_address(env['PUBLIC_IPV6']), "PTR") if env.get("PUBLIC_IPV6") else None existing_rdns_v6 = query_dns(dns.reversename.from_address(env['PUBLIC_IPV6']), "PTR") if env.get("PUBLIC_IPV6") else None
if existing_rdns_v4 == domain and existing_rdns_v6 in (None, domain): if existing_rdns_v4 == domain and existing_rdns_v6 in {None, domain}:
output.print_ok("Reverse DNS is set correctly at ISP. [%s%s]" % (my_ips, env['PRIMARY_HOSTNAME'])) output.print_ok("Reverse DNS is set correctly at ISP. [%s%s]" % (my_ips, env['PRIMARY_HOSTNAME']))
elif existing_rdns_v4 == existing_rdns_v6 or existing_rdns_v6 is None: elif existing_rdns_v4 == existing_rdns_v6 or existing_rdns_v6 is None:
output.print_error("""Your box's reverse DNS is currently %s, but it should be %s. Your ISP or cloud provider will have instructions output.print_error("""Your box's reverse DNS is currently %s, but it should be %s. Your ISP or cloud provider will have instructions
@ -636,7 +636,7 @@ def check_dnssec(domain, env, output, dns_zonefiles, is_checking_primary=False):
if set(r[1] for r in matched_ds) == { '13' } and set(r[2] for r in matched_ds) <= { '2', '4' }: # all are alg 13 and digest type 2 or 4 if set(r[1] for r in matched_ds) == { '13' } and set(r[2] for r in matched_ds) <= { '2', '4' }: # all are alg 13 and digest type 2 or 4
output.print_ok("DNSSEC 'DS' record is set correctly at registrar.") output.print_ok("DNSSEC 'DS' record is set correctly at registrar.")
return return
elif len([r for r in matched_ds if r[1] == '13' and r[2] in ( '2', '4' )]) > 0: # some but not all are alg 13 elif len([r for r in matched_ds if r[1] == '13' and r[2] in { '2', '4' }]) > 0: # some but not all are alg 13
output.print_ok("DNSSEC 'DS' record is set correctly at registrar. (Records using algorithm other than ECDSAP256SHA256 and digest types other than SHA-256/384 should be removed.)") output.print_ok("DNSSEC 'DS' record is set correctly at registrar. (Records using algorithm other than ECDSAP256SHA256 and digest types other than SHA-256/384 should be removed.)")
return return
else: # no record uses alg 13 else: # no record uses alg 13
@ -825,7 +825,7 @@ def query_dns(qname, rtype, nxdomain='[Not Set]', at=None, as_list=False):
# be expressed in equivalent string forms. Canonicalize the form before # be expressed in equivalent string forms. Canonicalize the form before
# returning them. The caller should normalize any IP addresses the result # returning them. The caller should normalize any IP addresses the result
# of this method is compared with. # of this method is compared with.
if rtype in ("A", "AAAA"): if rtype in {"A", "AAAA"}:
response = [normalize_ip(str(r)) for r in response] response = [normalize_ip(str(r)) for r in response]
if as_list: if as_list:
@ -841,7 +841,7 @@ def check_ssl_cert(domain, rounded_time, ssl_certificates, env, output):
# Check that TLS certificate is signed. # Check that TLS certificate is signed.
# Skip the check if the A record is not pointed here. # Skip the check if the A record is not pointed here.
if query_dns(domain, "A", None) not in (env['PUBLIC_IP'], None): return if query_dns(domain, "A", None) not in {env['PUBLIC_IP'], None}: return
# Where is the certificate file stored? # Where is the certificate file stored?
tls_cert = get_domain_ssl_files(domain, ssl_certificates, env, allow_missing_cert=True) tls_cert = get_domain_ssl_files(domain, ssl_certificates, env, allow_missing_cert=True)
@ -1002,14 +1002,14 @@ def run_and_output_changes(env, pool):
out.add_heading(category + " -- Previously:") out.add_heading(category + " -- Previously:")
elif op == "delete": elif op == "delete":
out.add_heading(category + " -- Removed") out.add_heading(category + " -- Removed")
if op in ("replace", "delete"): if op in {"replace", "delete"}:
BufferedOutput(with_lines=prev_lines[i1:i2]).playback(out) BufferedOutput(with_lines=prev_lines[i1:i2]).playback(out)
if op == "replace": if op == "replace":
out.add_heading(category + " -- Currently:") out.add_heading(category + " -- Currently:")
elif op == "insert": elif op == "insert":
out.add_heading(category + " -- Added") out.add_heading(category + " -- Added")
if op in ("replace", "insert"): if op in {"replace", "insert"}:
BufferedOutput(with_lines=cur_lines[j1:j2]).playback(out) BufferedOutput(with_lines=cur_lines[j1:j2]).playback(out)
for category, prev_lines in prev_status.items(): for category, prev_lines in prev_status.items():
@ -1095,7 +1095,7 @@ class BufferedOutput:
def __init__(self, with_lines=None): def __init__(self, with_lines=None):
self.buf = [] if not with_lines else with_lines self.buf = [] if not with_lines else with_lines
def __getattr__(self, attr): def __getattr__(self, attr):
if attr not in ("add_heading", "print_ok", "print_error", "print_warning", "print_block", "print_line"): if attr not in {"add_heading", "print_ok", "print_error", "print_warning", "print_block", "print_line"}:
raise AttributeError raise AttributeError
# Return a function that just records the call & arguments to our buffer. # Return a function that just records the call & arguments to our buffer.
def w(*args, **kwargs): def w(*args, **kwargs):

View File

@ -53,7 +53,7 @@ def get_domains_with_a_records(env):
domains = set() domains = set()
dns = get_custom_dns_config(env) dns = get_custom_dns_config(env)
for domain, rtype, value in dns: for domain, rtype, value in dns:
if rtype == "CNAME" or (rtype in ("A", "AAAA") and value not in ("local", env['PUBLIC_IP'])): if rtype == "CNAME" or (rtype in {"A", "AAAA"} and value not in {"local", env['PUBLIC_IP']}):
domains.add(domain) domains.add(domain)
return domains return domains

View File

@ -84,7 +84,7 @@ while len(input_lines) > 0:
# If this configuration file uses folded lines, append any folded lines # If this configuration file uses folded lines, append any folded lines
# into our input buffer. # into our input buffer.
if folded_lines and line[0] not in (comment_char, " ", ""): if folded_lines and line[0] not in {comment_char, " ", ""}:
while len(input_lines) > 0 and input_lines[0][0] in " \t": while len(input_lines) > 0 and input_lines[0][0] in " \t":
line += input_lines.pop(0) line += input_lines.pop(0)