This commit is contained in:
Alex Yatskov 2020-01-05 15:42:08 -08:00
parent 6118d01de1
commit 173e43700b
3 changed files with 338 additions and 360 deletions

View File

@ -1,4 +1,4 @@
# Copyright 2016-2019 Alex Yatskov # Copyright 2016-2020 Alex Yatskov
# #
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by # it under the terms of the GNU General Public License as published by
@ -13,251 +13,25 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import anki
import aqt
import base64 import base64
import hashlib import hashlib
import inspect import inspect
import json import json
import os import os
import os.path import os.path
import random
import re import re
import select import string
import socket import time
import sys import unicodedata
from operator import itemgetter
from time import time
from unicodedata import normalize
from random import choice
from string import ascii_letters
#
# Constants
#
API_VERSION = 6
API_LOG_PATH = None
NET_CORS_ORIGIN = os.getenv('ANKICONNECT_CORS_ORIGIN', 'http://localhost')
NET_ADDRESS = os.getenv('ANKICONNECT_BIND_ADDRESS', '127.0.0.1')
NET_BACKLOG = 5
NET_PORT = 8765
TICK_INTERVAL = 25
URL_TIMEOUT = 10
URL_UPGRADE = 'https://raw.githubusercontent.com/FooSoft/anki-connect/master/AnkiConnect.py'
config = aqt.mw.addonManager.getConfig('AnkiConnect')
#
# Helpers
#
from anki.sync import AnkiRequestsClient
def download(url):
contents = None
client = AnkiRequestsClient()
client.timeout = URL_TIMEOUT
resp = client.get(url)
if resp.status_code == 200:
contents = client.streamContent(resp)
return (resp.status_code, contents)
from PyQt5.QtCore import QTimer from PyQt5.QtCore import QTimer
from PyQt5.QtWidgets import QMessageBox from PyQt5.QtWidgets import QMessageBox
import anki
import aqt
def api(*versions): from AnkiConnect import web, util
def decorator(func):
method = lambda *args, **kwargs: func(*args, **kwargs)
setattr(method, 'versions', versions)
setattr(method, 'api', True)
return method
return decorator
#
# WebRequest
#
class WebRequest:
def __init__(self, headers, body):
self.headers = headers
self.body = body
#
# WebClient
#
class WebClient:
def __init__(self, sock, handler):
self.sock = sock
self.handler = handler
self.readBuff = bytes()
self.writeBuff = bytes()
def advance(self, recvSize=1024):
if self.sock is None:
return False
rlist, wlist = select.select([self.sock], [self.sock], [], 0)[:2]
if rlist:
msg = self.sock.recv(recvSize)
if not msg:
self.close()
return False
self.readBuff += msg
req, length = self.parseRequest(self.readBuff)
if req is not None:
self.readBuff = self.readBuff[length:]
self.writeBuff += self.handler(req)
if wlist and self.writeBuff:
length = self.sock.send(self.writeBuff)
self.writeBuff = self.writeBuff[length:]
if not self.writeBuff:
self.close()
return False
return True
def close(self):
if self.sock is not None:
self.sock.close()
self.sock = None
self.readBuff = bytes()
self.writeBuff = bytes()
def parseRequest(self, data):
parts = data.split('\r\n\r\n'.encode('utf-8'), 1)
if len(parts) == 1:
return None, 0
headers = {}
for line in parts[0].split('\r\n'.encode('utf-8')):
pair = line.split(': '.encode('utf-8'))
headers[pair[0].lower()] = pair[1] if len(pair) > 1 else None
headerLength = len(parts[0]) + 4
bodyLength = int(headers.get('content-length'.encode('utf-8'), 0))
totalLength = headerLength + bodyLength
if totalLength > len(data):
return None, 0
body = data[headerLength : totalLength]
return WebRequest(headers, body), totalLength
#
# WebServer
#
class WebServer:
def __init__(self, handler):
self.handler = handler
self.clients = []
self.sock = None
self.resetHeaders()
def setHeader(self, name, value):
self.headersOpt[name] = value
def resetHeaders(self):
self.headers = [
['HTTP/1.1 200 OK', None],
['Content-Type', 'text/json'],
['Access-Control-Allow-Origin', NET_CORS_ORIGIN]
]
self.headersOpt = {}
def getHeaders(self):
headers = self.headers[:]
for name in self.headersOpt:
headers.append([name, self.headersOpt[name]])
return headers
def advance(self):
if self.sock is not None:
self.acceptClients()
self.advanceClients()
def acceptClients(self):
rlist = select.select([self.sock], [], [], 0)[0]
if not rlist:
return
clientSock = self.sock.accept()[0]
if clientSock is not None:
clientSock.setblocking(False)
self.clients.append(WebClient(clientSock, self.handlerWrapper))
def advanceClients(self):
self.clients = list(filter(lambda c: c.advance(), self.clients))
def listen(self):
self.close()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.setblocking(False)
self.sock.bind((NET_ADDRESS, NET_PORT))
self.sock.listen(NET_BACKLOG)
def handlerWrapper(self, req):
if len(req.body) == 0:
body = 'AnkiConnect v.{}'.format(API_VERSION).encode('utf-8')
else:
try:
params = json.loads(req.body.decode('utf-8'))
body = json.dumps(self.handler(params)).encode('utf-8')
except ValueError:
body = json.dumps(None).encode('utf-8')
resp = bytes()
self.setHeader('Content-Length', str(len(body)))
headers = self.getHeaders()
for key, value in headers:
if value is None:
resp += '{}\r\n'.format(key).encode('utf-8')
else:
resp += '{}: {}\r\n'.format(key, value).encode('utf-8')
resp += '\r\n'.encode('utf-8')
resp += body
return resp
def close(self):
if self.sock is not None:
self.sock.close()
self.sock = None
for client in self.clients:
client.close()
self.clients = []
# #
@ -266,34 +40,39 @@ class WebServer:
class AnkiConnect: class AnkiConnect:
def __init__(self): def __init__(self):
self.server = WebServer(self.handler)
self.log = None self.log = None
if API_LOG_PATH is not None: logPath = util.setting('apiLogPath')
self.log = open(API_LOG_PATH, 'w') if logPath is not None:
self.log = open(logPath, 'w')
try: try:
self.server = web.WebServer(self.handler)
self.server.listen() self.server.listen()
self.timer = QTimer() self.timer = QTimer()
self.timer.timeout.connect(self.advance) self.timer.timeout.connect(self.advance)
self.timer.start(TICK_INTERVAL) self.timer.start(util.setting('apiPollInterval'))
except: except:
QMessageBox.critical( QMessageBox.critical(
self.window(), self.window(),
'AnkiConnect', 'AnkiConnect',
'Failed to listen on port {}.\nMake sure it is available and is not in use.'.format(NET_PORT) 'Failed to listen on port {}.\nMake sure it is available and is not in use.'.format(util.setting('webBindPort'))
) )
def logEvent(self, name, data):
if self.log is not None:
self.log.write('[{}]\n'.format(name))
json.dump(data, self.log, indent=4, sort_keys=True)
self.log.write('\n\n')
def advance(self): def advance(self):
self.server.advance() self.server.advance()
def handler(self, request): def handler(self, request):
if self.log is not None: self.logEvent('request', request)
self.log.write('[request]\n')
json.dump(request, self.log, indent=4, sort_keys=True)
self.log.write('\n\n')
name = request.get('action', '') name = request.get('action', '')
version = request.get('version', 4) version = request.get('version', 4)
@ -302,7 +81,6 @@ class AnkiConnect:
try: try:
method = None method = None
for methodName, methodInst in inspect.getmembers(self, predicate=inspect.ismethod): for methodName, methodInst in inspect.getmembers(self, predicate=inspect.ismethod):
apiVersionLast = 0 apiVersionLast = 0
apiNameLast = None apiNameLast = None
@ -331,25 +109,10 @@ class AnkiConnect:
except Exception as e: except Exception as e:
reply['error'] = str(e) reply['error'] = str(e)
if self.log is not None: self.logEvent('reply', reply)
self.log.write('[reply]\n')
json.dump(reply, self.log, indent=4, sort_keys=True)
self.log.write('\n\n')
return reply return reply
def download(self, url):
try:
(code, contents) = download(url)
except Exception as e:
raise Exception('{} download failed with error {}'.format(url, str(e)))
if code == 200:
return contents
else:
raise Exception('{} download failed with return code {}'.format(url, code))
def window(self): def window(self):
return aqt.mw return aqt.mw
@ -455,40 +218,12 @@ class AnkiConnect:
# Miscellaneous # Miscellaneous
# #
@api() @util.api()
def version(self): def version(self):
return API_VERSION return util.setting('apiVersion')
@api() @util.api()
def upgrade(self):
response = QMessageBox.question(
self.window(),
'AnkiConnect',
'Upgrade to the latest version?',
QMessageBox.Yes | QMessageBox.No
)
if response == QMessageBox.Yes:
try:
data = self.download(URL_UPGRADE)
path = os.path.splitext(__file__)[0] + '.py'
with open(path, 'w') as fp:
fp.write(data.decode('utf-8'))
QMessageBox.information(
self.window(),
'AnkiConnect',
'Upgraded to the latest version, please restart Anki.'
)
return True
except Exception as e:
QMessageBox.critical(self.window(), 'AnkiConnect', 'Failed to download latest version.')
raise e
return False
@api()
def loadProfile(self, name): def loadProfile(self, name):
if name not in self.window().pm.profiles(): if name not in self.window().pm.profiles():
return False return False
@ -504,12 +239,12 @@ class AnkiConnect:
return True return True
@api() @util.api()
def sync(self): def sync(self):
self.window().onSync() self.window().onSync()
@api() @util.api()
def multi(self, actions): def multi(self, actions):
return list(map(self.handler, actions)) return list(map(self.handler, actions))
@ -518,12 +253,12 @@ class AnkiConnect:
# Decks # Decks
# #
@api() @util.api()
def deckNames(self): def deckNames(self):
return self.decks().allNames() return self.decks().allNames()
@api() @util.api()
def deckNamesAndIds(self): def deckNamesAndIds(self):
decks = {} decks = {}
for deck in self.deckNames(): for deck in self.deckNames():
@ -532,7 +267,7 @@ class AnkiConnect:
return decks return decks
@api() @util.api()
def getDecks(self, cards): def getDecks(self, cards):
decks = {} decks = {}
for card in cards: for card in cards:
@ -546,7 +281,7 @@ class AnkiConnect:
return decks return decks
@api() @util.api()
def createDeck(self, deck): def createDeck(self, deck):
try: try:
self.startEditing() self.startEditing()
@ -557,7 +292,7 @@ class AnkiConnect:
return did return did
@api() @util.api()
def changeDeck(self, cards, deck): def changeDeck(self, cards, deck):
self.startEditing() self.startEditing()
@ -575,7 +310,7 @@ class AnkiConnect:
self.stopEditing() self.stopEditing()
@api() @util.api()
def deleteDecks(self, decks, cardsToo=False): def deleteDecks(self, decks, cardsToo=False):
try: try:
self.startEditing() self.startEditing()
@ -587,7 +322,7 @@ class AnkiConnect:
self.stopEditing() self.stopEditing()
@api() @util.api()
def getDeckConfig(self, deck): def getDeckConfig(self, deck):
if not deck in self.deckNames(): if not deck in self.deckNames():
return False return False
@ -597,7 +332,7 @@ class AnkiConnect:
return collection.decks.confForDid(did) return collection.decks.confForDid(did)
@api() @util.api()
def saveDeckConfig(self, config): def saveDeckConfig(self, config):
collection = self.collection() collection = self.collection()
@ -613,7 +348,7 @@ class AnkiConnect:
return True return True
@api() @util.api()
def setDeckConfigId(self, decks, configId): def setDeckConfigId(self, decks, configId):
configId = str(configId) configId = str(configId)
for deck in decks: for deck in decks:
@ -631,7 +366,7 @@ class AnkiConnect:
return True return True
@api() @util.api()
def cloneDeckConfigId(self, name, cloneFrom='1'): def cloneDeckConfigId(self, name, cloneFrom='1'):
configId = str(cloneFrom) configId = str(cloneFrom)
if not configId in self.collection().decks.dconf: if not configId in self.collection().decks.dconf:
@ -641,7 +376,7 @@ class AnkiConnect:
return self.collection().decks.confId(name, config) return self.collection().decks.confId(name, config)
@api() @util.api()
def removeDeckConfigId(self, configId): def removeDeckConfigId(self, configId):
configId = str(configId) configId = str(configId)
collection = self.collection() collection = self.collection()
@ -652,16 +387,16 @@ class AnkiConnect:
return True return True
@api() @util.api()
def storeMediaFile(self, filename, data): def storeMediaFile(self, filename, data):
self.deleteMediaFile(filename) self.deleteMediaFile(filename)
self.media().writeData(filename, base64.b64decode(data)) self.media().writeData(filename, base64.b64decode(data))
@api() @util.api()
def retrieveMediaFile(self, filename): def retrieveMediaFile(self, filename):
filename = os.path.basename(filename) filename = os.path.basename(filename)
filename = normalize('NFC', filename) filename = unicodedata.normalize('NFC', filename)
filename = self.media().stripIllegal(filename) filename = self.media().stripIllegal(filename)
path = os.path.join(self.media().dir(), filename) path = os.path.join(self.media().dir(), filename)
@ -672,19 +407,19 @@ class AnkiConnect:
return False return False
@api() @util.api()
def deleteMediaFile(self, filename): def deleteMediaFile(self, filename):
self.media().syncDelete(filename) self.media().syncDelete(filename)
@api() @util.api()
def addNote(self, note): def addNote(self, note):
ankiNote = self.createNote(note) ankiNote = self.createNote(note)
audio = note.get('audio') audio = note.get('audio')
if audio is not None and len(audio['fields']) > 0: if audio is not None and len(audio['fields']) > 0:
try: try:
data = self.download(audio['url']) data = util.download(audio['url'])
skipHash = audio.get('skipHash') skipHash = audio.get('skipHash')
if skipHash is None: if skipHash is None:
skip = False skip = False
@ -716,7 +451,7 @@ class AnkiConnect:
return ankiNote.id return ankiNote.id
@api() @util.api()
def canAddNote(self, note): def canAddNote(self, note):
try: try:
return bool(self.createNote(note)) return bool(self.createNote(note))
@ -724,7 +459,7 @@ class AnkiConnect:
return False return False
@api() @util.api()
def updateNoteFields(self, note): def updateNoteFields(self, note):
ankiNote = self.collection().getNote(note['id']) ankiNote = self.collection().getNote(note['id'])
if ankiNote is None: if ankiNote is None:
@ -737,24 +472,24 @@ class AnkiConnect:
ankiNote.flush() ankiNote.flush()
@api() @util.api()
def addTags(self, notes, tags, add=True): def addTags(self, notes, tags, add=True):
self.startEditing() self.startEditing()
self.collection().tags.bulkAdd(notes, tags, add) self.collection().tags.bulkAdd(notes, tags, add)
self.stopEditing() self.stopEditing()
@api() @util.api()
def removeTags(self, notes, tags): def removeTags(self, notes, tags):
return self.addTags(notes, tags, False) return self.addTags(notes, tags, False)
@api() @util.api()
def getTags(self): def getTags(self):
return self.collection().tags.all() return self.collection().tags.all()
@api() @util.api()
def suspend(self, cards, suspend=True): def suspend(self, cards, suspend=True):
for card in cards: for card in cards:
if self.suspended(card) == suspend: if self.suspended(card) == suspend:
@ -774,18 +509,18 @@ class AnkiConnect:
return True return True
@api() @util.api()
def unsuspend(self, cards): def unsuspend(self, cards):
self.suspend(cards, False) self.suspend(cards, False)
@api() @util.api()
def suspended(self, card): def suspended(self, card):
card = self.collection().getCard(card) card = self.collection().getCard(card)
return card.queue == -1 return card.queue == -1
@api() @util.api()
def areSuspended(self, cards): def areSuspended(self, cards):
suspended = [] suspended = []
for card in cards: for card in cards:
@ -794,7 +529,7 @@ class AnkiConnect:
return suspended return suspended
@api() @util.api()
def areDue(self, cards): def areDue(self, cards):
due = [] due = []
for card in cards: for card in cards:
@ -805,12 +540,12 @@ class AnkiConnect:
if ivl >= -1200: if ivl >= -1200:
due.append(bool(self.findCards('cid:{} is:due'.format(card)))) due.append(bool(self.findCards('cid:{} is:due'.format(card))))
else: else:
due.append(date - ivl <= time()) due.append(date - ivl <= time.time())
return due return due
@api() @util.api()
def getIntervals(self, cards, complete=False): def getIntervals(self, cards, complete=False):
intervals = [] intervals = []
for card in cards: for card in cards:
@ -826,12 +561,12 @@ class AnkiConnect:
@api() @util.api()
def modelNames(self): def modelNames(self):
return self.collection().models.allNames() return self.collection().models.allNames()
@api() @util.api()
def createModel(self, modelName, inOrderFields, cardTemplates, css = None): def createModel(self, modelName, inOrderFields, cardTemplates, css = None):
# https://github.com/dae/anki/blob/b06b70f7214fb1f2ce33ba06d2b095384b81f874/anki/stdmodels.py # https://github.com/dae/anki/blob/b06b70f7214fb1f2ce33ba06d2b095384b81f874/anki/stdmodels.py
if (len(inOrderFields) == 0): if (len(inOrderFields) == 0):
@ -869,7 +604,7 @@ class AnkiConnect:
return m return m
@api() @util.api()
def modelNamesAndIds(self): def modelNamesAndIds(self):
models = {} models = {}
for model in self.modelNames(): for model in self.modelNames():
@ -878,7 +613,7 @@ class AnkiConnect:
return models return models
@api() @util.api()
def modelNameFromId(self, modelId): def modelNameFromId(self, modelId):
model = self.collection().models.get(modelId) model = self.collection().models.get(modelId)
if model is None: if model is None:
@ -887,7 +622,7 @@ class AnkiConnect:
return model['name'] return model['name']
@api() @util.api()
def modelFieldNames(self, modelName): def modelFieldNames(self, modelName):
model = self.collection().models.byName(modelName) model = self.collection().models.byName(modelName)
if model is None: if model is None:
@ -896,7 +631,7 @@ class AnkiConnect:
return [field['name'] for field in model['flds']] return [field['name'] for field in model['flds']]
@api() @util.api()
def modelFieldsOnTemplates(self, modelName): def modelFieldsOnTemplates(self, modelName):
model = self.collection().models.byName(modelName) model = self.collection().models.byName(modelName)
if model is None: if model is None:
@ -926,7 +661,7 @@ class AnkiConnect:
return templates return templates
@api() @util.api()
def modelTemplates(self, modelName): def modelTemplates(self, modelName):
model = self.collection().models.byName(modelName) model = self.collection().models.byName(modelName)
if model is None: if model is None:
@ -939,7 +674,7 @@ class AnkiConnect:
return templates return templates
@api() @util.api()
def modelStyling(self, modelName): def modelStyling(self, modelName):
model = self.collection().models.byName(modelName) model = self.collection().models.byName(modelName)
if model is None: if model is None:
@ -948,7 +683,7 @@ class AnkiConnect:
return {'css': model['css']} return {'css': model['css']}
@api() @util.api()
def updateModelTemplates(self, model): def updateModelTemplates(self, model):
models = self.collection().models models = self.collection().models
ankiModel = models.byName(model['name']) ankiModel = models.byName(model['name'])
@ -972,7 +707,7 @@ class AnkiConnect:
models.flush() models.flush()
@api() @util.api()
def updateModelStyling(self, model): def updateModelStyling(self, model):
models = self.collection().models models = self.collection().models
ankiModel = models.byName(model['name']) ankiModel = models.byName(model['name'])
@ -985,7 +720,7 @@ class AnkiConnect:
models.flush() models.flush()
@api() @util.api()
def deckNameFromId(self, deckId): def deckNameFromId(self, deckId):
deck = self.collection().decks.get(deckId) deck = self.collection().decks.get(deckId)
if deck is None: if deck is None:
@ -994,7 +729,7 @@ class AnkiConnect:
return deck['name'] return deck['name']
@api() @util.api()
def findNotes(self, query=None): def findNotes(self, query=None):
if query is None: if query is None:
return [] return []
@ -1002,7 +737,7 @@ class AnkiConnect:
return self.collection().findNotes(query) return self.collection().findNotes(query)
@api() @util.api()
def findCards(self, query=None): def findCards(self, query=None):
if query is None: if query is None:
return [] return []
@ -1010,7 +745,7 @@ class AnkiConnect:
return self.collection().findCards(query) return self.collection().findCards(query)
@api() @util.api()
def cardsInfo(self, cards): def cardsInfo(self, cards):
result = [] result = []
for cid in cards: for cid in cards:
@ -1049,7 +784,7 @@ class AnkiConnect:
return result return result
@api() @util.api()
def notesInfo(self, notes): def notesInfo(self, notes):
result = [] result = []
for nid in notes: for nid in notes:
@ -1080,7 +815,7 @@ class AnkiConnect:
return result return result
@api() @util.api()
def deleteNotes(self, notes): def deleteNotes(self, notes):
try: try:
self.collection().remNotes(notes) self.collection().remNotes(notes)
@ -1090,12 +825,12 @@ class AnkiConnect:
@api() @util.api()
def cardsToNotes(self, cards): def cardsToNotes(self, cards):
return self.collection().db.list('select distinct nid from cards where id in ' + anki.utils.ids2str(cards)) return self.collection().db.list('select distinct nid from cards where id in ' + anki.utils.ids2str(cards))
@api() @util.api()
def guiBrowse(self, query=None): def guiBrowse(self, query=None):
browser = aqt.dialogs.open('Browser', self.window()) browser = aqt.dialogs.open('Browser', self.window())
browser.activateWindow() browser.activateWindow()
@ -1110,7 +845,7 @@ class AnkiConnect:
return browser.model.cards return browser.model.cards
@api() @util.api()
def guiAddCards(self, note=None): def guiAddCards(self, note=None):
if note is not None: if note is not None:
@ -1140,8 +875,7 @@ class AnkiConnect:
addCards = None addCards = None
if closeAfterAdding: if closeAfterAdding:
randomString = ''.join(random.choice(string.ascii_letters) for _ in range(10))
randomString = ''.join(choice(ascii_letters) for _ in range(10))
windowName = 'AddCardsAndClose' + randomString windowName = 'AddCardsAndClose' + randomString
class AddCardsAndClose(aqt.addcards.AddCards): class AddCardsAndClose(aqt.addcards.AddCards):
@ -1248,12 +982,12 @@ class AnkiConnect:
addCards = aqt.dialogs.open('AddCards', self.window()) addCards = aqt.dialogs.open('AddCards', self.window())
addCards.activateWindow() addCards.activateWindow()
@api() @util.api()
def guiReviewActive(self): def guiReviewActive(self):
return self.reviewer().card is not None and self.window().state == 'review' return self.reviewer().card is not None and self.window().state == 'review'
@api() @util.api()
def guiCurrentCard(self): def guiCurrentCard(self):
if not self.guiReviewActive(): if not self.guiReviewActive():
raise Exception('Gui review is not currently active.') raise Exception('Gui review is not currently active.')
@ -1286,7 +1020,7 @@ class AnkiConnect:
} }
@api() @util.api()
def guiStartCardTimer(self): def guiStartCardTimer(self):
if not self.guiReviewActive(): if not self.guiReviewActive():
return False return False
@ -1300,7 +1034,7 @@ class AnkiConnect:
return False return False
@api() @util.api()
def guiShowQuestion(self): def guiShowQuestion(self):
if self.guiReviewActive(): if self.guiReviewActive():
self.reviewer()._showQuestion() self.reviewer()._showQuestion()
@ -1309,7 +1043,7 @@ class AnkiConnect:
return False return False
@api() @util.api()
def guiShowAnswer(self): def guiShowAnswer(self):
if self.guiReviewActive(): if self.guiReviewActive():
self.window().reviewer._showAnswer() self.window().reviewer._showAnswer()
@ -1318,7 +1052,7 @@ class AnkiConnect:
return False return False
@api() @util.api()
def guiAnswerCard(self, ease): def guiAnswerCard(self, ease):
if not self.guiReviewActive(): if not self.guiReviewActive():
return False return False
@ -1333,7 +1067,7 @@ class AnkiConnect:
return True return True
@api() @util.api()
def guiDeckOverview(self, name): def guiDeckOverview(self, name):
collection = self.collection() collection = self.collection()
if collection is not None: if collection is not None:
@ -1346,12 +1080,12 @@ class AnkiConnect:
return False return False
@api() @util.api()
def guiDeckBrowser(self): def guiDeckBrowser(self):
self.window().moveToState('deckBrowser') self.window().moveToState('deckBrowser')
@api() @util.api()
def guiDeckReview(self, name): def guiDeckReview(self, name):
if self.guiDeckOverview(name): if self.guiDeckOverview(name):
self.window().moveToState('review') self.window().moveToState('review')
@ -1360,18 +1094,14 @@ class AnkiConnect:
return False return False
@api() @util.api()
def guiExitAnki(self): def guiExitAnki(self):
timer = QTimer() timer = QTimer()
def exitAnki(): timer.timeout.connect(self.window().close)
timer.stop()
self.window().close()
timer.timeout.connect(exitAnki)
timer.start(1000) # 1s should be enough to allow the response to be sent. timer.start(1000) # 1s should be enough to allow the response to be sent.
@util.api()
@api()
def addNotes(self, notes): def addNotes(self, notes):
results = [] results = []
for note in notes: for note in notes:
@ -1383,7 +1113,7 @@ class AnkiConnect:
return results return results
@api() @util.api()
def canAddNotes(self, notes): def canAddNotes(self, notes):
results = [] results = []
for note in notes: for note in notes:

