Cleanup
This commit is contained in:
parent
3a4814d3a8
commit
fb2df21d01
@ -56,11 +56,6 @@ from . import web, util
|
|||||||
|
|
||||||
class AnkiConnect:
|
class AnkiConnect:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.log = None
|
|
||||||
logPath = util.setting('apiLogPath')
|
|
||||||
if logPath is not None:
|
|
||||||
self.log = open(logPath, 'w')
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.server = web.WebServer(self.handler)
|
self.server = web.WebServer(self.handler)
|
||||||
self.server.listen()
|
self.server.listen()
|
||||||
@ -76,21 +71,11 @@ class AnkiConnect:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def logEvent(self, name, data):
|
|
||||||
if self.log is not None:
|
|
||||||
self.log.write('[{}]\n'.format(name))
|
|
||||||
json.dump(data, self.log, indent=4, sort_keys=True)
|
|
||||||
self.log.write('\n\n')
|
|
||||||
self.log.flush()
|
|
||||||
|
|
||||||
|
|
||||||
def advance(self):
|
def advance(self):
|
||||||
self.server.advance()
|
self.server.advance()
|
||||||
|
|
||||||
|
|
||||||
def handler(self, request):
|
def handler(self, request):
|
||||||
self.logEvent('request', request)
|
|
||||||
|
|
||||||
name = request.get('action', '')
|
name = request.get('action', '')
|
||||||
version = request.get('version', 4)
|
version = request.get('version', 4)
|
||||||
params = request.get('params', {})
|
params = request.get('params', {})
|
||||||
@ -131,7 +116,6 @@ class AnkiConnect:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
reply['error'] = str(e)
|
reply['error'] = str(e)
|
||||||
|
|
||||||
self.logEvent('reply', reply)
|
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
|
||||||
|
64
plugin/host.py
Normal file
64
plugin/host.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# Copyright 2016-2021 Alex Yatskov
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU General Public License
|
||||||
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
from . import web
|
||||||
|
|
||||||
|
|
||||||
|
class ApiHost:
|
||||||
|
def __init__(self, origins, key, interval, address, port):
|
||||||
|
self.key = key
|
||||||
|
self.modules = []
|
||||||
|
|
||||||
|
self.server = web.WebServer(self.handler, origins)
|
||||||
|
self.server.bindAndListen(address, port)
|
||||||
|
|
||||||
|
self.timer = QTimer()
|
||||||
|
self.timer.timeout.connect(self.advance)
|
||||||
|
self.timer.start(interval)
|
||||||
|
|
||||||
|
|
||||||
|
def register(self, module):
|
||||||
|
self.modules.append(module)
|
||||||
|
|
||||||
|
|
||||||
|
def handler(self, call):
|
||||||
|
action = call.get('action')
|
||||||
|
params = call.get('params', {})
|
||||||
|
allowed = call.get('allowed', False)
|
||||||
|
key = call.get('key')
|
||||||
|
|
||||||
|
try:
|
||||||
|
if key != self.key and action != 'requestPermission':
|
||||||
|
raise Exception('valid api key must be provided')
|
||||||
|
|
||||||
|
method = None
|
||||||
|
for module in self.modules:
|
||||||
|
for methodName, methodInstance in inspect.getmembers(module, predicate=inspect.ismethod):
|
||||||
|
if getattr(methodInstance, 'api', False):
|
||||||
|
method = methodInstance
|
||||||
|
break
|
||||||
|
|
||||||
|
if method:
|
||||||
|
return {'error': None, 'result': methodInstance(**params)}
|
||||||
|
else:
|
||||||
|
raise Exception('unsupported action')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {'error': str(e), 'result': None}
|
||||||
|
|
||||||
|
|
||||||
|
def api(method):
|
||||||
|
setattr(method, 'api', True)
|
||||||
|
return decorator
|
176
plugin/web.py
176
plugin/web.py
@ -17,8 +17,6 @@ import json
|
|||||||
import select
|
import select
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
from . import util
|
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# WebRequest
|
# WebRequest
|
||||||
@ -36,40 +34,40 @@ class WebRequest:
|
|||||||
|
|
||||||
class WebClient:
|
class WebClient:
|
||||||
def __init__(self, sock, handler):
|
def __init__(self, sock, handler):
|
||||||
self.sock = sock
|
self.socket = sock
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
self.readBuff = bytes()
|
self.readBuff = bytes()
|
||||||
self.writeBuff = bytes()
|
self.writeBuff = bytes()
|
||||||
|
|
||||||
|
|
||||||
def advance(self, recvSize=1024):
|
def advance(self):
|
||||||
if self.sock is None:
|
if not self.socket:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
rlist, wlist = select.select([self.sock], [self.sock], [], 0)[:2]
|
rlist, wlist = select.select([self.socket], [self.socket], [], 0)[:2]
|
||||||
self.sock.settimeout(5.0)
|
self.socket.settimeout(5.0)
|
||||||
|
|
||||||
if rlist:
|
if rlist:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
msg = self.sock.recv(recvSize)
|
data = self.socket.recv(1024)
|
||||||
except (ConnectionResetError, socket.timeout):
|
if not data:
|
||||||
|
raise Exception('failed to get data from socket')
|
||||||
|
except:
|
||||||
self.close()
|
self.close()
|
||||||
return False
|
return False
|
||||||
if not msg:
|
|
||||||
self.close()
|
|
||||||
return False
|
|
||||||
self.readBuff += msg
|
|
||||||
|
|
||||||
req, length = self.parseRequest(self.readBuff)
|
self.readBuff += data
|
||||||
if req is not None:
|
|
||||||
|
request, length = self.parseRequest(self.readBuff.decode('utf-8'))
|
||||||
|
if request:
|
||||||
self.readBuff = self.readBuff[length:]
|
self.readBuff = self.readBuff[length:]
|
||||||
self.writeBuff += self.handler(req)
|
self.writeBuff += self.handler(request).encode('utf-8')
|
||||||
break
|
break
|
||||||
|
|
||||||
if wlist and self.writeBuff:
|
if wlist and self.writeBuff:
|
||||||
try:
|
try:
|
||||||
length = self.sock.send(self.writeBuff)
|
length = self.socket.send(self.writeBuff)
|
||||||
self.writeBuff = self.writeBuff[length:]
|
self.writeBuff = self.writeBuff[length:]
|
||||||
if not self.writeBuff:
|
if not self.writeBuff:
|
||||||
self.close()
|
self.close()
|
||||||
@ -77,36 +75,39 @@ class WebClient:
|
|||||||
except:
|
except:
|
||||||
self.close()
|
self.close()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self.sock is not None:
|
if self.socket:
|
||||||
self.sock.close()
|
self.socket.close()
|
||||||
self.sock = None
|
self.socket = None
|
||||||
|
|
||||||
self.readBuff = bytes()
|
self.readBuff = bytes()
|
||||||
self.writeBuff = bytes()
|
self.writeBuff = bytes()
|
||||||
|
|
||||||
|
|
||||||
def parseRequest(self, data):
|
def parseRequest(self, data):
|
||||||
parts = data.split('\r\n\r\n'.encode('utf-8'), 1)
|
parts = data.split('\r\n\r\n', 1)
|
||||||
if len(parts) == 1:
|
if len(parts) == 1:
|
||||||
return None, 0
|
return None, 0
|
||||||
|
|
||||||
headers = {}
|
headers = {}
|
||||||
for line in parts[0].split('\r\n'.encode('utf-8')):
|
for line in parts[0].split('\r\n'):
|
||||||
pair = line.split(': '.encode('utf-8'))
|
pair = line.split(': ', 2)
|
||||||
headers[pair[0].lower()] = pair[1] if len(pair) > 1 else None
|
if len(pair) == 2:
|
||||||
|
headers[pair[0].lower()] = pair[1]
|
||||||
|
else:
|
||||||
|
headers[pair[0]] = None
|
||||||
|
|
||||||
headerLength = len(parts[0]) + 4
|
headerLength = len(parts[0]) + 4
|
||||||
bodyLength = int(headers.get('content-length'.encode('utf-8'), 0))
|
bodyLength = int(headers.get('content-length', '0'))
|
||||||
totalLength = headerLength + bodyLength
|
totalLength = headerLength + bodyLength
|
||||||
|
|
||||||
if totalLength > len(data):
|
if totalLength > len(data):
|
||||||
return None, 0
|
return None, 0
|
||||||
|
|
||||||
body = data[headerLength : totalLength]
|
body = data[headerLength:totalLength]
|
||||||
return WebRequest(headers, body), totalLength
|
return WebRequest(headers, body), totalLength
|
||||||
|
|
||||||
|
|
||||||
@ -115,126 +116,83 @@ class WebClient:
|
|||||||
#
|
#
|
||||||
|
|
||||||
class WebServer:
|
class WebServer:
|
||||||
def __init__(self, handler):
|
def __init__(self, handler, origins):
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
|
self.origins = origins
|
||||||
self.clients = []
|
self.clients = []
|
||||||
self.sock = None
|
self.socket = None
|
||||||
|
|
||||||
|
|
||||||
def advance(self):
|
def advance(self):
|
||||||
if self.sock is not None:
|
if self.socket:
|
||||||
self.acceptClients()
|
self.acceptClients()
|
||||||
self.advanceClients()
|
self.advanceClients()
|
||||||
|
|
||||||
|
|
||||||
def acceptClients(self):
|
def acceptClients(self):
|
||||||
rlist = select.select([self.sock], [], [], 0)[0]
|
rlist = select.select([self.socket], [], [], 0)[0]
|
||||||
if not rlist:
|
if not rlist:
|
||||||
return
|
return
|
||||||
|
|
||||||
clientSock = self.sock.accept()[0]
|
socket = self.socket.accept()[0]
|
||||||
if clientSock is not None:
|
if socket:
|
||||||
clientSock.setblocking(False)
|
socket.setblocking(False)
|
||||||
self.clients.append(WebClient(clientSock, self.handlerWrapper))
|
self.clients.append(WebClient(socket, self.handlerWrapper))
|
||||||
|
|
||||||
|
|
||||||
def advanceClients(self):
|
def advanceClients(self):
|
||||||
self.clients = list(filter(lambda c: c.advance(), self.clients))
|
self.clients = list(filter(lambda c: c.advance(), self.clients))
|
||||||
|
|
||||||
|
|
||||||
def listen(self):
|
def bindAndListen(self, address, port):
|
||||||
self.close()
|
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
self.socket.setblocking(False)
|
||||||
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
self.socket.bind((address, port))
|
||||||
self.sock.setblocking(False)
|
self.socket.listen(5)
|
||||||
self.sock.bind((util.setting('webBindAddress'), util.setting('webBindPort')))
|
|
||||||
self.sock.listen(util.setting('webBacklog'))
|
|
||||||
|
|
||||||
|
|
||||||
def handlerWrapper(self, req):
|
def handlerWrapper(self, request):
|
||||||
# handle multiple cors origins by checking the 'origin'-header against the allowed origin list from the config
|
if '*' in self.origins:
|
||||||
webCorsOriginList = util.setting('webCorsOriginList')
|
origin = '*'
|
||||||
|
|
||||||
# keep support for deprecated 'webCorsOrigin' field, as long it is not removed
|
|
||||||
webCorsOrigin = util.setting('webCorsOrigin')
|
|
||||||
if webCorsOrigin:
|
|
||||||
webCorsOriginList.append(webCorsOrigin)
|
|
||||||
|
|
||||||
allowed = False
|
|
||||||
corsOrigin = 'http://localhost'
|
|
||||||
allowAllCors = '*' in webCorsOriginList # allow CORS for all domains
|
|
||||||
|
|
||||||
if allowAllCors:
|
|
||||||
corsOrigin = '*'
|
|
||||||
allowed = True
|
|
||||||
elif b'origin' in req.headers:
|
|
||||||
originStr = req.headers[b'origin'].decode()
|
|
||||||
if originStr in webCorsOriginList :
|
|
||||||
corsOrigin = originStr
|
|
||||||
allowed = True
|
|
||||||
elif 'http://localhost' in webCorsOriginList and (
|
|
||||||
originStr == 'http://127.0.0.1' or originStr == 'https://127.0.0.1' or # allow 127.0.0.1 if localhost allowed
|
|
||||||
originStr.startswith('http://127.0.0.1:') or originStr.startswith('http://127.0.0.1:') or
|
|
||||||
originStr.startswith('chrome-extension://') or originStr.startswith('moz-extension://') ) : # allow chrome and firefox extension if localhost allowed
|
|
||||||
corsOrigin = originStr
|
|
||||||
allowed = True
|
allowed = True
|
||||||
else:
|
else:
|
||||||
allowed = True
|
origin = request.headers.get('origin')
|
||||||
|
allowed = origin in self.origins
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
origin = 'http://127.0.0.1'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
params = json.loads(req.body.decode('utf-8'))
|
call = json.loads(request.body)
|
||||||
paramsError = False
|
if call:
|
||||||
except ValueError:
|
call['allowed'] = allowed
|
||||||
body = json.dumps(None).encode('utf-8')
|
call['origin'] = origin
|
||||||
paramsError = True
|
body = json.dumps(self.handler(call))
|
||||||
|
|
||||||
if allowed or not paramsError and params.get('action', '') == 'requestPermission':
|
|
||||||
if len(req.body) == 0:
|
|
||||||
body = 'AnkiConnect v.{}'.format(util.setting('apiVersion')).encode('utf-8')
|
|
||||||
else:
|
else:
|
||||||
if params.get('action', '') == 'requestPermission':
|
body = 'AnkiConnect'
|
||||||
params['params'] = params.get('params', {})
|
except Exception as e:
|
||||||
params['params']['allowed'] = allowed
|
body = str(e)
|
||||||
params['params']['origin'] = b'origin' in req.headers and req.headers[b'origin'].decode() or ''
|
|
||||||
if not allowed :
|
|
||||||
corsOrigin = params['params']['origin']
|
|
||||||
|
|
||||||
body = json.dumps(self.handler(params)).encode('utf-8')
|
|
||||||
|
|
||||||
headers = [
|
headers = [
|
||||||
['HTTP/1.1 200 OK', None],
|
['HTTP/1.1 200 OK', None],
|
||||||
['Content-Type', 'text/json'],
|
['Content-Type', 'text/json'],
|
||||||
['Access-Control-Allow-Origin', corsOrigin],
|
['Access-Control-Allow-Origin', origin],
|
||||||
['Access-Control-Allow-Headers', '*'],
|
['Access-Control-Allow-Headers', '*'],
|
||||||
['Content-Length', str(len(body))]
|
['Content-Length', len(body.encode('utf-8'))]
|
||||||
]
|
]
|
||||||
else :
|
|
||||||
headers = [
|
|
||||||
['HTTP/1.1 403 Forbidden', None],
|
|
||||||
['Access-Control-Allow-Origin', corsOrigin],
|
|
||||||
['Access-Control-Allow-Headers', '*']
|
|
||||||
]
|
|
||||||
body = ''.encode('utf-8');
|
|
||||||
|
|
||||||
resp = bytes()
|
|
||||||
|
|
||||||
|
header = bytes()
|
||||||
for key, value in headers:
|
for key, value in headers:
|
||||||
if value is None:
|
header += f'{key}: {value}\r\n'
|
||||||
resp += '{}\r\n'.format(key).encode('utf-8')
|
|
||||||
else:
|
|
||||||
resp += '{}: {}\r\n'.format(key, value).encode('utf-8')
|
|
||||||
|
|
||||||
resp += '\r\n'.encode('utf-8')
|
return header + '\r\n' + body
|
||||||
resp += body
|
|
||||||
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self.sock is not None:
|
if self.socket:
|
||||||
self.sock.close()
|
self.socket.close()
|
||||||
self.sock = None
|
self.socket = None
|
||||||
|
|
||||||
for client in self.clients:
|
for client in self.clients:
|
||||||
client.close()
|
client.close()
|
||||||
|
Loading…
Reference in New Issue
Block a user