149 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			149 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
#!/usr/bin/env python3
 | 
						|
 | 
						|
import sys
 | 
						|
from pprint import pformat
 | 
						|
import requests
 | 
						|
from pyquery import PyQuery as pq
 | 
						|
import ssl
 | 
						|
import OpenSSL
 | 
						|
from urllib.parse import urlparse, urljoin
 | 
						|
from datetime import datetime, timedelta
 | 
						|
from pytz import UTC
 | 
						|
import logging
 | 
						|
#logging.basicConfig(level=logging.DEBUG)
 | 
						|
logging.basicConfig(level=logging.WARNING)
 | 
						|
log = logging.getLogger()
 | 
						|
 | 
						|
# FIXME: relative url stuff will not work if the url passed in redirects 
 | 
						|
# somewhere else
 | 
						|
 | 
						|
class CertificateProblem(Exception):
 | 
						|
    pass
 | 
						|
 | 
						|
class ReachabilityProblem(Exception):
 | 
						|
    pass
 | 
						|
 | 
						|
class SSLCert(object):
 | 
						|
    def __init__(self,c):
 | 
						|
        self.c = c
 | 
						|
    def decode_ossl_time(self,t):
 | 
						|
        f = '%Y%m%d%H%M%SZ'
 | 
						|
        return datetime.strptime(t.decode('utf-8'), f)
 | 
						|
    def notBefore(self):
 | 
						|
        return self.decode_ossl_time(self.c.get_notBefore())
 | 
						|
    def notAfter(self):
 | 
						|
        return self.decode_ossl_time(self.c.get_notAfter())
 | 
						|
    def commonName(self):
 | 
						|
        t = self.c.get_subject().get_components()
 | 
						|
        for x in t:
 | 
						|
            if x[0] == "CN":
 | 
						|
                return x[1]
 | 
						|
    def expired(self):
 | 
						|
        return datetime.utcnow() > self.notAfter()
 | 
						|
    def tooEarly(self):
 | 
						|
        return datetime.utcnow() < self.notBefore()
 | 
						|
    def validTime(self):
 | 
						|
        if not self.expired() and not self.tooEarly():
 | 
						|
            return True
 | 
						|
        return False
 | 
						|
    def expiresSoon(self):
 | 
						|
        week = timedelta(days=7)
 | 
						|
        then = datetime.utcnow() + week
 | 
						|
        return then > self.notAfter()
 | 
						|
 | 
						|
class Website(object):
 | 
						|
    def __init__(self,url):
 | 
						|
        self.url = urlparse(url)
 | 
						|
        if not self.url.scheme:
 | 
						|
            self.url = urlparse('http://' + url)
 | 
						|
        self.cert = None
 | 
						|
        self.res = {}
 | 
						|
        self.r = requests.get(self.urlstring(),verify=True)
 | 
						|
    def contentType(self):
 | 
						|
        if ';' in self.r.headers['content-type']:
 | 
						|
            return self.r.headers['content-type'].split(';')[0]
 | 
						|
        else:
 | 
						|
            return self.r.headers['content-type']
 | 
						|
    def resources(self):
 | 
						|
        if self.contentType() != 'text/html':
 | 
						|
            return []
 | 
						|
        d = pq(self.r.text)
 | 
						|
        #import pdb; pdb.set_trace()
 | 
						|
        res = []
 | 
						|
        for e in d('link'):
 | 
						|
            if 'openid' in e.attrib.get('rel'):
 | 
						|
                continue
 | 
						|
            res.append(e.attrib.get('href')) 
 | 
						|
        for e in d('script'):
 | 
						|
            res.append(e.attrib.get('src'))
 | 
						|
        res = [
 | 
						|
            urljoin(self.urlstring(),x) if not urlparse(x).netloc else x
 | 
						|
                 for x in res 
 | 
						|
        ]
 | 
						|
        res = [
 | 
						|
            self.url.scheme + ':' + x if not urlparse(x).scheme else x
 | 
						|
                for x in res
 | 
						|
        ]
 | 
						|
        res = {x: 1 for x in res}
 | 
						|
        self.res = res.keys()
 | 
						|
        return self.res
 | 
						|
    def resources_by_host(self):
 | 
						|
        out = {}
 | 
						|
        for r in self.res:
 | 
						|
            if not out.get(urlparse(r).netloc):
 | 
						|
                out[urlparse(r).netloc] = []
 | 
						|
            out[urlparse(r).netloc].append(r)
 | 
						|
        return out
 | 
						|
    def is_tls(self):
 | 
						|
        return self.url.scheme == 'https'
 | 
						|
    def urlstring(self):
 | 
						|
        return self.url.geturl()
 | 
						|
    def check(self):
 | 
						|
        if self.r.status_code is not 200:
 | 
						|
            raise ReachabilityProblem("can't access: '%s'" % self.urlstring())
 | 
						|
        if self.is_tls():
 | 
						|
            self._get_cert()
 | 
						|
            if not self.cert.validTime():
 | 
						|
                raise CertificateProblem(
 | 
						|
                    "cert for %s is invalid: %s to %s" % (
 | 
						|
                        self.urlstring(),
 | 
						|
                        self.cert.notBefore(),
 | 
						|
                        self.cert.notAfter()   
 | 
						|
                    )
 | 
						|
                )
 | 
						|
            if self.cert.expiresSoon():
 | 
						|
                raise CertificateProblem(
 | 
						|
                    "cert for %s expires soon: %s" % (
 | 
						|
                        self.urlstring(),
 | 
						|
                        self.cert.notAfter()   
 | 
						|
                    )
 | 
						|
                )
 | 
						|
    def _get_cert(self):
 | 
						|
        if not self.url.port:
 | 
						|
            p = 443
 | 
						|
        else:
 | 
						|
            p = self.url.port
 | 
						|
        c = ssl.get_server_certificate(
 | 
						|
            (self.url.hostname, p),
 | 
						|
            ssl_version=ssl.PROTOCOL_TLSv1
 | 
						|
        )
 | 
						|
        self.cert = SSLCert(
 | 
						|
            OpenSSL.crypto.load_certificate(
 | 
						|
                OpenSSL.crypto.FILETYPE_PEM,
 | 
						|
                c
 | 
						|
            )
 | 
						|
        )
 | 
						|
 
 | 
						|
def main():
 | 
						|
    if len(sys.argv) < 2:
 | 
						|
        print("usage: %s <url> [url2] [url3] [...]" % sys.argv[0])
 | 
						|
        sys.exit(1)
 | 
						|
    for site in sys.argv[1:]:
 | 
						|
        s = Website(site)
 | 
						|
        s.check()
 | 
						|
        for u in s.resources():
 | 
						|
            Website(u).check()
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    main()
 |