63
plugin/util.py Normal file
View File

@ -0,0 +1,63 @@
# Copyright 2016-2020 Alex Yatskov
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import os
import anki
import anki.sync
import aqt
#
# Utilities
#
def download(url):
client = anki.sync.AnkiRequestsClient()
client.timeout = setting('webTimeout') / 1000
resp = client.get(url)
if resp.status_code == 200:
return client.streamContent(resp)
else:
raise Exception('{} download failed with return code {}'.format(url, resp.status_code))
def api(*versions):
def decorator(func):
method = lambda *args, **kwargs: func(*args, **kwargs)
setattr(method, 'versions', versions)
setattr(method, 'api', True)
return method
return decorator
def setting(key):
defaults = {
'apiKey': None,
'apiLogPath': None,
'apiPollInterval': 25,
'apiVersion': 6,
'webBacklog': 5,
'webBindAddress': os.getenv('ANKICONNECT_BIND_ADDRESS', '127.0.0.1'),
'webBindPort': 8765,
'webCorsOrigin': os.getenv('ANKICONNECT_CORS_ORIGIN', 'http://localhost'),
'webTimeout': 10000,
}
try:
return aqt.mw.addonManager.getConfig(__name__).get(key, defaults[key])
except:
raise Exception('setting {} not found'.format(key))

