diff --git a/plugin/web.py b/plugin/web.py index a682d20..e386c6a 100644 --- a/plugin/web.py +++ b/plugin/web.py @@ -24,7 +24,8 @@ from . import util # class WebRequest: - def __init__(self, headers, body): + def __init__(self, method, headers, body): + self.method = method self.headers = headers self.body = body @@ -95,8 +96,15 @@ class WebClient: if len(parts) == 1: return None, 0 + lines = parts[0].split('\r\n'.encode('utf-8')) + method = None + + if len(lines) > 0: + request_line_parts = lines[0].split(' '.encode('utf-8')) + method = request_line_parts[0].upper() if len(request_line_parts) > 0 else None + headers = {} - for line in parts[0].split('\r\n'.encode('utf-8')): + for line in lines[1:]: pair = line.split(': '.encode('utf-8')) headers[pair[0].lower()] = pair[1] if len(pair) > 1 else None @@ -108,8 +116,7 @@ class WebClient: return None, 0 body = data[headerLength : totalLength] - return WebRequest(headers, body), totalLength - + return WebRequest(method, headers, body), totalLength # # WebServer @@ -154,7 +161,54 @@ class WebServer: def handlerWrapper(self, req): + allowed, corsOrigin = self.allowOrigin(req) + if req.method == b'OPTIONS': + body = ''.encode('utf-8') + headers = self.buildHeaders(corsOrigin, body) + + if b'access-control-request-private-network' in req.headers and ( + req.headers[b'access-control-request-private-network'] == b'true'): + # include this header so that if a public origin is included in the whitelist, + # then browsers won't fail requests due to the private network access check + headers.append(['Access-Control-Allow-Private-Network', 'true']) + + return self.buildResponse(headers, body) + + paramsError = False + + try: + params = json.loads(req.body.decode('utf-8')) + except ValueError: + body = json.dumps(None).encode('utf-8') + paramsError = True + + if allowed or not paramsError and params.get('action', '') == 'requestPermission': + if len(req.body) == 0: + body = 'AnkiConnect v.{}'.format(util.setting('apiVersion')).encode('utf-8') + else: + if params.get('action', '') == 'requestPermission': + params['params'] = params.get('params', {}) + params['params']['allowed'] = allowed + params['params']['origin'] = b'origin' in req.headers and req.headers[b'origin'].decode() or '' + if not allowed : + corsOrigin = params['params']['origin'] + + body = json.dumps(self.handler(params)).encode('utf-8') + + headers = self.buildHeaders(corsOrigin, body) + else : + headers = [ + ['HTTP/1.1 403 Forbidden', None], + ['Access-Control-Allow-Origin', corsOrigin], + ['Access-Control-Allow-Headers', '*'] + ] + body = ''.encode('utf-8') + + return self.buildResponse(headers, body) + + + def allowOrigin(self, req): # handle multiple cors origins by checking the 'origin'-header against the allowed origin list from the config webCorsOriginList = util.setting('webCorsOriginList') @@ -183,43 +237,22 @@ class WebServer: allowed = True else: allowed = True + + return allowed, corsOrigin + + def buildHeaders(self, corsOrigin, body): + return [ + ['HTTP/1.1 200 OK', None], + ['Content-Type', 'text/json'], + ['Access-Control-Allow-Origin', corsOrigin], + ['Access-Control-Allow-Headers', '*'], + ['Content-Length', str(len(body))] + ] + + + def buildResponse(self, headers, body): resp = bytes() - paramsError = False - try: - params = json.loads(req.body.decode('utf-8')) - except ValueError: - body = json.dumps(None).encode('utf-8') - paramsError = True - - if allowed or not paramsError and params.get('action', '') == 'requestPermission': - if len(req.body) == 0: - body = 'AnkiConnect v.{}'.format(util.setting('apiVersion')).encode('utf-8') - else: - if params.get('action', '') == 'requestPermission': - params['params'] = params.get('params', {}) - params['params']['allowed'] = allowed - params['params']['origin'] = b'origin' in req.headers and req.headers[b'origin'].decode() or '' - if not allowed : - corsOrigin = params['params']['origin'] - - body = json.dumps(self.handler(params)).encode('utf-8') - - headers = [ - ['HTTP/1.1 200 OK', None], - ['Content-Type', 'text/json'], - ['Access-Control-Allow-Origin', corsOrigin], - ['Access-Control-Allow-Headers', '*'], - ['Content-Length', str(len(body))] - ] - else : - headers = [ - ['HTTP/1.1 403 Forbidden', None], - ['Access-Control-Allow-Origin', corsOrigin], - ['Access-Control-Allow-Headers', '*'] - ] - body = ''.encode('utf-8') - for key, value in headers: if value is None: resp += '{}\r\n'.format(key).encode('utf-8') @@ -228,7 +261,6 @@ class WebServer: resp += '\r\n'.encode('utf-8') resp += body - return resp