This commit is contained in:
Alex Yatskov 2021-07-18 13:55:41 -07:00
parent 3a4814d3a8
commit fb2df21d01
3 changed files with 136 additions and 130 deletions

View File

@ -56,11 +56,6 @@ from . import web, util
class AnkiConnect:
def __init__(self):
self.log = None
logPath = util.setting('apiLogPath')
if logPath is not None:
self.log = open(logPath, 'w')
try:
self.server = web.WebServer(self.handler)
self.server.listen()
@ -76,21 +71,11 @@ class AnkiConnect:
)
def logEvent(self, name, data):
if self.log is not None:
self.log.write('[{}]\n'.format(name))
json.dump(data, self.log, indent=4, sort_keys=True)
self.log.write('\n\n')
self.log.flush()
def advance(self):
self.server.advance()
def handler(self, request):
self.logEvent('request', request)
name = request.get('action', '')
version = request.get('version', 4)
params = request.get('params', {})
@ -131,7 +116,6 @@ class AnkiConnect:
except Exception as e:
reply['error'] = str(e)
self.logEvent('reply', reply)
return reply

64
plugin/host.py Normal file
View File

@ -0,0 +1,64 @@
# Copyright 2016-2021 Alex Yatskov
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from . import web
class ApiHost:
def __init__(self, origins, key, interval, address, port):
self.key = key
self.modules = []
self.server = web.WebServer(self.handler, origins)
self.server.bindAndListen(address, port)
self.timer = QTimer()
self.timer.timeout.connect(self.advance)
self.timer.start(interval)
def register(self, module):
self.modules.append(module)
def handler(self, call):
action = call.get('action')
params = call.get('params', {})
allowed = call.get('allowed', False)
key = call.get('key')
try:
if key != self.key and action != 'requestPermission':
raise Exception('valid api key must be provided')
method = None
for module in self.modules:
for methodName, methodInstance in inspect.getmembers(module, predicate=inspect.ismethod):
if getattr(methodInstance, 'api', False):
method = methodInstance
break
if method:
return {'error': None, 'result': methodInstance(**params)}
else:
raise Exception('unsupported action')
except Exception as e:
return {'error': str(e), 'result': None}
def api(method):
setattr(method, 'api', True)
return decorator

View File

