diff --git a/plugin/__init__.py b/plugin/__init__.py
index 3afe6b8..7d3c2ab 100644
--- a/plugin/__init__.py
+++ b/plugin/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2016-2019 Alex Yatskov
+# Copyright 2016-2020 Alex Yatskov
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
@@ -13,251 +13,25 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-
-import anki
-import aqt
import base64
import hashlib
import inspect
import json
import os
import os.path
+import random
import re
-import select
-import socket
-import sys
-
-from operator import itemgetter
-from time import time
-from unicodedata import normalize
-from random import choice
-from string import ascii_letters
-
-#
-# Constants
-#
-
-API_VERSION = 6
-API_LOG_PATH = None
-NET_CORS_ORIGIN = os.getenv('ANKICONNECT_CORS_ORIGIN', 'http://localhost')
-NET_ADDRESS = os.getenv('ANKICONNECT_BIND_ADDRESS', '127.0.0.1')
-NET_BACKLOG = 5
-NET_PORT = 8765
-TICK_INTERVAL = 25
-URL_TIMEOUT = 10
-URL_UPGRADE = 'https://raw.githubusercontent.com/FooSoft/anki-connect/master/AnkiConnect.py'
-
-config = aqt.mw.addonManager.getConfig('AnkiConnect')
-
-#
-# Helpers
-#
-
-from anki.sync import AnkiRequestsClient
-def download(url):
- contents = None
- client = AnkiRequestsClient()
- client.timeout = URL_TIMEOUT
- resp = client.get(url)
- if resp.status_code == 200:
- contents = client.streamContent(resp)
- return (resp.status_code, contents)
+import string
+import time
+import unicodedata
from PyQt5.QtCore import QTimer
from PyQt5.QtWidgets import QMessageBox
+import anki
+import aqt
-def api(*versions):
- def decorator(func):
- method = lambda *args, **kwargs: func(*args, **kwargs)
- setattr(method, 'versions', versions)
- setattr(method, 'api', True)
- return method
-
- return decorator
-
-
-#
-# WebRequest
-#
-
-class WebRequest:
- def __init__(self, headers, body):
- self.headers = headers
- self.body = body
-
-
-#
-# WebClient
-#
-
-class WebClient:
- def __init__(self, sock, handler):
- self.sock = sock
- self.handler = handler
- self.readBuff = bytes()
- self.writeBuff = bytes()
-
-
- def advance(self, recvSize=1024):
- if self.sock is None:
- return False
-
- rlist, wlist = select.select([self.sock], [self.sock], [], 0)[:2]
-
- if rlist:
- msg = self.sock.recv(recvSize)
- if not msg:
- self.close()
- return False
-
- self.readBuff += msg
-
- req, length = self.parseRequest(self.readBuff)
- if req is not None:
- self.readBuff = self.readBuff[length:]
- self.writeBuff += self.handler(req)
-
- if wlist and self.writeBuff:
- length = self.sock.send(self.writeBuff)
- self.writeBuff = self.writeBuff[length:]
- if not self.writeBuff:
- self.close()
- return False
-
- return True
-
-
- def close(self):
- if self.sock is not None:
- self.sock.close()
- self.sock = None
-
- self.readBuff = bytes()
- self.writeBuff = bytes()
-
-
- def parseRequest(self, data):
- parts = data.split('\r\n\r\n'.encode('utf-8'), 1)
- if len(parts) == 1:
- return None, 0
-
- headers = {}
- for line in parts[0].split('\r\n'.encode('utf-8')):
- pair = line.split(': '.encode('utf-8'))
- headers[pair[0].lower()] = pair[1] if len(pair) > 1 else None
-
- headerLength = len(parts[0]) + 4
- bodyLength = int(headers.get('content-length'.encode('utf-8'), 0))
- totalLength = headerLength + bodyLength
-
- if totalLength > len(data):
- return None, 0
-
- body = data[headerLength : totalLength]
- return WebRequest(headers, body), totalLength
-
-
-#
-# WebServer
-#
-
-class WebServer:
- def __init__(self, handler):
- self.handler = handler
- self.clients = []
- self.sock = None
- self.resetHeaders()
-
-
- def setHeader(self, name, value):
- self.headersOpt[name] = value
-
-
- def resetHeaders(self):
- self.headers = [
- ['HTTP/1.1 200 OK', None],
- ['Content-Type', 'text/json'],
- ['Access-Control-Allow-Origin', NET_CORS_ORIGIN]
- ]
- self.headersOpt = {}
-
-
- def getHeaders(self):
- headers = self.headers[:]
- for name in self.headersOpt:
- headers.append([name, self.headersOpt[name]])
-
- return headers
-
-
- def advance(self):
- if self.sock is not None:
- self.acceptClients()
- self.advanceClients()
-
-
- def acceptClients(self):
- rlist = select.select([self.sock], [], [], 0)[0]
- if not rlist:
- return
-
- clientSock = self.sock.accept()[0]
- if clientSock is not None:
- clientSock.setblocking(False)
- self.clients.append(WebClient(clientSock, self.handlerWrapper))
-
-
- def advanceClients(self):
- self.clients = list(filter(lambda c: c.advance(), self.clients))
-
-
- def listen(self):
- self.close()
-
- self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- self.sock.setblocking(False)
- self.sock.bind((NET_ADDRESS, NET_PORT))
- self.sock.listen(NET_BACKLOG)
-
-
- def handlerWrapper(self, req):
- if len(req.body) == 0:
- body = 'AnkiConnect v.{}'.format(API_VERSION).encode('utf-8')
- else:
- try:
- params = json.loads(req.body.decode('utf-8'))
- body = json.dumps(self.handler(params)).encode('utf-8')
- except ValueError:
- body = json.dumps(None).encode('utf-8')
-
- resp = bytes()
-
- self.setHeader('Content-Length', str(len(body)))
- headers = self.getHeaders()
-
- for key, value in headers:
- if value is None:
- resp += '{}\r\n'.format(key).encode('utf-8')
- else:
- resp += '{}: {}\r\n'.format(key, value).encode('utf-8')
-
- resp += '\r\n'.encode('utf-8')
- resp += body
-
- return resp
-
-
- def close(self):
- if self.sock is not None:
- self.sock.close()
- self.sock = None
-
- for client in self.clients:
- client.close()
-
- self.clients = []
+from AnkiConnect import web, util
#
@@ -266,34 +40,39 @@ class WebServer:
class AnkiConnect:
def __init__(self):
- self.server = WebServer(self.handler)
self.log = None
- if API_LOG_PATH is not None:
- self.log = open(API_LOG_PATH, 'w')
+ logPath = util.setting('apiLogPath')
+ if logPath is not None:
+ self.log = open(logPath, 'w')
try:
+ self.server = web.WebServer(self.handler)
self.server.listen()
self.timer = QTimer()
self.timer.timeout.connect(self.advance)
- self.timer.start(TICK_INTERVAL)
+ self.timer.start(util.setting('apiPollInterval'))
except:
QMessageBox.critical(
self.window(),
'AnkiConnect',
- 'Failed to listen on port {}.\nMake sure it is available and is not in use.'.format(NET_PORT)
+ 'Failed to listen on port {}.\nMake sure it is available and is not in use.'.format(util.setting('webBindPort'))
)
+ def logEvent(self, name, data):
+ if self.log is not None:
+ self.log.write('[{}]\n'.format(name))
+ json.dump(data, self.log, indent=4, sort_keys=True)
+ self.log.write('\n\n')
+
+
def advance(self):
self.server.advance()
def handler(self, request):
- if self.log is not None:
- self.log.write('[request]\n')
- json.dump(request, self.log, indent=4, sort_keys=True)
- self.log.write('\n\n')
+ self.logEvent('request', request)
name = request.get('action', '')
version = request.get('version', 4)
@@ -302,7 +81,6 @@ class AnkiConnect:
try:
method = None
-
for methodName, methodInst in inspect.getmembers(self, predicate=inspect.ismethod):
apiVersionLast = 0
apiNameLast = None
@@ -331,25 +109,10 @@ class AnkiConnect:
except Exception as e:
reply['error'] = str(e)
- if self.log is not None:
- self.log.write('[reply]\n')
- json.dump(reply, self.log, indent=4, sort_keys=True)
- self.log.write('\n\n')
-
+ self.logEvent('reply', reply)
return reply
- def download(self, url):
- try:
- (code, contents) = download(url)
- except Exception as e:
- raise Exception('{} download failed with error {}'.format(url, str(e)))
- if code == 200:
- return contents
- else:
- raise Exception('{} download failed with return code {}'.format(url, code))
-
-
def window(self):
return aqt.mw
@@ -455,40 +218,12 @@ class AnkiConnect:
# Miscellaneous
#
- @api()
+ @util.api()
def version(self):
- return API_VERSION
+ return util.setting('apiVersion')
- @api()
- def upgrade(self):
- response = QMessageBox.question(
- self.window(),
- 'AnkiConnect',
- 'Upgrade to the latest version?',
- QMessageBox.Yes | QMessageBox.No
- )
-
- if response == QMessageBox.Yes:
- try:
- data = self.download(URL_UPGRADE)
- path = os.path.splitext(__file__)[0] + '.py'
- with open(path, 'w') as fp:
- fp.write(data.decode('utf-8'))
- QMessageBox.information(
- self.window(),
- 'AnkiConnect',
- 'Upgraded to the latest version, please restart Anki.'
- )
- return True
- except Exception as e:
- QMessageBox.critical(self.window(), 'AnkiConnect', 'Failed to download latest version.')
- raise e
-
- return False
-
-
- @api()
+ @util.api()
def loadProfile(self, name):
if name not in self.window().pm.profiles():
return False
@@ -504,12 +239,12 @@ class AnkiConnect:
return True
- @api()
+ @util.api()
def sync(self):
self.window().onSync()
- @api()
+ @util.api()
def multi(self, actions):
return list(map(self.handler, actions))
@@ -518,12 +253,12 @@ class AnkiConnect:
# Decks
#
- @api()
+ @util.api()
def deckNames(self):
return self.decks().allNames()
- @api()
+ @util.api()
def deckNamesAndIds(self):
decks = {}
for deck in self.deckNames():
@@ -532,7 +267,7 @@ class AnkiConnect:
return decks
- @api()
+ @util.api()
def getDecks(self, cards):
decks = {}
for card in cards:
@@ -546,7 +281,7 @@ class AnkiConnect:
return decks
- @api()
+ @util.api()
def createDeck(self, deck):
try:
self.startEditing()
@@ -557,7 +292,7 @@ class AnkiConnect:
return did
- @api()
+ @util.api()
def changeDeck(self, cards, deck):
self.startEditing()
@@ -575,7 +310,7 @@ class AnkiConnect:
self.stopEditing()
- @api()
+ @util.api()
def deleteDecks(self, decks, cardsToo=False):
try:
self.startEditing()
@@ -587,7 +322,7 @@ class AnkiConnect:
self.stopEditing()
- @api()
+ @util.api()
def getDeckConfig(self, deck):
if not deck in self.deckNames():
return False
@@ -597,7 +332,7 @@ class AnkiConnect:
return collection.decks.confForDid(did)
- @api()
+ @util.api()
def saveDeckConfig(self, config):
collection = self.collection()
@@ -613,7 +348,7 @@ class AnkiConnect:
return True
- @api()
+ @util.api()
def setDeckConfigId(self, decks, configId):
configId = str(configId)
for deck in decks:
@@ -631,7 +366,7 @@ class AnkiConnect:
return True
- @api()
+ @util.api()
def cloneDeckConfigId(self, name, cloneFrom='1'):
configId = str(cloneFrom)
if not configId in self.collection().decks.dconf:
@@ -641,7 +376,7 @@ class AnkiConnect:
return self.collection().decks.confId(name, config)
- @api()
+ @util.api()
def removeDeckConfigId(self, configId):
configId = str(configId)
collection = self.collection()
@@ -652,16 +387,16 @@ class AnkiConnect:
return True
- @api()
+ @util.api()
def storeMediaFile(self, filename, data):
self.deleteMediaFile(filename)
self.media().writeData(filename, base64.b64decode(data))
- @api()
+ @util.api()
def retrieveMediaFile(self, filename):
filename = os.path.basename(filename)
- filename = normalize('NFC', filename)
+ filename = unicodedata.normalize('NFC', filename)
filename = self.media().stripIllegal(filename)
path = os.path.join(self.media().dir(), filename)
@@ -672,19 +407,19 @@ class AnkiConnect:
return False
- @api()
+ @util.api()
def deleteMediaFile(self, filename):
self.media().syncDelete(filename)
- @api()
+ @util.api()
def addNote(self, note):
ankiNote = self.createNote(note)
audio = note.get('audio')
if audio is not None and len(audio['fields']) > 0:
try:
- data = self.download(audio['url'])
+ data = util.download(audio['url'])
skipHash = audio.get('skipHash')
if skipHash is None:
skip = False
@@ -716,7 +451,7 @@ class AnkiConnect:
return ankiNote.id
- @api()
+ @util.api()
def canAddNote(self, note):
try:
return bool(self.createNote(note))
@@ -724,7 +459,7 @@ class AnkiConnect:
return False
- @api()
+ @util.api()
def updateNoteFields(self, note):
ankiNote = self.collection().getNote(note['id'])
if ankiNote is None:
@@ -737,24 +472,24 @@ class AnkiConnect:
ankiNote.flush()
- @api()
+ @util.api()
def addTags(self, notes, tags, add=True):
self.startEditing()
self.collection().tags.bulkAdd(notes, tags, add)
self.stopEditing()
- @api()
+ @util.api()
def removeTags(self, notes, tags):
return self.addTags(notes, tags, False)
- @api()
+ @util.api()
def getTags(self):
return self.collection().tags.all()
- @api()
+ @util.api()
def suspend(self, cards, suspend=True):
for card in cards:
if self.suspended(card) == suspend:
@@ -774,18 +509,18 @@ class AnkiConnect:
return True
- @api()
+ @util.api()
def unsuspend(self, cards):
self.suspend(cards, False)
- @api()
+ @util.api()
def suspended(self, card):
card = self.collection().getCard(card)
return card.queue == -1
- @api()
+ @util.api()
def areSuspended(self, cards):
suspended = []
for card in cards:
@@ -794,7 +529,7 @@ class AnkiConnect:
return suspended
- @api()
+ @util.api()
def areDue(self, cards):
due = []
for card in cards:
@@ -805,12 +540,12 @@ class AnkiConnect:
if ivl >= -1200:
due.append(bool(self.findCards('cid:{} is:due'.format(card))))
else:
- due.append(date - ivl <= time())
+ due.append(date - ivl <= time.time())
return due
- @api()
+ @util.api()
def getIntervals(self, cards, complete=False):
intervals = []
for card in cards:
@@ -826,12 +561,12 @@ class AnkiConnect:
- @api()
+ @util.api()
def modelNames(self):
return self.collection().models.allNames()
- @api()
+ @util.api()
def createModel(self, modelName, inOrderFields, cardTemplates, css = None):
# https://github.com/dae/anki/blob/b06b70f7214fb1f2ce33ba06d2b095384b81f874/anki/stdmodels.py
if (len(inOrderFields) == 0):
@@ -869,7 +604,7 @@ class AnkiConnect:
return m
- @api()
+ @util.api()
def modelNamesAndIds(self):
models = {}
for model in self.modelNames():
@@ -878,7 +613,7 @@ class AnkiConnect:
return models
- @api()
+ @util.api()
def modelNameFromId(self, modelId):
model = self.collection().models.get(modelId)
if model is None:
@@ -887,7 +622,7 @@ class AnkiConnect:
return model['name']
- @api()
+ @util.api()
def modelFieldNames(self, modelName):
model = self.collection().models.byName(modelName)
if model is None:
@@ -896,7 +631,7 @@ class AnkiConnect:
return [field['name'] for field in model['flds']]
- @api()
+ @util.api()
def modelFieldsOnTemplates(self, modelName):
model = self.collection().models.byName(modelName)
if model is None:
@@ -926,7 +661,7 @@ class AnkiConnect:
return templates
- @api()
+ @util.api()
def modelTemplates(self, modelName):
model = self.collection().models.byName(modelName)
if model is None:
@@ -939,7 +674,7 @@ class AnkiConnect:
return templates
- @api()
+ @util.api()
def modelStyling(self, modelName):
model = self.collection().models.byName(modelName)
if model is None:
@@ -948,7 +683,7 @@ class AnkiConnect:
return {'css': model['css']}
- @api()
+ @util.api()
def updateModelTemplates(self, model):
models = self.collection().models
ankiModel = models.byName(model['name'])
@@ -972,7 +707,7 @@ class AnkiConnect:
models.flush()
- @api()
+ @util.api()
def updateModelStyling(self, model):
models = self.collection().models
ankiModel = models.byName(model['name'])
@@ -985,7 +720,7 @@ class AnkiConnect:
models.flush()
- @api()
+ @util.api()
def deckNameFromId(self, deckId):
deck = self.collection().decks.get(deckId)
if deck is None:
@@ -994,7 +729,7 @@ class AnkiConnect:
return deck['name']
- @api()
+ @util.api()
def findNotes(self, query=None):
if query is None:
return []
@@ -1002,7 +737,7 @@ class AnkiConnect:
return self.collection().findNotes(query)
- @api()
+ @util.api()
def findCards(self, query=None):
if query is None:
return []
@@ -1010,7 +745,7 @@ class AnkiConnect:
return self.collection().findCards(query)
- @api()
+ @util.api()
def cardsInfo(self, cards):
result = []
for cid in cards:
@@ -1049,7 +784,7 @@ class AnkiConnect:
return result
- @api()
+ @util.api()
def notesInfo(self, notes):
result = []
for nid in notes:
@@ -1080,7 +815,7 @@ class AnkiConnect:
return result
- @api()
+ @util.api()
def deleteNotes(self, notes):
try:
self.collection().remNotes(notes)
@@ -1090,12 +825,12 @@ class AnkiConnect:
- @api()
+ @util.api()
def cardsToNotes(self, cards):
return self.collection().db.list('select distinct nid from cards where id in ' + anki.utils.ids2str(cards))
- @api()
+ @util.api()
def guiBrowse(self, query=None):
browser = aqt.dialogs.open('Browser', self.window())
browser.activateWindow()
@@ -1110,7 +845,7 @@ class AnkiConnect:
return browser.model.cards
- @api()
+ @util.api()
def guiAddCards(self, note=None):
if note is not None:
@@ -1140,8 +875,7 @@ class AnkiConnect:
addCards = None
if closeAfterAdding:
-
- randomString = ''.join(choice(ascii_letters) for _ in range(10))
+ randomString = ''.join(random.choice(string.ascii_letters) for _ in range(10))
windowName = 'AddCardsAndClose' + randomString
class AddCardsAndClose(aqt.addcards.AddCards):
@@ -1248,12 +982,12 @@ class AnkiConnect:
addCards = aqt.dialogs.open('AddCards', self.window())
addCards.activateWindow()
- @api()
+ @util.api()
def guiReviewActive(self):
return self.reviewer().card is not None and self.window().state == 'review'
- @api()
+ @util.api()
def guiCurrentCard(self):
if not self.guiReviewActive():
raise Exception('Gui review is not currently active.')
@@ -1286,7 +1020,7 @@ class AnkiConnect:
}
- @api()
+ @util.api()
def guiStartCardTimer(self):
if not self.guiReviewActive():
return False
@@ -1300,7 +1034,7 @@ class AnkiConnect:
return False
- @api()
+ @util.api()
def guiShowQuestion(self):
if self.guiReviewActive():
self.reviewer()._showQuestion()
@@ -1309,7 +1043,7 @@ class AnkiConnect:
return False
- @api()
+ @util.api()
def guiShowAnswer(self):
if self.guiReviewActive():
self.window().reviewer._showAnswer()
@@ -1318,7 +1052,7 @@ class AnkiConnect:
return False
- @api()
+ @util.api()
def guiAnswerCard(self, ease):
if not self.guiReviewActive():
return False
@@ -1333,7 +1067,7 @@ class AnkiConnect:
return True
- @api()
+ @util.api()
def guiDeckOverview(self, name):
collection = self.collection()
if collection is not None:
@@ -1346,12 +1080,12 @@ class AnkiConnect:
return False
- @api()
+ @util.api()
def guiDeckBrowser(self):
self.window().moveToState('deckBrowser')
- @api()
+ @util.api()
def guiDeckReview(self, name):
if self.guiDeckOverview(name):
self.window().moveToState('review')
@@ -1360,18 +1094,14 @@ class AnkiConnect:
return False
- @api()
+ @util.api()
def guiExitAnki(self):
timer = QTimer()
- def exitAnki():
- timer.stop()
- self.window().close()
- timer.timeout.connect(exitAnki)
+ timer.timeout.connect(self.window().close)
timer.start(1000) # 1s should be enough to allow the response to be sent.
-
- @api()
+ @util.api()
def addNotes(self, notes):
results = []
for note in notes:
@@ -1383,7 +1113,7 @@ class AnkiConnect:
return results
- @api()
+ @util.api()
def canAddNotes(self, notes):
results = []
for note in notes:
diff --git a/plugin/util.py b/plugin/util.py
new file mode 100644
index 0000000..740840e
--- /dev/null
+++ b/plugin/util.py
@@ -0,0 +1,63 @@
+# Copyright 2016-2020 Alex Yatskov
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import os
+
+import anki
+import anki.sync
+import aqt
+
+
+#
+# Utilities
+#
+
+def download(url):
+ client = anki.sync.AnkiRequestsClient()
+ client.timeout = setting('webTimeout') / 1000
+ resp = client.get(url)
+ if resp.status_code == 200:
+ return client.streamContent(resp)
+ else:
+ raise Exception('{} download failed with return code {}'.format(url, resp.status_code))
+
+
+def api(*versions):
+ def decorator(func):
+ method = lambda *args, **kwargs: func(*args, **kwargs)
+ setattr(method, 'versions', versions)
+ setattr(method, 'api', True)
+ return method
+
+ return decorator
+
+
+def setting(key):
+ defaults = {
+ 'apiKey': None,
+ 'apiLogPath': None,
+ 'apiPollInterval': 25,
+ 'apiVersion': 6,
+ 'webBacklog': 5,
+ 'webBindAddress': os.getenv('ANKICONNECT_BIND_ADDRESS', '127.0.0.1'),
+ 'webBindPort': 8765,
+ 'webCorsOrigin': os.getenv('ANKICONNECT_CORS_ORIGIN', 'http://localhost'),
+ 'webTimeout': 10000,
+ }
+
+ try:
+ return aqt.mw.addonManager.getConfig(__name__).get(key, defaults[key])
+ except:
+ raise Exception('setting {} not found'.format(key))
diff --git a/plugin/web.py b/plugin/web.py
new file mode 100644
index 0000000..9156178
--- /dev/null
+++ b/plugin/web.py
@@ -0,0 +1,185 @@
+# Copyright 2016-2020 Alex Yatskov
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import json
+import select
+import socket
+
+from AnkiConnect import web, util
+
+
+#
+# WebRequest
+#
+
+class WebRequest:
+ def __init__(self, headers, body):
+ self.headers = headers
+ self.body = body
+
+
+#
+# WebClient
+#
+
+class WebClient:
+ def __init__(self, sock, handler):
+ self.sock = sock
+ self.handler = handler
+ self.readBuff = bytes()
+ self.writeBuff = bytes()
+
+
+ def advance(self, recvSize=1024):
+ if self.sock is None:
+ return False
+
+ rlist, wlist = select.select([self.sock], [self.sock], [], 0)[:2]
+
+ if rlist:
+ msg = self.sock.recv(recvSize)
+ if not msg:
+ self.close()
+ return False
+
+ self.readBuff += msg
+
+ req, length = self.parseRequest(self.readBuff)
+ if req is not None:
+ self.readBuff = self.readBuff[length:]
+ self.writeBuff += self.handler(req)
+
+ if wlist and self.writeBuff:
+ length = self.sock.send(self.writeBuff)
+ self.writeBuff = self.writeBuff[length:]
+ if not self.writeBuff:
+ self.close()
+ return False
+
+ return True
+
+
+ def close(self):
+ if self.sock is not None:
+ self.sock.close()
+ self.sock = None
+
+ self.readBuff = bytes()
+ self.writeBuff = bytes()
+
+
+ def parseRequest(self, data):
+ parts = data.split('\r\n\r\n'.encode('utf-8'), 1)
+ if len(parts) == 1:
+ return None, 0
+
+ headers = {}
+ for line in parts[0].split('\r\n'.encode('utf-8')):
+ pair = line.split(': '.encode('utf-8'))
+ headers[pair[0].lower()] = pair[1] if len(pair) > 1 else None
+
+ headerLength = len(parts[0]) + 4
+ bodyLength = int(headers.get('content-length'.encode('utf-8'), 0))
+ totalLength = headerLength + bodyLength
+
+ if totalLength > len(data):
+ return None, 0
+
+ body = data[headerLength : totalLength]
+ return WebRequest(headers, body), totalLength
+
+
+#
+# WebServer
+#
+
+class WebServer:
+ def __init__(self, handler):
+ self.handler = handler
+ self.clients = []
+ self.sock = None
+
+
+ def advance(self):
+ if self.sock is not None:
+ self.acceptClients()
+ self.advanceClients()
+
+
+ def acceptClients(self):
+ rlist = select.select([self.sock], [], [], 0)[0]
+ if not rlist:
+ return
+
+ clientSock = self.sock.accept()[0]
+ if clientSock is not None:
+ clientSock.setblocking(False)
+ self.clients.append(WebClient(clientSock, self.handlerWrapper))
+
+
+ def advanceClients(self):
+ self.clients = list(filter(lambda c: c.advance(), self.clients))
+
+
+ def listen(self):
+ self.close()
+
+ self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ self.sock.setblocking(False)
+ self.sock.bind((util.setting('webBindAddress'), util.setting('webBindPort')))
+ self.sock.listen(util.setting('webBacklog'))
+
+
+ def handlerWrapper(self, req):
+ if len(req.body) == 0:
+ body = 'AnkiConnect'
+ else:
+ try:
+ params = json.loads(req.body.decode('utf-8'))
+ body = json.dumps(self.handler(params)).encode('utf-8')
+ except ValueError:
+ body = json.dumps(None).encode('utf-8')
+
+ headers = [
+ ['HTTP/1.1 200 OK', None],
+ ['Content-Type', 'text/json'],
+ ['Access-Control-Allow-Origin', util.setting('webCorsOrigin')],
+ ['Content-Length', str(len(body))]
+ ]
+
+ resp = bytes()
+
+ for key, value in headers:
+ if value is None:
+ resp += '{}\r\n'.format(key).encode('utf-8')
+ else:
+ resp += '{}: {}\r\n'.format(key, value).encode('utf-8')
+
+ resp += '\r\n'.encode('utf-8')
+ resp += body
+
+ return resp
+
+
+ def close(self):
+ if self.sock is not None:
+ self.sock.close()
+ self.sock = None
+
+ for client in self.clients:
+ client.close()
+
+ self.clients = []