diff --git a/plugin/__init__.py b/plugin/__init__.py index 898f81b..3ff0de2 100644 --- a/plugin/__init__.py +++ b/plugin/__init__.py @@ -43,6 +43,7 @@ from anki.notes import Note from anki.errors import NotFoundError from aqt.qt import Qt, QTimer, QMessageBox, QCheckBox +from .web import format_exception_reply, format_success_reply from .edit import Edit from . import web, util @@ -99,7 +100,6 @@ class AnkiConnect: version = request.get('version', 4) params = request.get('params', {}) key = request.get('key') - reply = {'result': None, 'error': None} try: if key != util.setting('apiKey') and name != 'requestPermission': @@ -126,14 +126,12 @@ class AnkiConnect: if method is None: raise Exception('unsupported action') - else: - reply['result'] = methodInst(**params) - if version <= 4: - reply = reply['result'] + api_return_value = methodInst(**params) + reply = format_success_reply(version, api_return_value) except Exception as e: - reply['error'] = str(e) + reply = format_exception_reply(version, e) self.logEvent('reply', reply) return reply diff --git a/plugin/web.py b/plugin/web.py index e386c6a..5300bfc 100644 --- a/plugin/web.py +++ b/plugin/web.py @@ -175,27 +175,29 @@ class WebServer: return self.buildResponse(headers, body) - paramsError = False - try: params = json.loads(req.body.decode('utf-8')) - except ValueError: - body = json.dumps(None).encode('utf-8') - paramsError = True - - if allowed or not paramsError and params.get('action', '') == 'requestPermission': - if len(req.body) == 0: - body = 'AnkiConnect v.{}'.format(util.setting('apiVersion')).encode('utf-8') + except ValueError as e: + if allowed: + if len(req.body) == 0: + body = f"AnkiConnect v.{util.setting('apiVersion')}".encode() + else: + reply = format_exception_reply(util.setting('apiVersion'), e) + body = json.dumps(reply).encode('utf-8') + headers = self.buildHeaders(corsOrigin, body) + return self.buildResponse(headers, body) else: - if params.get('action', '') == 'requestPermission': - params['params'] = params.get('params', {}) - params['params']['allowed'] = allowed - params['params']['origin'] = b'origin' in req.headers and req.headers[b'origin'].decode() or '' - if not allowed : - corsOrigin = params['params']['origin'] + params = {} # trigger the 403 response below + + if allowed or params.get('action', '') == 'requestPermission': + if params.get('action', '') == 'requestPermission': + params['params'] = params.get('params', {}) + params['params']['allowed'] = allowed + params['params']['origin'] = b'origin' in req.headers and req.headers[b'origin'].decode() or '' + if not allowed : + corsOrigin = params['params']['origin'] - body = json.dumps(self.handler(params)).encode('utf-8') - + body = json.dumps(self.handler(params)).encode('utf-8') headers = self.buildHeaders(corsOrigin, body) else : headers = [ @@ -273,3 +275,14 @@ class WebServer: client.close() self.clients = [] + + +def format_success_reply(api_version, result): + if api_version <= 4: + return result + else: + return {"result": result, "error": None} + + +def format_exception_reply(_api_version, exception): + return {"result": None, "error": str(exception)} diff --git a/tests/test_server.py b/tests/test_server.py index 1f33dcf..1f13157 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -46,11 +46,14 @@ class Client: return {"action": action, "params": params, "version": 6} def send_request(self, action, **params): - request_url = f"http://localhost:{self.port}" request_data = self.make_request(action, **params) - request_json = json.dumps(request_data).encode("utf-8") - request = urllib.request.Request(request_url, request_json) - response = json.load(urllib.request.urlopen(request)) + json_bytes = json.dumps(request_data).encode("utf-8") + return json.loads(self.send_bytes(json_bytes)) + + def send_bytes(self, bytes, headers={}): # noqa + request_url = f"http://localhost:{self.port}" + request = urllib.request.Request(request_url, bytes, headers) + response = urllib.request.urlopen(request).read() return response def wait_for_web_server_to_come_live(self, at_most_seconds=30): @@ -137,6 +140,11 @@ def test_multi_request(external_anki): } +def test_request_with_empty_body_returns_version_banner(external_anki): + response = external_anki.send_bytes(b"") + assert response == b"AnkiConnect v.6" + + def test_failing_request_due_to_bad_arguments(external_anki): response = external_anki.send_request("addNote", bad="request") assert response["result"] is None @@ -147,3 +155,24 @@ def test_failing_request_due_to_anki_raising_exception(external_anki): response = external_anki.send_request("suspend", cards=[-123]) assert response["result"] is None assert "Card was not found" in response["error"] + + +def test_failing_request_due_to_bad_encoding(external_anki): + response = json.loads(external_anki.send_bytes(b"\xe7\x8c")) + assert response["result"] is None + assert "can't decode" in response["error"] + + +def test_failing_request_due_to_bad_json(external_anki): + response = json.loads(external_anki.send_bytes(b'{1: 2}')) + assert response["result"] is None + assert "in double quotes" in response["error"] + + +def test_403_in_case_of_disallowed_origin(external_anki): + with pytest.raises(urllib.error.HTTPError, match="403"): # good request/json + json_bytes = json.dumps(Client.make_request("version")).encode("utf-8") + external_anki.send_bytes(json_bytes, headers={b"origin": b"foo"}) + + with pytest.raises(urllib.error.HTTPError, match="403"): # bad json + external_anki.send_bytes(b'{1: 2}', headers={b"origin": b"foo"})