185
plugin/web.py Normal file
View File

@ -0,0 +1,185 @@
# Copyright 2016-2020 Alex Yatskov
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import json
import select
import socket
from AnkiConnect import web, util
#
# WebRequest
#
class WebRequest:
def __init__(self, headers, body):
self.headers = headers
self.body = body
#
# WebClient
#
class WebClient:
def __init__(self, sock, handler):
self.sock = sock
self.handler = handler
self.readBuff = bytes()
self.writeBuff = bytes()
def advance(self, recvSize=1024):
if self.sock is None:
return False
rlist, wlist = select.select([self.sock], [self.sock], [], 0)[:2]
if rlist:
msg = self.sock.recv(recvSize)
if not msg:
self.close()
return False
self.readBuff += msg
req, length = self.parseRequest(self.readBuff)
if req is not None:
self.readBuff = self.readBuff[length:]
self.writeBuff += self.handler(req)
if wlist and self.writeBuff:
length = self.sock.send(self.writeBuff)
self.writeBuff = self.writeBuff[length:]
if not self.writeBuff:
self.close()
return False
return True
def close(self):
if self.sock is not None:
self.sock.close()
self.sock = None
self.readBuff = bytes()
self.writeBuff = bytes()
def parseRequest(self, data):
parts = data.split('\r\n\r\n'.encode('utf-8'), 1)
if len(parts) == 1:
return None, 0
headers = {}
for line in parts[0].split('\r\n'.encode('utf-8')):
pair = line.split(': '.encode('utf-8'))
headers[pair[0].lower()] = pair[1] if len(pair) > 1 else None
headerLength = len(parts[0]) + 4
bodyLength = int(headers.get('content-length'.encode('utf-8'), 0))
totalLength = headerLength + bodyLength
if totalLength > len(data):
return None, 0
body = data[headerLength : totalLength]
return WebRequest(headers, body), totalLength
#
# WebServer
#
class WebServer:
def __init__(self, handler):
self.handler = handler
self.clients = []
self.sock = None
def advance(self):
if self.sock is not None:
self.acceptClients()
self.advanceClients()
def acceptClients(self):
rlist = select.select([self.sock], [], [], 0)[0]
if not rlist:
return
clientSock = self.sock.accept()[0]
if clientSock is not None:
clientSock.setblocking(False)
self.clients.append(WebClient(clientSock, self.handlerWrapper))
def advanceClients(self):
self.clients = list(filter(lambda c: c.advance(), self.clients))
def listen(self):
self.close()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.setblocking(False)
self.sock.bind((util.setting('webBindAddress'), util.setting('webBindPort')))
self.sock.listen(util.setting('webBacklog'))
def handlerWrapper(self, req):
if len(req.body) == 0:
body = 'AnkiConnect'
else:
try:
params = json.loads(req.body.decode('utf-8'))
body = json.dumps(self.handler(params)).encode('utf-8')
except ValueError:
body = json.dumps(None).encode('utf-8')
headers = [
['HTTP/1.1 200 OK', None],
['Content-Type', 'text/json'],
['Access-Control-Allow-Origin', util.setting('webCorsOrigin')],
['Content-Length', str(len(body))]
]
resp = bytes()
for key, value in headers:
if value is None:
resp += '{}\r\n'.format(key).encode('utf-8')
else:
resp += '{}: {}\r\n'.format(key, value).encode('utf-8')
resp += '\r\n'.encode('utf-8')
resp += body
return resp
def close(self):
if self.sock is not None:
self.sock.close()
self.sock = None
for client in self.clients:
client.close()
self.clients = []