diff --git a/plugin/__init__.py b/plugin/__init__.py index d18dff4..3c602ac 100644 --- a/plugin/__init__.py +++ b/plugin/__init__.py @@ -56,11 +56,6 @@ from . import web, util class AnkiConnect: def __init__(self): - self.log = None - logPath = util.setting('apiLogPath') - if logPath is not None: - self.log = open(logPath, 'w') - try: self.server = web.WebServer(self.handler) self.server.listen() @@ -76,21 +71,11 @@ class AnkiConnect: ) - def logEvent(self, name, data): - if self.log is not None: - self.log.write('[{}]\n'.format(name)) - json.dump(data, self.log, indent=4, sort_keys=True) - self.log.write('\n\n') - self.log.flush() - - def advance(self): self.server.advance() def handler(self, request): - self.logEvent('request', request) - name = request.get('action', '') version = request.get('version', 4) params = request.get('params', {}) @@ -131,7 +116,6 @@ class AnkiConnect: except Exception as e: reply['error'] = str(e) - self.logEvent('reply', reply) return reply diff --git a/plugin/host.py b/plugin/host.py new file mode 100644 index 0000000..3ae5400 --- /dev/null +++ b/plugin/host.py @@ -0,0 +1,64 @@ +# Copyright 2016-2021 Alex Yatskov +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from . import web + + +class ApiHost: + def __init__(self, origins, key, interval, address, port): + self.key = key + self.modules = [] + + self.server = web.WebServer(self.handler, origins) + self.server.bindAndListen(address, port) + + self.timer = QTimer() + self.timer.timeout.connect(self.advance) + self.timer.start(interval) + + + def register(self, module): + self.modules.append(module) + + + def handler(self, call): + action = call.get('action') + params = call.get('params', {}) + allowed = call.get('allowed', False) + key = call.get('key') + + try: + if key != self.key and action != 'requestPermission': + raise Exception('valid api key must be provided') + + method = None + for module in self.modules: + for methodName, methodInstance in inspect.getmembers(module, predicate=inspect.ismethod): + if getattr(methodInstance, 'api', False): + method = methodInstance + break + + if method: + return {'error': None, 'result': methodInstance(**params)} + else: + raise Exception('unsupported action') + + except Exception as e: + return {'error': str(e), 'result': None} + + +def api(method): + setattr(method, 'api', True) + return decorator diff --git a/plugin/web.py b/plugin/web.py index 9eb8e6b..70848d6 100644 --- a/plugin/web.py +++ b/plugin/web.py @@ -17,8 +17,6 @@ import json import select import socket -from . import util - # # WebRequest @@ -36,40 +34,40 @@ class WebRequest: class WebClient: def __init__(self, sock, handler): - self.sock = sock + self.socket = sock self.handler = handler self.readBuff = bytes() self.writeBuff = bytes() - def advance(self, recvSize=1024): - if self.sock is None: + def advance(self): + if not self.socket: return False - rlist, wlist = select.select([self.sock], [self.sock], [], 0)[:2] - self.sock.settimeout(5.0) + rlist, wlist = select.select([self.socket], [self.socket], [], 0)[:2] + self.socket.settimeout(5.0) if rlist: while True: try: - msg = self.sock.recv(recvSize) - except (ConnectionResetError, socket.timeout): + data = self.socket.recv(1024) + if not data: + raise Exception('failed to get data from socket') + except: self.close() return False - if not msg: - self.close() - return False - self.readBuff += msg - req, length = self.parseRequest(self.readBuff) - if req is not None: + self.readBuff += data + + request, length = self.parseRequest(self.readBuff.decode('utf-8')) + if request: self.readBuff = self.readBuff[length:] - self.writeBuff += self.handler(req) + self.writeBuff += self.handler(request).encode('utf-8') break if wlist and self.writeBuff: try: - length = self.sock.send(self.writeBuff) + length = self.socket.send(self.writeBuff) self.writeBuff = self.writeBuff[length:] if not self.writeBuff: self.close() @@ -77,36 +75,39 @@ class WebClient: except: self.close() return False + return True def close(self): - if self.sock is not None: - self.sock.close() - self.sock = None + if self.socket: + self.socket.close() + self.socket = None self.readBuff = bytes() self.writeBuff = bytes() def parseRequest(self, data): - parts = data.split('\r\n\r\n'.encode('utf-8'), 1) + parts = data.split('\r\n\r\n', 1) if len(parts) == 1: return None, 0 headers = {} - for line in parts[0].split('\r\n'.encode('utf-8')): - pair = line.split(': '.encode('utf-8')) - headers[pair[0].lower()] = pair[1] if len(pair) > 1 else None + for line in parts[0].split('\r\n'): + pair = line.split(': ', 2) + if len(pair) == 2: + headers[pair[0].lower()] = pair[1] + else: + headers[pair[0]] = None headerLength = len(parts[0]) + 4 - bodyLength = int(headers.get('content-length'.encode('utf-8'), 0)) + bodyLength = int(headers.get('content-length', '0')) totalLength = headerLength + bodyLength - if totalLength > len(data): return None, 0 - body = data[headerLength : totalLength] + body = data[headerLength:totalLength] return WebRequest(headers, body), totalLength @@ -115,126 +116,83 @@ class WebClient: # class WebServer: - def __init__(self, handler): + def __init__(self, handler, origins): self.handler = handler + self.origins = origins self.clients = [] - self.sock = None + self.socket = None def advance(self): - if self.sock is not None: + if self.socket: self.acceptClients() self.advanceClients() def acceptClients(self): - rlist = select.select([self.sock], [], [], 0)[0] + rlist = select.select([self.socket], [], [], 0)[0] if not rlist: return - clientSock = self.sock.accept()[0] - if clientSock is not None: - clientSock.setblocking(False) - self.clients.append(WebClient(clientSock, self.handlerWrapper)) + socket = self.socket.accept()[0] + if socket: + socket.setblocking(False) + self.clients.append(WebClient(socket, self.handlerWrapper)) def advanceClients(self): self.clients = list(filter(lambda c: c.advance(), self.clients)) - def listen(self): - self.close() - - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.sock.setblocking(False) - self.sock.bind((util.setting('webBindAddress'), util.setting('webBindPort'))) - self.sock.listen(util.setting('webBacklog')) + def bindAndListen(self, address, port): + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.socket.setblocking(False) + self.socket.bind((address, port)) + self.socket.listen(5) - def handlerWrapper(self, req): - # handle multiple cors origins by checking the 'origin'-header against the allowed origin list from the config - webCorsOriginList = util.setting('webCorsOriginList') - - # keep support for deprecated 'webCorsOrigin' field, as long it is not removed - webCorsOrigin = util.setting('webCorsOrigin') - if webCorsOrigin: - webCorsOriginList.append(webCorsOrigin) - - allowed = False - corsOrigin = 'http://localhost' - allowAllCors = '*' in webCorsOriginList # allow CORS for all domains - - if allowAllCors: - corsOrigin = '*' + def handlerWrapper(self, request): + if '*' in self.origins: + origin = '*' allowed = True - elif b'origin' in req.headers: - originStr = req.headers[b'origin'].decode() - if originStr in webCorsOriginList : - corsOrigin = originStr - allowed = True - elif 'http://localhost' in webCorsOriginList and ( - originStr == 'http://127.0.0.1' or originStr == 'https://127.0.0.1' or # allow 127.0.0.1 if localhost allowed - originStr.startswith('http://127.0.0.1:') or originStr.startswith('http://127.0.0.1:') or - originStr.startswith('chrome-extension://') or originStr.startswith('moz-extension://') ) : # allow chrome and firefox extension if localhost allowed - corsOrigin = originStr - allowed = True else: - allowed = True + origin = request.headers.get('origin') + allowed = origin in self.origins + + if not allowed: + origin = 'http://127.0.0.1' try: - params = json.loads(req.body.decode('utf-8')) - paramsError = False - except ValueError: - body = json.dumps(None).encode('utf-8') - paramsError = True - - if allowed or not paramsError and params.get('action', '') == 'requestPermission': - if len(req.body) == 0: - body = 'AnkiConnect v.{}'.format(util.setting('apiVersion')).encode('utf-8') + call = json.loads(request.body) + if call: + call['allowed'] = allowed + call['origin'] = origin + body = json.dumps(self.handler(call)) else: - if params.get('action', '') == 'requestPermission': - params['params'] = params.get('params', {}) - params['params']['allowed'] = allowed - params['params']['origin'] = b'origin' in req.headers and req.headers[b'origin'].decode() or '' - if not allowed : - corsOrigin = params['params']['origin'] + body = 'AnkiConnect' + except Exception as e: + body = str(e) - body = json.dumps(self.handler(params)).encode('utf-8') - - headers = [ - ['HTTP/1.1 200 OK', None], - ['Content-Type', 'text/json'], - ['Access-Control-Allow-Origin', corsOrigin], - ['Access-Control-Allow-Headers', '*'], - ['Content-Length', str(len(body))] - ] - else : - headers = [ - ['HTTP/1.1 403 Forbidden', None], - ['Access-Control-Allow-Origin', corsOrigin], - ['Access-Control-Allow-Headers', '*'] - ] - body = ''.encode('utf-8'); - - resp = bytes() + headers = [ + ['HTTP/1.1 200 OK', None], + ['Content-Type', 'text/json'], + ['Access-Control-Allow-Origin', origin], + ['Access-Control-Allow-Headers', '*'], + ['Content-Length', len(body.encode('utf-8'))] + ] + header = bytes() for key, value in headers: - if value is None: - resp += '{}\r\n'.format(key).encode('utf-8') - else: - resp += '{}: {}\r\n'.format(key, value).encode('utf-8') + header += f'{key}: {value}\r\n' - resp += '\r\n'.encode('utf-8') - resp += body - - return resp + return header + '\r\n' + body def close(self): - if self.sock is not None: - self.sock.close() - self.sock = None + if self.socket: + self.socket.close() + self.socket = None for client in self.clients: client.close()