diff --git a/checkcert/checkcert b/checkcert/checkcert index d6e94ae..82c1b14 100755 --- a/checkcert/checkcert +++ b/checkcert/checkcert @@ -1,31 +1,21 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 import sys from pprint import pformat import requests -import pyquery +from pyquery import PyQuery as pq import ssl import OpenSSL -from urlparse import urlparse, urljoin +from urllib.parse import urlparse, urljoin from datetime import datetime, timedelta from pytz import UTC import logging -logging.getLogger().setLevel(logging.DEBUG) +#logging.basicConfig(level=logging.DEBUG) +logging.basicConfig(level=logging.WARNING) +log = logging.getLogger() - -class AttrDict(dict): - def __init__(self, *a, **kw): - dict.__init__(self, *a, **kw) - self.__dict__ = self - -derp = AttrDict() - -def die(reason): - raise SystemExit(reason) - -def interact(): - import code - code.InteractiveConsole(locals=globals()).interact() +# FIXME: relative url stuff will not work if the url passed in redirects +# somewhere else class CertificateProblem(Exception): pass @@ -33,19 +23,18 @@ class CertificateProblem(Exception): class ReachabilityProblem(Exception): pass -def decode_ossl_time(t): - f = '%Y%m%d%H%M%SZ' - return datetime.strptime(t, f) - class SSLCert(object): - def __init__(self,x): - self.x = x + def __init__(self,c): + self.c = c + def decode_ossl_time(self,t): + f = '%Y%m%d%H%M%SZ' + return datetime.strptime(t.decode('utf-8'), f) def notBefore(self): - return decode_ossl_time(self.x.get_notBefore()) + return self.decode_ossl_time(self.c.get_notBefore()) def notAfter(self): - return decode_ossl_time(self.x.get_notAfter()) + return self.decode_ossl_time(self.c.get_notAfter()) def commonName(self): - t = self.x.get_subject().get_components() + t = self.c.get_subject().get_components() for x in t: if x[0] == "CN": return x[1] @@ -61,43 +50,90 @@ class SSLCert(object): week = timedelta(days=7) then = datetime.utcnow() + week return then > self.notAfter() - -def check_url(url): - r = requests.get(url,verify=True) - if r.status_code is not 200: - raise ReachabilityProblem - cert = cert_for_url(url) - -def cert_for_url(url): - o = urlparse(url) - if o.scheme != 'https': - return None - - if not o.port: - p = 443 - else: - p = o.port - c = ssl.get_server_certificate((o.hostname, p)) - - return SSLCert( - OpenSSL.crypto.load_certificate( - OpenSSL.crypto.FILETYPE_PEM, - c +class Website(object): + def __init__(self,url): + self.url = urlparse(url) + if not self.url.scheme: + self.url = urlparse('http://' + url) + self.cert = None + self.res = {} + self.r = requests.get(self.urlstring(),verify=True) + def contentType(self): + if ';' in self.r.headers['content-type']: + return self.r.headers['content-type'].split(';')[0] + else: + return self.r.headers['content-type'] + def resources(self): + if self.contentType() != 'text/html': + return [] + d = pq(self.r.text) + #import pdb; pdb.set_trace() + res = [] + for e in d('link'): + if 'openid' in e.attrib.get('rel'): + continue + res.append(e.attrib.get('href')) + for e in d('script'): + res.append(e.attrib.get('src')) + res = [ + urljoin(self.urlstring(),x) if not urlparse(x).netloc else x + for x in res + ] + res = [ + self.url.scheme + ':' + x if not urlparse(x).scheme else x + for x in res + ] + res = {x: 1 for x in res} + self.res = res.keys() + return self.res + def resources_by_host(self): + out = {} + for r in self.res: + if not out.get(urlparse(r).netloc): + out[urlparse(r).netloc] = [] + out[urlparse(r).netloc].append(r) + return out + def is_tls(self): + return self.url.scheme == 'https' + def urlstring(self): + return self.url.geturl() + def check(self): + if self.r.status_code is not 200: + raise ReachabilityProblem + if self.is_tls(): + self._get_cert() + if self.cert.expiresSoon() or not self.cert.validTime(): + raise CertificateProblem( + "cert for %s expires soon: %s" % ( + self.urlstring(), + self.cert.notAfter() + ) + ) + def _get_cert(self): + if not self.url.port: + p = 443 + else: + p = self.url.port + c = ssl.get_server_certificate( + (self.url.hostname, p), + ssl_version=ssl.PROTOCOL_TLSv1 ) - ) - + self.cert = SSLCert( + OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, + c + ) + ) + def main(): if len(sys.argv) < 2: - print "usage: %s " % sys.argv[0] + print("usage: %s " % sys.argv[0]) sys.exit(1) - do_checks(sys.argv[1]) - -def do_checks(starturl): - urlqueue = [] - urlqueue.append(starturl) - while len(urlqueue): - urlqueue.extend(check_url(urlqueue.pop())) + s = Website(sys.argv[1]) + s.check() + for rurl in s.resources(): + Website(rurl).check() if __name__ == '__main__': main()