This commit is contained in:
Alex Yatskov 2021-07-18 20:19:37 -07:00
parent 2165c6e0ca
commit 86e1258ebb
3 changed files with 34 additions and 28 deletions

View File

@ -17,8 +17,9 @@ from . import util
@util.api @util.api
def guiBrowse(self, query=None): def guiBrowse(query=None):
print(query) print(query)
return 'hi'
# browser = aqt.dialogs.open('Browser', self.window()) # browser = aqt.dialogs.open('Browser', self.window())
# browser.activateWindow() # browser.activateWindow()
# #

View File

@ -13,6 +13,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import inspect
import PyQt5 import PyQt5
from . import web from . import web
@ -20,9 +21,10 @@ from . import web
class ApiHost: class ApiHost:
def __init__(self, origins, key, address, port): def __init__(self, origins, key, address, port):
self.origins = origins
self.key = key self.key = key
self.modules = [] self.modules = []
self.server = web.WebServer(self.handler, origins) self.server = web.WebServer(self.handler)
self.server.bindAndListen(address, port) self.server.bindAndListen(address, port)
@ -39,19 +41,36 @@ class ApiHost:
def handler(self, call): def handler(self, call):
action = call.get('action') action = call.get('action')
params = call.get('params', {}) params = call.get('params', {})
allowed = call.get('allowed', False)
key = call.get('key')
try: try:
if key != self.key and action != 'requestPermission':
raise Exception('valid api key must be provided')
for module in self.modules: for module in self.modules:
for methodName, methodInstance in inspect.getmembers(module, predicate=inspect.ismethod): for funcName, funcInstance in inspect.getmembers(module, predicate=inspect.isfunction):
if methodName == action and getattr(methodInstance, 'api', False): if funcName == action and getattr(funcInstance, 'api', False):
return {'error': None, 'result': methodInstance(**params)} return {'error': None, 'result': funcInstance(**params)}
else: else:
raise Exception('unsupported action') raise Exception('unsupported action')
except Exception as e: except Exception as e:
return {'error': str(e), 'result': None} return {'error': str(e), 'result': None}
# if '*' in self.origins:
# origin = '*'
# allowed = True
# else:
# origin = request.headers.get('origin', 'http://127.0.0.1:')
# for prefix in self.origins:
# if origin.startswith(prefix):
# allowed = True
# break
#
# try:
# if request.body:
# call = json.loads(request.body)
# call['allowed'] = allowed
# call['origin'] = origin
# body = json.dumps(self.handler(call))
# else:
# body = 'AnkiConnect'
# except Exception as e:
# body = str(e)

View File

@ -116,9 +116,8 @@ class WebClient:
# #
class WebServer: class WebServer:
def __init__(self, handler, origins): def __init__(self, handler):
self.handler = handler self.handler = handler
self.origins = origins
self.clients = [] self.clients = []
self.socket = None self.socket = None
@ -153,22 +152,9 @@ class WebServer:
def handlerWrapper(self, request): def handlerWrapper(self, request):
if '*' in self.origins:
origin = '*'
allowed = True
else:
origin = request.headers.get('origin', 'http://127.0.0.1:')
for prefix in self.origins:
if origin.startswith(prefix):
allowed = True
break
try: try:
if request.body: if request.body:
call = json.loads(request.body) body = json.dumps(self.handler(json.loads(request.body)))
call['allowed'] = allowed
call['origin'] = origin
body = json.dumps(self.handler(call))
else: else:
body = 'AnkiConnect' body = 'AnkiConnect'
except Exception as e: except Exception as e:
@ -177,8 +163,8 @@ class WebServer:
headers = [ headers = [
['HTTP/1.1 200 OK', None], ['HTTP/1.1 200 OK', None],
['Content-Type', 'text/json'], ['Content-Type', 'text/json'],
['Access-Control-Allow-Origin', origin], # ['Access-Control-Allow-Origin', origin],
['Access-Control-Allow-Headers', '*'], # ['Access-Control-Allow-Headers', '*'],
['Content-Length', len(body.encode('utf-8'))] ['Content-Length', len(body.encode('utf-8'))]
] ]