This commit is contained in:
Alex Yatskov 2021-07-18 13:55:41 -07:00
parent 3a4814d3a8
commit fb2df21d01
3 changed files with 136 additions and 130 deletions

View File

@ -56,11 +56,6 @@ from . import web, util
class AnkiConnect: class AnkiConnect:
def __init__(self): def __init__(self):
self.log = None
logPath = util.setting('apiLogPath')
if logPath is not None:
self.log = open(logPath, 'w')
try: try:
self.server = web.WebServer(self.handler) self.server = web.WebServer(self.handler)
self.server.listen() self.server.listen()
@ -76,21 +71,11 @@ class AnkiConnect:
) )
def logEvent(self, name, data):
if self.log is not None:
self.log.write('[{}]\n'.format(name))
json.dump(data, self.log, indent=4, sort_keys=True)
self.log.write('\n\n')
self.log.flush()
def advance(self): def advance(self):
self.server.advance() self.server.advance()
def handler(self, request): def handler(self, request):
self.logEvent('request', request)
name = request.get('action', '') name = request.get('action', '')
version = request.get('version', 4) version = request.get('version', 4)
params = request.get('params', {}) params = request.get('params', {})
@ -131,7 +116,6 @@ class AnkiConnect:
except Exception as e: except Exception as e:
reply['error'] = str(e) reply['error'] = str(e)
self.logEvent('reply', reply)
return reply return reply

64
plugin/host.py Normal file
View File

@ -0,0 +1,64 @@
# Copyright 2016-2021 Alex Yatskov
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import web
class ApiHost:
def __init__(self, origins, key, interval, address, port):
self.key = key
self.modules = []
self.server = web.WebServer(self.handler, origins)
self.server.bindAndListen(address, port)
self.timer = QTimer()
self.timer.timeout.connect(self.advance)
self.timer.start(interval)
def register(self, module):
self.modules.append(module)
def handler(self, call):
action = call.get('action')
params = call.get('params', {})
allowed = call.get('allowed', False)
key = call.get('key')
try:
if key != self.key and action != 'requestPermission':
raise Exception('valid api key must be provided')
method = None
for module in self.modules:
for methodName, methodInstance in inspect.getmembers(module, predicate=inspect.ismethod):
if getattr(methodInstance, 'api', False):
method = methodInstance
break
if method:
return {'error': None, 'result': methodInstance(**params)}
else:
raise Exception('unsupported action')
except Exception as e:
return {'error': str(e), 'result': None}
def api(method):
setattr(method, 'api', True)
return decorator

View File