@ -17,8 +17,6 @@ import json
import select
import socket
from . import util
#
# WebRequest
@ -36,40 +34,40 @@ class WebRequest:
class WebClient:
def __init__(self, sock, handler):
self.sock = sock
self.socket = sock
self.handler = handler
self.readBuff = bytes()
self.writeBuff = bytes()
def advance(self, recvSize=1024):
if self.sock is None:
def advance(self):
if not self.socket:
return False
rlist, wlist = select.select([self.sock], [self.sock], [], 0)[:2]
self.sock.settimeout(5.0)
rlist, wlist = select.select([self.socket], [self.socket], [], 0)[:2]
self.socket.settimeout(5.0)
if rlist:
while True:
try:
msg = self.sock.recv(recvSize)
except (ConnectionResetError, socket.timeout):
data = self.socket.recv(1024)
if not data:
raise Exception('failed to get data from socket')
except:
self.close()
return False
if not msg:
self.close()
return False
self.readBuff += msg
req, length = self.parseRequest(self.readBuff)
if req is not None:
self.readBuff += data
request, length = self.parseRequest(self.readBuff.decode('utf-8'))
if request:
self.readBuff = self.readBuff[length:]
self.writeBuff += self.handler(req)
self.writeBuff += self.handler(request).encode('utf-8')
break
if wlist and self.writeBuff:
try:
length = self.sock.send(self.writeBuff)
length = self.socket.send(self.writeBuff)
self.writeBuff = self.writeBuff[length:]
if not self.writeBuff:
self.close()
@ -77,36 +75,39 @@ class WebClient:
except:
self.close()
return False
return True
def close(self):
if self.sock is not None:
self.sock.close()
self.sock = None
if self.socket:
self.socket.close()
self.socket = None
self.readBuff = bytes()
self.writeBuff = bytes()
def parseRequest(self, data):
parts = data.split('\r\n\r\n'.encode('utf-8'), 1)
parts = data.split('\r\n\r\n', 1)
if len(parts) == 1:
return None, 0
headers = {}
for line in parts[0].split('\r\n'.encode('utf-8')):
pair = line.split(': '.encode('utf-8'))
headers[pair[0].lower()] = pair[1] if len(pair) > 1 else None
for line in parts[0].split('\r\n'):
pair = line.split(': ', 2)
if len(pair) == 2:
headers[pair[0].lower()] = pair[1]
else:
headers[pair[0]] = None
headerLength = len(parts[0]) + 4
bodyLength = int(headers.get('content-length'.encode('utf-8'), 0))
bodyLength = int(headers.get('content-length', '0'))
totalLength = headerLength + bodyLength
if totalLength > len(data):
return None, 0
body = data[headerLength : totalLength]
body = data[headerLength:totalLength]
return WebRequest(headers, body), totalLength
@ -115,126 +116,83 @@ class WebClient:
#
class WebServer:
def __init__(self, handler):
def __init__(self, handler, origins):
self.handler = handler
self.origins = origins
self.clients = []
self.sock = None
self.socket = None
def advance(self):
if self.sock is not None:
if self.socket:
self.acceptClients()
self.advanceClients()
def acceptClients(self):
rlist = select.select([self.sock], [], [], 0)[0]
rlist = select.select([self.socket], [], [], 0)[0]
if not rlist:
return
clientSock = self.sock.accept()[0]
if clientSock is not None:
clientSock.setblocking(False)
self.clients.append(WebClient(clientSock, self.handlerWrapper))
socket = self.socket.accept()[0]
if socket:
socket.setblocking(False)
self.clients.append(WebClient(socket, self.handlerWrapper))
def advanceClients(self):
self.clients = list(filter(lambda c: c.advance(), self.clients))
def listen(self):
self.close()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.setblocking(False)
self.sock.bind((util.setting('webBindAddress'), util.setting('webBindPort')))
self.sock.listen(util.setting('webBacklog'))
def bindAndListen(self, address, port):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.setblocking(False)
self.socket.bind((address, port))
self.socket.listen(5)
def handlerWrapper(self, req):
# handle multiple cors origins by checking the 'origin'-header against the allowed origin list from the config
webCorsOriginList = util.setting('webCorsOriginList')
# keep support for deprecated 'webCorsOrigin' field, as long it is not removed
webCorsOrigin = util.setting('webCorsOrigin')
if webCorsOrigin:
webCorsOriginList.append(webCorsOrigin)
allowed = False
corsOrigin = 'http://localhost'
allowAllCors = '*' in webCorsOriginList # allow CORS for all domains
if allowAllCors:
corsOrigin = '*'
def handlerWrapper(self, request):
if '*' in self.origins:
origin = '*'
allowed = True
elif b'origin' in req.headers:
originStr = req.headers[b'origin'].decode()
if originStr in webCorsOriginList :
corsOrigin = originStr
allowed = True
elif 'http://localhost' in webCorsOriginList and (
originStr == 'http://127.0.0.1' or originStr == 'https://127.0.0.1' or # allow 127.0.0.1 if localhost allowed
originStr.startswith('http://127.0.0.1:') or originStr.startswith('http://127.0.0.1:') or
originStr.startswith('chrome-extension://') or originStr.startswith('moz-extension://') ) : # allow chrome and firefox extension if localhost allowed
corsOrigin = originStr
allowed = True
else:
allowed = True
origin = request.headers.get('origin')
allowed = origin in self.origins
if not allowed:
origin = 'http://127.0.0.1'
try:
params = json.loads(req.body.decode('utf-8'))
paramsError = False
except ValueError:
body = json.dumps(None).encode('utf-8')
paramsError = True
if allowed or not paramsError and params.get('action', '') == 'requestPermission':
if len(req.body) == 0:
body = 'AnkiConnect v.{}'.format(util.setting('apiVersion')).encode('utf-8')
call = json.loads(request.body)
if call:
call['allowed'] = allowed
call['origin'] = origin
body = json.dumps(self.handler(call))
else:
if params.get('action', '') == 'requestPermission':
params['params'] = params.get('params', {})
params['params']['allowed'] = allowed
params['params']['origin'] = b'origin' in req.headers and req.headers[b'origin'].decode() or ''
if not allowed :
corsOrigin = params['params']['origin']
body = 'AnkiConnect'
except Exception as e:
body = str(e)
body = json.dumps(self.handler(params)).encode('utf-8')
headers = [
['HTTP/1.1 200 OK', None],
['Content-Type', 'text/json'],
['Access-Control-Allow-Origin', corsOrigin],
['Access-Control-Allow-Headers', '*'],
['Content-Length', str(len(body))]
]
else :
headers = [
['HTTP/1.1 403 Forbidden', None],
['Access-Control-Allow-Origin', corsOrigin],
['Access-Control-Allow-Headers', '*']
]
body = ''.encode('utf-8');
resp = bytes()
headers = [
['HTTP/1.1 200 OK', None],
['Content-Type', 'text/json'],
['Access-Control-Allow-Origin', origin],
['Access-Control-Allow-Headers', '*'],
['Content-Length', len(body.encode('utf-8'))]
]
header = bytes()
for key, value in headers:
if value is None:
resp += '{}\r\n'.format(key).encode('utf-8')
else:
resp += '{}: {}\r\n'.format(key, value).encode('utf-8')
header += f'{key}: {value}\r\n'
resp += '\r\n'.encode('utf-8')
resp += body
return resp
return header + '\r\n' + body
def close(self):
if self.sock is not None:
self.sock.close()
self.sock = None
if self.socket:
self.socket.close()
self.socket = None
for client in self.clients:
client.close()