Cleanup
This commit is contained in:
parent
3a4814d3a8
commit
fb2df21d01
@ -56,11 +56,6 @@ from . import web, util
|
||||
|
||||
class AnkiConnect:
|
||||
def __init__(self):
|
||||
self.log = None
|
||||
logPath = util.setting('apiLogPath')
|
||||
if logPath is not None:
|
||||
self.log = open(logPath, 'w')
|
||||
|
||||
try:
|
||||
self.server = web.WebServer(self.handler)
|
||||
self.server.listen()
|
||||
@ -76,21 +71,11 @@ class AnkiConnect:
|
||||
)
|
||||
|
||||
|
||||
def logEvent(self, name, data):
|
||||
if self.log is not None:
|
||||
self.log.write('[{}]\n'.format(name))
|
||||
json.dump(data, self.log, indent=4, sort_keys=True)
|
||||
self.log.write('\n\n')
|
||||
self.log.flush()
|
||||
|
||||
|
||||
def advance(self):
|
||||
self.server.advance()
|
||||
|
||||
|
||||
def handler(self, request):
|
||||
self.logEvent('request', request)
|
||||
|
||||
name = request.get('action', '')
|
||||
version = request.get('version', 4)
|
||||
params = request.get('params', {})
|
||||
@ -131,7 +116,6 @@ class AnkiConnect:
|
||||
except Exception as e:
|
||||
reply['error'] = str(e)
|
||||
|
||||
self.logEvent('reply', reply)
|
||||
return reply
|
||||
|
||||
|
||||
|
64
plugin/host.py
Normal file
64
plugin/host.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright 2016-2021 Alex Yatskov
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from . import web
|
||||
|
||||
|
||||
class ApiHost:
|
||||
def __init__(self, origins, key, interval, address, port):
|
||||
self.key = key
|
||||
self.modules = []
|
||||
|
||||
self.server = web.WebServer(self.handler, origins)
|
||||
self.server.bindAndListen(address, port)
|
||||
|
||||
self.timer = QTimer()
|
||||
self.timer.timeout.connect(self.advance)
|
||||
self.timer.start(interval)
|
||||
|
||||
|
||||
def register(self, module):
|
||||
self.modules.append(module)
|
||||
|
||||
|
||||
def handler(self, call):
|
||||
action = call.get('action')
|
||||
params = call.get('params', {})
|
||||
allowed = call.get('allowed', False)
|
||||
key = call.get('key')
|
||||
|
||||
try:
|
||||
if key != self.key and action != 'requestPermission':
|
||||
raise Exception('valid api key must be provided')
|
||||
|
||||
method = None
|
||||
for module in self.modules:
|
||||
for methodName, methodInstance in inspect.getmembers(module, predicate=inspect.ismethod):
|
||||
if getattr(methodInstance, 'api', False):
|
||||
method = methodInstance
|
||||
break
|
||||
|
||||
if method:
|
||||
return {'error': None, 'result': methodInstance(**params)}
|
||||
else:
|
||||
raise Exception('unsupported action')
|
||||
|
||||
except Exception as e:
|
||||
return {'error': str(e), 'result': None}
|
||||
|
||||
|
||||
def api(method):
|
||||
setattr(method, 'api', True)
|
||||
return decorator
|
186
plugin/web.py
186
plugin/web.py
@ -17,8 +17,6 @@ import json
|
||||
import select
|
||||
import socket
|
||||
|
||||
from . import util
|
||||
|
||||
|
||||
#
|
||||
# WebRequest
|
||||
@ -36,40 +34,40 @@ class WebRequest:
|
||||
|
||||
class WebClient:
|
||||
def __init__(self, sock, handler):
|
||||
self.sock = sock
|
||||
self.socket = sock
|
||||
self.handler = handler
|
||||
self.readBuff = bytes()
|
||||
self.writeBuff = bytes()
|
||||
|
||||
|
||||
def advance(self, recvSize=1024):
|
||||
if self.sock is None:
|
||||
def advance(self):
|
||||
if not self.socket:
|
||||
return False
|
||||
|
||||
rlist, wlist = select.select([self.sock], [self.sock], [], 0)[:2]
|
||||
self.sock.settimeout(5.0)
|
||||
rlist, wlist = select.select([self.socket], [self.socket], [], 0)[:2]
|
||||
self.socket.settimeout(5.0)
|
||||
|
||||
if rlist:
|
||||
while True:
|
||||
try:
|
||||
msg = self.sock.recv(recvSize)
|
||||
except (ConnectionResetError, socket.timeout):
|
||||
data = self.socket.recv(1024)
|
||||
if not data:
|
||||
raise Exception('failed to get data from socket')
|
||||
except:
|
||||
self.close()
|
||||
return False
|
||||
if not msg:
|
||||
self.close()
|
||||
return False
|
||||
self.readBuff += msg
|
||||
|
||||
req, length = self.parseRequest(self.readBuff)
|
||||
if req is not None:
|
||||
self.readBuff += data
|
||||
|
||||
request, length = self.parseRequest(self.readBuff.decode('utf-8'))
|
||||
if request:
|
||||
self.readBuff = self.readBuff[length:]
|
||||
self.writeBuff += self.handler(req)
|
||||
self.writeBuff += self.handler(request).encode('utf-8')
|
||||
break
|
||||
|
||||
if wlist and self.writeBuff:
|
||||
try:
|
||||
length = self.sock.send(self.writeBuff)
|
||||
length = self.socket.send(self.writeBuff)
|
||||
self.writeBuff = self.writeBuff[length:]
|
||||
if not self.writeBuff:
|
||||
self.close()
|
||||
@ -77,36 +75,39 @@ class WebClient:
|
||||
except:
|
||||
self.close()
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def close(self):
|
||||
if self.sock is not None:
|
||||
self.sock.close()
|
||||
self.sock = None
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
|
||||
self.readBuff = bytes()
|
||||
self.writeBuff = bytes()
|
||||
|
||||
|
||||
def parseRequest(self, data):
|
||||
parts = data.split('\r\n\r\n'.encode('utf-8'), 1)
|
||||
parts = data.split('\r\n\r\n', 1)
|
||||
if len(parts) == 1:
|
||||
return None, 0
|
||||
|
||||
headers = {}
|
||||
for line in parts[0].split('\r\n'.encode('utf-8')):
|
||||
pair = line.split(': '.encode('utf-8'))
|
||||
headers[pair[0].lower()] = pair[1] if len(pair) > 1 else None
|
||||
for line in parts[0].split('\r\n'):
|
||||
pair = line.split(': ', 2)
|
||||
if len(pair) == 2:
|
||||
headers[pair[0].lower()] = pair[1]
|
||||
else:
|
||||
headers[pair[0]] = None
|
||||
|
||||
headerLength = len(parts[0]) + 4
|
||||
bodyLength = int(headers.get('content-length'.encode('utf-8'), 0))
|
||||
bodyLength = int(headers.get('content-length', '0'))
|
||||
totalLength = headerLength + bodyLength
|
||||
|
||||
if totalLength > len(data):
|
||||
return None, 0
|
||||
|
||||
body = data[headerLength : totalLength]
|
||||
body = data[headerLength:totalLength]
|
||||
return WebRequest(headers, body), totalLength
|
||||
|
||||
|
||||
@ -115,126 +116,83 @@ class WebClient:
|
||||
#
|
||||
|
||||
class WebServer:
|
||||
def __init__(self, handler):
|
||||
def __init__(self, handler, origins):
|
||||
self.handler = handler
|
||||
self.origins = origins
|
||||
self.clients = []
|
||||
self.sock = None
|
||||
self.socket = None
|
||||
|
||||
|
||||
def advance(self):
|
||||
if self.sock is not None:
|
||||
if self.socket:
|
||||
self.acceptClients()
|
||||
self.advanceClients()
|
||||
|
||||
|
||||
def acceptClients(self):
|
||||
rlist = select.select([self.sock], [], [], 0)[0]
|
||||
rlist = select.select([self.socket], [], [], 0)[0]
|
||||
if not rlist:
|
||||
return
|
||||
|
||||
clientSock = self.sock.accept()[0]
|
||||
if clientSock is not None:
|
||||
clientSock.setblocking(False)
|
||||
self.clients.append(WebClient(clientSock, self.handlerWrapper))
|
||||
socket = self.socket.accept()[0]
|
||||
if socket:
|
||||
socket.setblocking(False)
|
||||
self.clients.append(WebClient(socket, self.handlerWrapper))
|
||||
|
||||
|
||||
def advanceClients(self):
|
||||
self.clients = list(filter(lambda c: c.advance(), self.clients))
|
||||
|
||||
|
||||
def listen(self):
|
||||
self.close()
|
||||
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.sock.setblocking(False)
|
||||
self.sock.bind((util.setting('webBindAddress'), util.setting('webBindPort')))
|
||||
self.sock.listen(util.setting('webBacklog'))
|
||||
def bindAndListen(self, address, port):
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.socket.setblocking(False)
|
||||
self.socket.bind((address, port))
|
||||
self.socket.listen(5)
|
||||
|
||||
|
||||
def handlerWrapper(self, req):
|
||||
# handle multiple cors origins by checking the 'origin'-header against the allowed origin list from the config
|
||||
webCorsOriginList = util.setting('webCorsOriginList')
|
||||
|
||||
# keep support for deprecated 'webCorsOrigin' field, as long it is not removed
|
||||
webCorsOrigin = util.setting('webCorsOrigin')
|
||||
if webCorsOrigin:
|
||||
webCorsOriginList.append(webCorsOrigin)
|
||||
|
||||
allowed = False
|
||||
corsOrigin = 'http://localhost'
|
||||
allowAllCors = '*' in webCorsOriginList # allow CORS for all domains
|
||||
|
||||
if allowAllCors:
|
||||
corsOrigin = '*'
|
||||
def handlerWrapper(self, request):
|
||||
if '*' in self.origins:
|
||||
origin = '*'
|
||||
allowed = True
|
||||
elif b'origin' in req.headers:
|
||||
originStr = req.headers[b'origin'].decode()
|
||||
if originStr in webCorsOriginList :
|
||||
corsOrigin = originStr
|
||||
allowed = True
|
||||
elif 'http://localhost' in webCorsOriginList and (
|
||||
originStr == 'http://127.0.0.1' or originStr == 'https://127.0.0.1' or # allow 127.0.0.1 if localhost allowed
|
||||
originStr.startswith('http://127.0.0.1:') or originStr.startswith('http://127.0.0.1:') or
|
||||
originStr.startswith('chrome-extension://') or originStr.startswith('moz-extension://') ) : # allow chrome and firefox extension if localhost allowed
|
||||
corsOrigin = originStr
|
||||
allowed = True
|
||||
else:
|
||||
allowed = True
|
||||
origin = request.headers.get('origin')
|
||||
allowed = origin in self.origins
|
||||
|
||||
if not allowed:
|
||||
origin = 'http://127.0.0.1'
|
||||
|
||||
try:
|
||||
params = json.loads(req.body.decode('utf-8'))
|
||||
paramsError = False
|
||||
except ValueError:
|
||||
body = json.dumps(None).encode('utf-8')
|
||||
paramsError = True
|
||||
|
||||
if allowed or not paramsError and params.get('action', '') == 'requestPermission':
|
||||
if len(req.body) == 0:
|
||||
body = 'AnkiConnect v.{}'.format(util.setting('apiVersion')).encode('utf-8')
|
||||
call = json.loads(request.body)
|
||||
if call:
|
||||
call['allowed'] = allowed
|
||||
call['origin'] = origin
|
||||
body = json.dumps(self.handler(call))
|
||||
else:
|
||||
if params.get('action', '') == 'requestPermission':
|
||||
params['params'] = params.get('params', {})
|
||||
params['params']['allowed'] = allowed
|
||||
params['params']['origin'] = b'origin' in req.headers and req.headers[b'origin'].decode() or ''
|
||||
if not allowed :
|
||||
corsOrigin = params['params']['origin']
|
||||
body = 'AnkiConnect'
|
||||
except Exception as e:
|
||||
body = str(e)
|
||||
|
||||
body = json.dumps(self.handler(params)).encode('utf-8')
|
||||
|
||||
headers = [
|
||||
['HTTP/1.1 200 OK', None],
|
||||
['Content-Type', 'text/json'],
|
||||
['Access-Control-Allow-Origin', corsOrigin],
|
||||
['Access-Control-Allow-Headers', '*'],
|
||||
['Content-Length', str(len(body))]
|
||||
]
|
||||
else :
|
||||
headers = [
|
||||
['HTTP/1.1 403 Forbidden', None],
|
||||
['Access-Control-Allow-Origin', corsOrigin],
|
||||
['Access-Control-Allow-Headers', '*']
|
||||
]
|
||||
body = ''.encode('utf-8');
|
||||
|
||||
resp = bytes()
|
||||
headers = [
|
||||
['HTTP/1.1 200 OK', None],
|
||||
['Content-Type', 'text/json'],
|
||||
['Access-Control-Allow-Origin', origin],
|
||||
['Access-Control-Allow-Headers', '*'],
|
||||
['Content-Length', len(body.encode('utf-8'))]
|
||||
]
|
||||
|
||||
header = bytes()
|
||||
for key, value in headers:
|
||||
if value is None:
|
||||
resp += '{}\r\n'.format(key).encode('utf-8')
|
||||
else:
|
||||
resp += '{}: {}\r\n'.format(key, value).encode('utf-8')
|
||||
header += f'{key}: {value}\r\n'
|
||||
|
||||
resp += '\r\n'.encode('utf-8')
|
||||
resp += body
|
||||
|
||||
return resp
|
||||
return header + '\r\n' + body
|
||||
|
||||
|
||||
def close(self):
|
||||
if self.sock is not None:
|
||||
self.sock.close()
|
||||
self.sock = None
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
|
||||
for client in self.clients:
|
||||
client.close()
|
||||
|
Loading…
Reference in New Issue
Block a user