@ -17,8 +17,6 @@ import json
import select import select
import socket import socket
from . import util
# #
# WebRequest # WebRequest
@ -36,40 +34,40 @@ class WebRequest:
class WebClient: class WebClient:
def __init__(self, sock, handler): def __init__(self, sock, handler):
self.sock = sock self.socket = sock
self.handler = handler self.handler = handler
self.readBuff = bytes() self.readBuff = bytes()
self.writeBuff = bytes() self.writeBuff = bytes()
def advance(self, recvSize=1024): def advance(self):
if self.sock is None: if not self.socket:
return False return False
rlist, wlist = select.select([self.sock], [self.sock], [], 0)[:2] rlist, wlist = select.select([self.socket], [self.socket], [], 0)[:2]
self.sock.settimeout(5.0) self.socket.settimeout(5.0)
if rlist: if rlist:
while True: while True:
try: try:
msg = self.sock.recv(recvSize) data = self.socket.recv(1024)
except (ConnectionResetError, socket.timeout): if not data:
raise Exception('failed to get data from socket')
except:
self.close() self.close()
return False return False
if not msg:
self.close()
return False
self.readBuff += msg
req, length = self.parseRequest(self.readBuff) self.readBuff += data
if req is not None:
request, length = self.parseRequest(self.readBuff.decode('utf-8'))
if request:
self.readBuff = self.readBuff[length:] self.readBuff = self.readBuff[length:]
self.writeBuff += self.handler(req) self.writeBuff += self.handler(request).encode('utf-8')
break break
if wlist and self.writeBuff: if wlist and self.writeBuff:
try: try:
length = self.sock.send(self.writeBuff) length = self.socket.send(self.writeBuff)
self.writeBuff = self.writeBuff[length:] self.writeBuff = self.writeBuff[length:]
if not self.writeBuff: if not self.writeBuff:
self.close() self.close()
@ -77,32 +75,35 @@ class WebClient:
except: except:
self.close() self.close()
return False return False
return True return True
def close(self): def close(self):
if self.sock is not None: if self.socket:
self.sock.close() self.socket.close()
self.sock = None self.socket = None
self.readBuff = bytes() self.readBuff = bytes()
self.writeBuff = bytes() self.writeBuff = bytes()
def parseRequest(self, data): def parseRequest(self, data):
parts = data.split('\r\n\r\n'.encode('utf-8'), 1) parts = data.split('\r\n\r\n', 1)
if len(parts) == 1: if len(parts) == 1:
return None, 0 return None, 0
headers = {} headers = {}
for line in parts[0].split('\r\n'.encode('utf-8')): for line in parts[0].split('\r\n'):
pair = line.split(': '.encode('utf-8')) pair = line.split(': ', 2)
headers[pair[0].lower()] = pair[1] if len(pair) > 1 else None if len(pair) == 2:
headers[pair[0].lower()] = pair[1]
else:
headers[pair[0]] = None
headerLength = len(parts[0]) + 4 headerLength = len(parts[0]) + 4
bodyLength = int(headers.get('content-length'.encode('utf-8'), 0)) bodyLength = int(headers.get('content-length', '0'))
totalLength = headerLength + bodyLength totalLength = headerLength + bodyLength
if totalLength > len(data): if totalLength > len(data):
return None, 0 return None, 0
@ -115,126 +116,83 @@ class WebClient:
# #
class WebServer: class WebServer:
def __init__(self, handler): def __init__(self, handler, origins):
self.handler = handler self.handler = handler
self.origins = origins
self.clients = [] self.clients = []
self.sock = None self.socket = None
def advance(self): def advance(self):
if self.sock is not None: if self.socket:
self.acceptClients() self.acceptClients()
self.advanceClients() self.advanceClients()
def acceptClients(self): def acceptClients(self):
rlist = select.select([self.sock], [], [], 0)[0] rlist = select.select([self.socket], [], [], 0)[0]
if not rlist: if not rlist:
return return
clientSock = self.sock.accept()[0] socket = self.socket.accept()[0]
if clientSock is not None: if socket:
clientSock.setblocking(False) socket.setblocking(False)
self.clients.append(WebClient(clientSock, self.handlerWrapper)) self.clients.append(WebClient(socket, self.handlerWrapper))
def advanceClients(self): def advanceClients(self):
self.clients = list(filter(lambda c: c.advance(), self.clients)) self.clients = list(filter(lambda c: c.advance(), self.clients))
def listen(self): def bindAndListen(self, address, port):
self.close() self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.setblocking(False)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.bind((address, port))
self.sock.setblocking(False) self.socket.listen(5)
self.sock.bind((util.setting('webBindAddress'), util.setting('webBindPort')))
self.sock.listen(util.setting('webBacklog'))
def handlerWrapper(self, req): def handlerWrapper(self, request):
# handle multiple cors origins by checking the 'origin'-header against the allowed origin list from the config if '*' in self.origins:
webCorsOriginList = util.setting('webCorsOriginList') origin = '*'
# keep support for deprecated 'webCorsOrigin' field, as long it is not removed
webCorsOrigin = util.setting('webCorsOrigin')
if webCorsOrigin:
webCorsOriginList.append(webCorsOrigin)
allowed = False
corsOrigin = 'http://localhost'
allowAllCors = '*' in webCorsOriginList # allow CORS for all domains
if allowAllCors:
corsOrigin = '*'
allowed = True
elif b'origin' in req.headers:
originStr = req.headers[b'origin'].decode()
if originStr in webCorsOriginList :
corsOrigin = originStr
allowed = True
elif 'http://localhost' in webCorsOriginList and (
originStr == 'http://127.0.0.1' or originStr == 'https://127.0.0.1' or # allow 127.0.0.1 if localhost allowed
originStr.startswith('http://127.0.0.1:') or originStr.startswith('http://127.0.0.1:') or
originStr.startswith('chrome-extension://') or originStr.startswith('moz-extension://') ) : # allow chrome and firefox extension if localhost allowed
corsOrigin = originStr
allowed = True allowed = True
else: else:
allowed = True origin = request.headers.get('origin')
allowed = origin in self.origins
if not allowed:
origin = 'http://127.0.0.1'
try: try:
params = json.loads(req.body.decode('utf-8')) call = json.loads(request.body)
paramsError = False if call:
except ValueError: call['allowed'] = allowed
body = json.dumps(None).encode('utf-8') call['origin'] = origin
paramsError = True body = json.dumps(self.handler(call))
if allowed or not paramsError and params.get('action', '') == 'requestPermission':
if len(req.body) == 0:
body = 'AnkiConnect v.{}'.format(util.setting('apiVersion')).encode('utf-8')
else: else:
if params.get('action', '') == 'requestPermission': body = 'AnkiConnect'
params['params'] = params.get('params', {}) except Exception as e:
params['params']['allowed'] = allowed body = str(e)
params['params']['origin'] = b'origin' in req.headers and req.headers[b'origin'].decode() or ''
if not allowed :
corsOrigin = params['params']['origin']
body = json.dumps(self.handler(params)).encode('utf-8')
headers = [ headers = [
['HTTP/1.1 200 OK', None], ['HTTP/1.1 200 OK', None],
['Content-Type', 'text/json'], ['Content-Type', 'text/json'],
['Access-Control-Allow-Origin', corsOrigin], ['Access-Control-Allow-Origin', origin],
['Access-Control-Allow-Headers', '*'], ['Access-Control-Allow-Headers', '*'],
['Content-Length', str(len(body))] ['Content-Length', len(body.encode('utf-8'))]
] ]
else :
headers = [
['HTTP/1.1 403 Forbidden', None],
['Access-Control-Allow-Origin', corsOrigin],
['Access-Control-Allow-Headers', '*']
]
body = ''.encode('utf-8');
resp = bytes()
header = bytes()
for key, value in headers: for key, value in headers:
if value is None: header += f'{key}: {value}\r\n'
resp += '{}\r\n'.format(key).encode('utf-8')
else:
resp += '{}: {}\r\n'.format(key, value).encode('utf-8')
resp += '\r\n'.encode('utf-8') return header + '\r\n' + body
resp += body
return resp
def close(self): def close(self):
if self.sock is not None: if self.socket:
self.sock.close() self.socket.close()
self.sock = None self.socket = None
for client in self.clients: for client in self.clients:
client.close() client.close()