mirror of
				https://github.com/noDRM/DeDRM_tools.git
				synced 2025-10-23 23:07:47 -04:00 
			
		
		
		
	
		
			
				
	
	
		
			1039 lines
		
	
	
	
		
			32 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1039 lines
		
	
	
	
		
			32 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| # -*- coding: utf-8 -*-
 | |
| 
 | |
| # ion.py
 | |
| # Copyright © 2013-2020 Apprentice Harper et al.
 | |
| 
 | |
| __license__ = 'GPL v3'
 | |
| __version__ = '3.0'
 | |
| 
 | |
| # Revision history:
 | |
| #  Pascal implementation by lulzkabulz.
 | |
| #  BinaryIon.pas + DrmIon.pas + IonSymbols.pas
 | |
| #  1.0   - Python translation by apprenticenaomi.
 | |
| #  1.1   - DeDRM integration by anon.
 | |
| #  1.2   - Added pylzma import fallback
 | |
| #  1.3   - Fixed lzma support for calibre 4.6+
 | |
| #  2.0   - VoucherEnvelope v2/v3 support by apprenticesakuya.
 | |
| #  3.0   - Added Python 3 compatibility for calibre 5.0
 | |
| 
 | |
| """
 | |
| Decrypt Kindle KFX files.
 | |
| """
 | |
| 
 | |
| import collections
 | |
| import hashlib
 | |
| import hmac
 | |
| import os
 | |
| import os.path
 | |
| import struct
 | |
| 
 | |
| from io import BytesIO
 | |
| 
 | |
| from Crypto.Cipher import AES
 | |
| from Crypto.Util.py3compat import bchr, bord
 | |
| 
 | |
| try:
 | |
|     # lzma library from calibre 4.6.0 or later
 | |
|     import calibre_lzma.lzma1 as calibre_lzma
 | |
| except ImportError:
 | |
|     calibre_lzma = None
 | |
|     # lzma library from calibre 2.35.0 or later
 | |
|     try:
 | |
|         import lzma.lzma1 as calibre_lzma
 | |
|     except ImportError:
 | |
|         calibre_lzma = None
 | |
|         try:
 | |
|             import lzma
 | |
|         except ImportError:
 | |
|             # Need pip backports.lzma on Python <3.3
 | |
|             try:
 | |
|                 from backports import lzma
 | |
|             except ImportError:
 | |
|                 # Windows-friendly choice: pylzma wheels
 | |
|                 import pylzma as lzma
 | |
| 
 | |
| 
 | |
| TID_NULL = 0
 | |
| TID_BOOLEAN = 1
 | |
| TID_POSINT = 2
 | |
| TID_NEGINT = 3
 | |
| TID_FLOAT = 4
 | |
| TID_DECIMAL = 5
 | |
| TID_TIMESTAMP = 6
 | |
| TID_SYMBOL = 7
 | |
| TID_STRING = 8
 | |
| TID_CLOB = 9
 | |
| TID_BLOB = 0xA
 | |
| TID_LIST = 0xB
 | |
| TID_SEXP = 0xC
 | |
| TID_STRUCT = 0xD
 | |
| TID_TYPEDECL = 0xE
 | |
| TID_UNUSED = 0xF
 | |
| 
 | |
| 
 | |
| SID_UNKNOWN = -1
 | |
| SID_ION = 1
 | |
| SID_ION_1_0 = 2
 | |
| SID_ION_SYMBOL_TABLE = 3
 | |
| SID_NAME = 4
 | |
| SID_VERSION = 5
 | |
| SID_IMPORTS = 6
 | |
| SID_SYMBOLS = 7
 | |
| SID_MAX_ID = 8
 | |
| SID_ION_SHARED_SYMBOL_TABLE = 9
 | |
| SID_ION_1_0_MAX = 10
 | |
| 
 | |
| 
 | |
| LEN_IS_VAR_LEN = 0xE
 | |
| LEN_IS_NULL = 0xF
 | |
| 
 | |
| 
 | |
| VERSION_MARKER = b"\x01\x00\xEA"
 | |
| 
 | |
| 
 | |
| # asserts must always raise exceptions for proper functioning
 | |
| def _assert(test, msg="Exception"):
 | |
|     if not test:
 | |
|         raise Exception(msg)
 | |
| 
 | |
| 
 | |
| class SystemSymbols(object):
 | |
|     ION = '$ion'
 | |
|     ION_1_0 = '$ion_1_0'
 | |
|     ION_SYMBOL_TABLE = '$ion_symbol_table'
 | |
|     NAME = 'name'
 | |
|     VERSION = 'version'
 | |
|     IMPORTS = 'imports'
 | |
|     SYMBOLS = 'symbols'
 | |
|     MAX_ID = 'max_id'
 | |
|     ION_SHARED_SYMBOL_TABLE = '$ion_shared_symbol_table'
 | |
| 
 | |
| 
 | |
| class IonCatalogItem(object):
 | |
|     name = ""
 | |
|     version = 0
 | |
|     symnames = []
 | |
| 
 | |
|     def __init__(self, name, version, symnames):
 | |
|         self.name = name
 | |
|         self.version = version
 | |
|         self.symnames = symnames
 | |
| 
 | |
| 
 | |
| class SymbolToken(object):
 | |
|     text = ""
 | |
|     sid = 0
 | |
| 
 | |
|     def __init__(self, text, sid):
 | |
|         if text == "" and sid == 0:
 | |
|             raise ValueError("Symbol token must have Text or SID")
 | |
| 
 | |
|         self.text = text
 | |
|         self.sid = sid
 | |
| 
 | |
| 
 | |
| class SymbolTable(object):
 | |
|     table = None
 | |
| 
 | |
|     def __init__(self):
 | |
|         self.table = [None] * SID_ION_1_0_MAX
 | |
|         self.table[SID_ION] = SystemSymbols.ION
 | |
|         self.table[SID_ION_1_0] = SystemSymbols.ION_1_0
 | |
|         self.table[SID_ION_SYMBOL_TABLE] = SystemSymbols.ION_SYMBOL_TABLE
 | |
|         self.table[SID_NAME] = SystemSymbols.NAME
 | |
|         self.table[SID_VERSION] = SystemSymbols.VERSION
 | |
|         self.table[SID_IMPORTS] = SystemSymbols.IMPORTS
 | |
|         self.table[SID_SYMBOLS] = SystemSymbols.SYMBOLS
 | |
|         self.table[SID_MAX_ID] = SystemSymbols.MAX_ID
 | |
|         self.table[SID_ION_SHARED_SYMBOL_TABLE] = SystemSymbols.ION_SHARED_SYMBOL_TABLE
 | |
| 
 | |
|     def findbyid(self, sid):
 | |
|         if sid < 1:
 | |
|             raise ValueError("Invalid symbol id")
 | |
| 
 | |
|         if sid < len(self.table):
 | |
|             return self.table[sid]
 | |
|         else:
 | |
|             return ""
 | |
| 
 | |
|     def import_(self, table, maxid):
 | |
|         for i in range(maxid):
 | |
|             self.table.append(table.symnames[i])
 | |
| 
 | |
|     def importunknown(self, name, maxid):
 | |
|         for i in range(maxid):
 | |
|             self.table.append("%s#%d" % (name, i + 1))
 | |
| 
 | |
| 
 | |
| class ParserState:
 | |
|     Invalid,BeforeField,BeforeTID,BeforeValue,AfterValue,EOF = 1,2,3,4,5,6
 | |
| 
 | |
| ContainerRec = collections.namedtuple("ContainerRec", "nextpos, tid, remaining")
 | |
| 
 | |
| 
 | |
| class BinaryIonParser(object):
 | |
|     eof = False
 | |
|     state = None
 | |
|     localremaining = 0
 | |
|     needhasnext = False
 | |
|     isinstruct = False
 | |
|     valuetid = 0
 | |
|     valuefieldid = 0
 | |
|     parenttid = 0
 | |
|     valuelen = 0
 | |
|     valueisnull = False
 | |
|     valueistrue = False
 | |
|     value = None
 | |
|     didimports = False
 | |
| 
 | |
|     def __init__(self, stream):
 | |
|         self.annotations = []
 | |
|         self.catalog = []
 | |
| 
 | |
|         self.stream = stream
 | |
|         self.initpos = stream.tell()
 | |
|         self.reset()
 | |
|         self.symbols = SymbolTable()
 | |
| 
 | |
|     def reset(self):
 | |
|         self.state = ParserState.BeforeTID
 | |
|         self.needhasnext = True
 | |
|         self.localremaining = -1
 | |
|         self.eof = False
 | |
|         self.isinstruct = False
 | |
|         self.containerstack = []
 | |
|         self.stream.seek(self.initpos)
 | |
| 
 | |
|     def addtocatalog(self, name, version, symbols):
 | |
|         self.catalog.append(IonCatalogItem(name, version, symbols))
 | |
| 
 | |
|     def hasnext(self):
 | |
|         while self.needhasnext and not self.eof:
 | |
|             self.hasnextraw()
 | |
|             if len(self.containerstack) == 0 and not self.valueisnull:
 | |
|                 if self.valuetid == TID_SYMBOL:
 | |
|                     if self.value == SID_ION_1_0:
 | |
|                         self.needhasnext = True
 | |
|                 elif self.valuetid == TID_STRUCT:
 | |
|                     for a in self.annotations:
 | |
|                         if a == SID_ION_SYMBOL_TABLE:
 | |
|                             self.parsesymboltable()
 | |
|                             self.needhasnext = True
 | |
|                             break
 | |
|         return not self.eof
 | |
| 
 | |
|     def hasnextraw(self):
 | |
|         self.clearvalue()
 | |
|         while self.valuetid == -1 and not self.eof:
 | |
|             self.needhasnext = False
 | |
|             if self.state == ParserState.BeforeField:
 | |
|                 _assert(self.valuefieldid == SID_UNKNOWN)
 | |
| 
 | |
|                 self.valuefieldid = self.readfieldid()
 | |
|                 if self.valuefieldid != SID_UNKNOWN:
 | |
|                     self.state = ParserState.BeforeTID
 | |
|                 else:
 | |
|                     self.eof = True
 | |
| 
 | |
|             elif self.state == ParserState.BeforeTID:
 | |
|                 self.state = ParserState.BeforeValue
 | |
|                 self.valuetid = self.readtypeid()
 | |
|                 if self.valuetid == -1:
 | |
|                     self.state = ParserState.EOF
 | |
|                     self.eof = True
 | |
|                     break
 | |
| 
 | |
|                 if self.valuetid == TID_TYPEDECL:
 | |
|                     if self.valuelen == 0:
 | |
|                         self.checkversionmarker()
 | |
|                     else:
 | |
|                         self.loadannotations()
 | |
| 
 | |
|             elif self.state == ParserState.BeforeValue:
 | |
|                 self.skip(self.valuelen)
 | |
|                 self.state = ParserState.AfterValue
 | |
| 
 | |
|             elif self.state == ParserState.AfterValue:
 | |
|                 if self.isinstruct:
 | |
|                     self.state = ParserState.BeforeField
 | |
|                 else:
 | |
|                     self.state = ParserState.BeforeTID
 | |
| 
 | |
|             else:
 | |
|                 _assert(self.state == ParserState.EOF)
 | |
| 
 | |
|     def next(self):
 | |
|         if self.hasnext():
 | |
|             self.needhasnext = True
 | |
|             return self.valuetid
 | |
|         else:
 | |
|             return -1
 | |
| 
 | |
|     def push(self, typeid, nextposition, nextremaining):
 | |
|         self.containerstack.append(ContainerRec(nextpos=nextposition, tid=typeid, remaining=nextremaining))
 | |
| 
 | |
|     def stepin(self):
 | |
|         _assert(self.valuetid in [TID_STRUCT, TID_LIST, TID_SEXP] and not self.eof,
 | |
|                 "valuetid=%s eof=%s" % (self.valuetid, self.eof))
 | |
|         _assert((not self.valueisnull or self.state == ParserState.AfterValue) and
 | |
|                (self.valueisnull or self.state == ParserState.BeforeValue))
 | |
| 
 | |
|         nextrem = self.localremaining
 | |
|         if nextrem != -1:
 | |
|             nextrem -= self.valuelen
 | |
|             if nextrem < 0:
 | |
|                 nextrem = 0
 | |
|         self.push(self.parenttid, self.stream.tell() + self.valuelen, nextrem)
 | |
| 
 | |
|         self.isinstruct = (self.valuetid == TID_STRUCT)
 | |
|         if self.isinstruct:
 | |
|             self.state = ParserState.BeforeField
 | |
|         else:
 | |
|             self.state = ParserState.BeforeTID
 | |
| 
 | |
|         self.localremaining = self.valuelen
 | |
|         self.parenttid = self.valuetid
 | |
|         self.clearvalue()
 | |
|         self.needhasnext = True
 | |
| 
 | |
|     def stepout(self):
 | |
|         rec = self.containerstack.pop()
 | |
| 
 | |
|         self.eof = False
 | |
|         self.parenttid = rec.tid
 | |
|         if self.parenttid == TID_STRUCT:
 | |
|             self.isinstruct = True
 | |
|             self.state = ParserState.BeforeField
 | |
|         else:
 | |
|             self.isinstruct = False
 | |
|             self.state = ParserState.BeforeTID
 | |
|         self.needhasnext = True
 | |
| 
 | |
|         self.clearvalue()
 | |
|         curpos = self.stream.tell()
 | |
|         if rec.nextpos > curpos:
 | |
|             self.skip(rec.nextpos - curpos)
 | |
|         else:
 | |
|             _assert(rec.nextpos == curpos)
 | |
| 
 | |
|         self.localremaining = rec.remaining
 | |
| 
 | |
|     def read(self, count=1):
 | |
|         if self.localremaining != -1:
 | |
|             self.localremaining -= count
 | |
|             _assert(self.localremaining >= 0)
 | |
| 
 | |
|         result = self.stream.read(count)
 | |
|         if len(result) == 0:
 | |
|             raise EOFError()
 | |
|         return result
 | |
| 
 | |
|     def readfieldid(self):
 | |
|         if self.localremaining != -1 and self.localremaining < 1:
 | |
|             return -1
 | |
| 
 | |
|         try:
 | |
|             return self.readvaruint()
 | |
|         except EOFError:
 | |
|             return -1
 | |
| 
 | |
|     def readtypeid(self):
 | |
|         if self.localremaining != -1:
 | |
|             if self.localremaining < 1:
 | |
|                 return -1
 | |
|             self.localremaining -= 1
 | |
| 
 | |
|         b = self.stream.read(1)
 | |
|         if len(b) < 1:
 | |
|             return -1
 | |
|         b = bord(b)
 | |
|         result = b >> 4
 | |
|         ln = b & 0xF
 | |
| 
 | |
|         if ln == LEN_IS_VAR_LEN:
 | |
|             ln = self.readvaruint()
 | |
|         elif ln == LEN_IS_NULL:
 | |
|             ln = 0
 | |
|             self.state = ParserState.AfterValue
 | |
|         elif result == TID_NULL:
 | |
|             # Must have LEN_IS_NULL
 | |
|             _assert(False)
 | |
|         elif result == TID_BOOLEAN:
 | |
|             _assert(ln <= 1)
 | |
|             self.valueistrue = (ln == 1)
 | |
|             ln = 0
 | |
|             self.state = ParserState.AfterValue
 | |
|         elif result == TID_STRUCT:
 | |
|             if ln == 1:
 | |
|                 ln = self.readvaruint()
 | |
| 
 | |
|         self.valuelen = ln
 | |
|         return result
 | |
| 
 | |
|     def readvarint(self):
 | |
|         b = bord(self.read())
 | |
|         negative = ((b & 0x40) != 0)
 | |
|         result = (b & 0x3F)
 | |
| 
 | |
|         i = 0
 | |
|         while (b & 0x80) == 0 and i < 4:
 | |
|             b = bord(self.read())
 | |
|             result = (result << 7) | (b & 0x7F)
 | |
|             i += 1
 | |
| 
 | |
|         _assert(i < 4 or (b & 0x80) != 0, "int overflow")
 | |
| 
 | |
|         if negative:
 | |
|             return -result
 | |
|         return result
 | |
| 
 | |
|     def readvaruint(self):
 | |
|         b = bord(self.read())
 | |
|         result = (b & 0x7F)
 | |
| 
 | |
|         i = 0
 | |
|         while (b & 0x80) == 0 and i < 4:
 | |
|             b = bord(self.read())
 | |
|             result = (result << 7) | (b & 0x7F)
 | |
|             i += 1
 | |
| 
 | |
|         _assert(i < 4 or (b & 0x80) != 0, "int overflow")
 | |
| 
 | |
|         return result
 | |
| 
 | |
|     def readdecimal(self):
 | |
|         if self.valuelen == 0:
 | |
|             return 0.
 | |
| 
 | |
|         rem = self.localremaining - self.valuelen
 | |
|         self.localremaining = self.valuelen
 | |
|         exponent = self.readvarint()
 | |
| 
 | |
|         _assert(self.localremaining > 0, "Only exponent in ReadDecimal")
 | |
|         _assert(self.localremaining <= 8, "Decimal overflow")
 | |
| 
 | |
|         signed = False
 | |
|         b = [bord(x) for x in self.read(self.localremaining)]
 | |
|         if (b[0] & 0x80) != 0:
 | |
|             b[0] = b[0] & 0x7F
 | |
|             signed = True
 | |
| 
 | |
|         # Convert variably sized network order integer into 64-bit little endian
 | |
|         j = 0
 | |
|         vb = [0] * 8
 | |
|         for i in range(len(b), -1, -1):
 | |
|             vb[i] = b[j]
 | |
|             j += 1
 | |
| 
 | |
|         v = struct.unpack("<Q", b"".join(bchr(x) for x in vb))[0]
 | |
| 
 | |
|         result = v * (10 ** exponent)
 | |
|         if signed:
 | |
|             result = -result
 | |
| 
 | |
|         self.localremaining = rem
 | |
|         return result
 | |
| 
 | |
|     def skip(self, count):
 | |
|         if self.localremaining != -1:
 | |
|             self.localremaining -= count
 | |
|             if self.localremaining < 0:
 | |
|                 raise EOFError()
 | |
| 
 | |
|         self.stream.seek(count, os.SEEK_CUR)
 | |
| 
 | |
|     def parsesymboltable(self):
 | |
|         self.next() # shouldn't do anything?
 | |
| 
 | |
|         _assert(self.valuetid == TID_STRUCT)
 | |
| 
 | |
|         if self.didimports:
 | |
|             return
 | |
| 
 | |
|         self.stepin()
 | |
| 
 | |
|         fieldtype = self.next()
 | |
|         while fieldtype != -1:
 | |
|             if not self.valueisnull:
 | |
|                 _assert(self.valuefieldid == SID_IMPORTS, "Unsupported symbol table field id")
 | |
| 
 | |
|                 if fieldtype == TID_LIST:
 | |
|                     self.gatherimports()
 | |
| 
 | |
|             fieldtype = self.next()
 | |
| 
 | |
|         self.stepout()
 | |
|         self.didimports = True
 | |
| 
 | |
|     def gatherimports(self):
 | |
|         self.stepin()
 | |
| 
 | |
|         t = self.next()
 | |
|         while t != -1:
 | |
|             if not self.valueisnull and t == TID_STRUCT:
 | |
|                 self.readimport()
 | |
| 
 | |
|             t = self.next()
 | |
| 
 | |
|         self.stepout()
 | |
| 
 | |
|     def readimport(self):
 | |
|         version = -1
 | |
|         maxid = -1
 | |
|         name = ""
 | |
| 
 | |
|         self.stepin()
 | |
| 
 | |
|         t = self.next()
 | |
|         while t != -1:
 | |
|             if not self.valueisnull and self.valuefieldid != SID_UNKNOWN:
 | |
|                 if self.valuefieldid == SID_NAME:
 | |
|                     name = self.stringvalue()
 | |
|                 elif self.valuefieldid == SID_VERSION:
 | |
|                     version = self.intvalue()
 | |
|                 elif self.valuefieldid == SID_MAX_ID:
 | |
|                     maxid = self.intvalue()
 | |
| 
 | |
|             t = self.next()
 | |
| 
 | |
|         self.stepout()
 | |
| 
 | |
|         if name == "" or name == SystemSymbols.ION:
 | |
|             return
 | |
| 
 | |
|         if version < 1:
 | |
|             version = 1
 | |
| 
 | |
|         table = self.findcatalogitem(name)
 | |
|         if maxid < 0:
 | |
|             _assert(table is not None and version == table.version, "Import %s lacks maxid" % name)
 | |
|             maxid = len(table.symnames)
 | |
| 
 | |
|         if table is not None:
 | |
|             self.symbols.import_(table, min(maxid, len(table.symnames)))
 | |
|         else:
 | |
|             self.symbols.importunknown(name, maxid)
 | |
| 
 | |
|     def intvalue(self):
 | |
|         _assert(self.valuetid in [TID_POSINT, TID_NEGINT], "Not an int")
 | |
| 
 | |
|         self.preparevalue()
 | |
|         return self.value
 | |
| 
 | |
|     def stringvalue(self):
 | |
|         _assert(self.valuetid == TID_STRING, "Not a string")
 | |
| 
 | |
|         if self.valueisnull:
 | |
|             return ""
 | |
| 
 | |
|         self.preparevalue()
 | |
|         return self.value
 | |
| 
 | |
|     def symbolvalue(self):
 | |
|         _assert(self.valuetid == TID_SYMBOL, "Not a symbol")
 | |
| 
 | |
|         self.preparevalue()
 | |
|         result = self.symbols.findbyid(self.value)
 | |
|         if result == "":
 | |
|             result = "SYMBOL#%d" % self.value
 | |
|         return result
 | |
| 
 | |
|     def lobvalue(self):
 | |
|         _assert(self.valuetid in [TID_CLOB, TID_BLOB], "Not a LOB type: %s" % self.getfieldname())
 | |
| 
 | |
|         if self.valueisnull:
 | |
|             return None
 | |
| 
 | |
|         result = self.read(self.valuelen)
 | |
|         self.state = ParserState.AfterValue
 | |
|         return result
 | |
| 
 | |
|     def decimalvalue(self):
 | |
|         _assert(self.valuetid == TID_DECIMAL, "Not a decimal")
 | |
| 
 | |
|         self.preparevalue()
 | |
|         return self.value
 | |
| 
 | |
|     def preparevalue(self):
 | |
|         if self.value is None:
 | |
|             self.loadscalarvalue()
 | |
| 
 | |
|     def loadscalarvalue(self):
 | |
|         if self.valuetid not in [TID_NULL, TID_BOOLEAN, TID_POSINT, TID_NEGINT,
 | |
|                                  TID_FLOAT, TID_DECIMAL, TID_TIMESTAMP,
 | |
|                                  TID_SYMBOL, TID_STRING]:
 | |
|             return
 | |
| 
 | |
|         if self.valueisnull:
 | |
|             self.value = None
 | |
|             return
 | |
| 
 | |
|         if self.valuetid == TID_STRING:
 | |
|             self.value = self.read(self.valuelen).decode("UTF-8")
 | |
| 
 | |
|         elif self.valuetid in (TID_POSINT, TID_NEGINT, TID_SYMBOL):
 | |
|             if self.valuelen == 0:
 | |
|                 self.value = 0
 | |
|             else:
 | |
|                 _assert(self.valuelen <= 4, "int too long: %d" % self.valuelen)
 | |
|                 v = 0
 | |
|                 for i in range(self.valuelen - 1, -1, -1):
 | |
|                     v = (v | (bord(self.read()) << (i * 8)))
 | |
| 
 | |
|                 if self.valuetid == TID_NEGINT:
 | |
|                     self.value = -v
 | |
|                 else:
 | |
|                     self.value = v
 | |
| 
 | |
|         elif self.valuetid == TID_DECIMAL:
 | |
|             self.value = self.readdecimal()
 | |
| 
 | |
|         #else:
 | |
|         #    _assert(False, "Unhandled scalar type %d" % self.valuetid)
 | |
| 
 | |
|         self.state = ParserState.AfterValue
 | |
| 
 | |
|     def clearvalue(self):
 | |
|         self.valuetid = -1
 | |
|         self.value = None
 | |
|         self.valueisnull = False
 | |
|         self.valuefieldid = SID_UNKNOWN
 | |
|         self.annotations = []
 | |
| 
 | |
|     def loadannotations(self):
 | |
|         ln = self.readvaruint()
 | |
|         maxpos = self.stream.tell() + ln
 | |
|         while self.stream.tell() < maxpos:
 | |
|             self.annotations.append(self.readvaruint())
 | |
|         self.valuetid = self.readtypeid()
 | |
| 
 | |
|     def checkversionmarker(self):
 | |
|         for i in VERSION_MARKER:
 | |
|             _assert(self.read() == i, "Unknown version marker")
 | |
| 
 | |
|         self.valuelen = 0
 | |
|         self.valuetid = TID_SYMBOL
 | |
|         self.value = SID_ION_1_0
 | |
|         self.valueisnull = False
 | |
|         self.valuefieldid = SID_UNKNOWN
 | |
|         self.state = ParserState.AfterValue
 | |
| 
 | |
|     def findcatalogitem(self, name):
 | |
|         for result in self.catalog:
 | |
|             if result.name == name:
 | |
|                 return result
 | |
| 
 | |
|     def forceimport(self, symbols):
 | |
|         item = IonCatalogItem("Forced", 1, symbols)
 | |
|         self.symbols.import_(item, len(symbols))
 | |
| 
 | |
|     def getfieldname(self):
 | |
|         if self.valuefieldid == SID_UNKNOWN:
 | |
|             return ""
 | |
|         return self.symbols.findbyid(self.valuefieldid)
 | |
| 
 | |
|     def getfieldnamesymbol(self):
 | |
|         return SymbolToken(self.getfieldname(), self.valuefieldid)
 | |
| 
 | |
|     def gettypename(self):
 | |
|         if len(self.annotations) == 0:
 | |
|             return ""
 | |
| 
 | |
|         return self.symbols.findbyid(self.annotations[0])
 | |
| 
 | |
|     @staticmethod
 | |
|     def printlob(b):
 | |
|         if b is None:
 | |
|             return "null"
 | |
| 
 | |
|         result = ""
 | |
|         for i in b:
 | |
|             result += ("%02x " % bord(i))
 | |
| 
 | |
|         if len(result) > 0:
 | |
|             result = result[:-1]
 | |
|         return result
 | |
| 
 | |
|     def ionwalk(self, supert, indent, lst):
 | |
|         while self.hasnext():
 | |
|             if supert == TID_STRUCT:
 | |
|                 L = self.getfieldname() + ":"
 | |
|             else:
 | |
|                 L = ""
 | |
| 
 | |
|             t = self.next()
 | |
|             if t in [TID_STRUCT, TID_LIST]:
 | |
|                 if L != "":
 | |
|                     lst.append(indent + L)
 | |
|                 L = self.gettypename()
 | |
|                 if L != "":
 | |
|                     lst.append(indent + L + "::")
 | |
|                 if t == TID_STRUCT:
 | |
|                     lst.append(indent + "{")
 | |
|                 else:
 | |
|                     lst.append(indent + "[")
 | |
| 
 | |
|                 self.stepin()
 | |
|                 self.ionwalk(t, indent + "  ", lst)
 | |
|                 self.stepout()
 | |
| 
 | |
|                 if t == TID_STRUCT:
 | |
|                     lst.append(indent + "}")
 | |
|                 else:
 | |
|                     lst.append(indent + "]")
 | |
| 
 | |
|             else:
 | |
|                 if t == TID_STRING:
 | |
|                     L += ('"%s"' % self.stringvalue())
 | |
|                 elif t in [TID_CLOB, TID_BLOB]:
 | |
|                     L += ("{%s}" % self.printlob(self.lobvalue()))
 | |
|                 elif t == TID_POSINT:
 | |
|                     L += str(self.intvalue())
 | |
|                 elif t == TID_SYMBOL:
 | |
|                     tn = self.gettypename()
 | |
|                     if tn != "":
 | |
|                         tn += "::"
 | |
|                     L += tn + self.symbolvalue()
 | |
|                 elif t == TID_DECIMAL:
 | |
|                     L += str(self.decimalvalue())
 | |
|                 else:
 | |
|                     L += ("TID %d" % t)
 | |
|                 lst.append(indent + L)
 | |
| 
 | |
|     def print_(self, lst):
 | |
|         self.reset()
 | |
|         self.ionwalk(-1, "", lst)
 | |
| 
 | |
| 
 | |
| SYM_NAMES = [ 'com.amazon.drm.Envelope@1.0',
 | |
|               'com.amazon.drm.EnvelopeMetadata@1.0', 'size', 'page_size',
 | |
|               'encryption_key', 'encryption_transformation',
 | |
|               'encryption_voucher', 'signing_key', 'signing_algorithm',
 | |
|               'signing_voucher', 'com.amazon.drm.EncryptedPage@1.0',
 | |
|               'cipher_text', 'cipher_iv', 'com.amazon.drm.Signature@1.0',
 | |
|               'data', 'com.amazon.drm.EnvelopeIndexTable@1.0', 'length',
 | |
|               'offset', 'algorithm', 'encoded', 'encryption_algorithm',
 | |
|               'hashing_algorithm', 'expires', 'format', 'id',
 | |
|               'lock_parameters', 'strategy', 'com.amazon.drm.Key@1.0',
 | |
|               'com.amazon.drm.KeySet@1.0', 'com.amazon.drm.PIDv3@1.0',
 | |
|               'com.amazon.drm.PlainTextPage@1.0',
 | |
|               'com.amazon.drm.PlainText@1.0', 'com.amazon.drm.PrivateKey@1.0',
 | |
|               'com.amazon.drm.PublicKey@1.0', 'com.amazon.drm.SecretKey@1.0',
 | |
|               'com.amazon.drm.Voucher@1.0', 'public_key', 'private_key',
 | |
|               'com.amazon.drm.KeyPair@1.0', 'com.amazon.drm.ProtectedData@1.0',
 | |
|               'doctype', 'com.amazon.drm.EnvelopeIndexTableOffset@1.0',
 | |
|               'enddoc', 'license_type', 'license', 'watermark', 'key', 'value',
 | |
|               'com.amazon.drm.License@1.0', 'category', 'metadata',
 | |
|               'categorized_metadata', 'com.amazon.drm.CategorizedMetadata@1.0',
 | |
|               'com.amazon.drm.VoucherEnvelope@1.0', 'mac', 'voucher',
 | |
|               'com.amazon.drm.ProtectedData@2.0',
 | |
|               'com.amazon.drm.Envelope@2.0',
 | |
|               'com.amazon.drm.EnvelopeMetadata@2.0',
 | |
|               'com.amazon.drm.EncryptedPage@2.0',
 | |
|               'com.amazon.drm.PlainText@2.0', 'compression_algorithm',
 | |
|               'com.amazon.drm.Compressed@1.0', 'page_index_table',
 | |
|               'com.amazon.drm.VoucherEnvelope@2.0', 'com.amazon.drm.VoucherEnvelope@3.0' ]
 | |
| 
 | |
| def addprottable(ion):
 | |
|     ion.addtocatalog("ProtectedData", 1, SYM_NAMES)
 | |
| 
 | |
| 
 | |
| def pkcs7pad(msg, blocklen):
 | |
|     paddinglen = blocklen - len(msg) % blocklen
 | |
|     padding = bchr(paddinglen) * paddinglen
 | |
|     return msg + padding
 | |
| 
 | |
| 
 | |
| def pkcs7unpad(msg, blocklen):
 | |
|     _assert(len(msg) % blocklen == 0)
 | |
| 
 | |
|     paddinglen = bord(msg[-1])
 | |
|     _assert(paddinglen > 0 and paddinglen <= blocklen, "Incorrect padding - Wrong key")
 | |
|     _assert(msg[-paddinglen:] == bchr(paddinglen) * paddinglen, "Incorrect padding - Wrong key")
 | |
| 
 | |
|     return msg[:-paddinglen]
 | |
| 
 | |
| 
 | |
| # every VoucherEnvelope version has a corresponding "word" and magic number, used in obfuscating the shared secret
 | |
| VOUCHER_VERSION_INFOS = {
 | |
|     2: [b'Antidisestablishmentarianism', 5],
 | |
|     3: [b'Floccinaucinihilipilification', 8]
 | |
| }
 | |
| 
 | |
| 
 | |
| # obfuscate shared secret according to the VoucherEnvelope version
 | |
| def obfuscate(secret, version):
 | |
|     if version == 1:  # v1 does not use obfuscation
 | |
|         return secret
 | |
| 
 | |
|     params = VOUCHER_VERSION_INFOS[version]
 | |
|     word = params[0]
 | |
|     magic = params[1]
 | |
| 
 | |
|     # extend secret so that its length is divisible by the magic number
 | |
|     if len(secret) % magic != 0:
 | |
|         secret = secret + b'\x00' * (magic - len(secret) % magic)
 | |
| 
 | |
|     secret = bytearray(secret)
 | |
| 
 | |
|     obfuscated = bytearray(len(secret))
 | |
|     wordhash = bytearray(hashlib.sha256(word).digest())
 | |
| 
 | |
|     # shuffle secret and xor it with the first half of the word hash
 | |
|     for i in range(0, len(secret)):
 | |
|         index = i // (len(secret) // magic) + magic * (i % (len(secret) // magic))
 | |
|         obfuscated[index] = secret[i] ^ wordhash[index % 16]
 | |
| 
 | |
|     return obfuscated
 | |
| 
 | |
| 
 | |
| class DrmIonVoucher(object):
 | |
|     envelope = None
 | |
|     version = None
 | |
|     voucher = None
 | |
|     drmkey = None
 | |
|     license_type = "Unknown"
 | |
| 
 | |
|     encalgorithm = ""
 | |
|     enctransformation = ""
 | |
|     hashalgorithm = ""
 | |
| 
 | |
|     lockparams = None
 | |
| 
 | |
|     ciphertext = b""
 | |
|     cipheriv = b""
 | |
|     secretkey = b""
 | |
| 
 | |
|     def __init__(self, voucherenv, dsn, secret):
 | |
|         self.dsn,self.secret = dsn,secret
 | |
| 
 | |
|         self.lockparams = []
 | |
| 
 | |
|         self.envelope = BinaryIonParser(voucherenv)
 | |
|         addprottable(self.envelope)
 | |
| 
 | |
|     def decryptvoucher(self):
 | |
|         shared = "PIDv3" + self.encalgorithm + self.enctransformation + self.hashalgorithm
 | |
| 
 | |
|         self.lockparams.sort()
 | |
|         for param in self.lockparams:
 | |
|             if param == "ACCOUNT_SECRET":
 | |
|                 shared += param + self.secret
 | |
|             elif param == "CLIENT_ID":
 | |
|                 shared += param + self.dsn
 | |
|             else:
 | |
|                 _assert(False, "Unknown lock parameter: %s" % param)
 | |
| 
 | |
|         sharedsecret = obfuscate(shared.encode('ASCII'), self.version)
 | |
| 
 | |
|         key = hmac.new(sharedsecret, "PIDv3", digestmod=hashlib.sha256).digest()
 | |
|         aes = AES.new(key[:32], AES.MODE_CBC, self.cipheriv[:16])
 | |
|         b = aes.decrypt(self.ciphertext)
 | |
|         b = pkcs7unpad(b, 16)
 | |
| 
 | |
|         self.drmkey = BinaryIonParser(BytesIO(b))
 | |
|         addprottable(self.drmkey)
 | |
| 
 | |
|         _assert(self.drmkey.hasnext() and self.drmkey.next() == TID_LIST and self.drmkey.gettypename() == "com.amazon.drm.KeySet@1.0",
 | |
|                 "Expected KeySet, got %s" % self.drmkey.gettypename())
 | |
| 
 | |
|         self.drmkey.stepin()
 | |
|         while self.drmkey.hasnext():
 | |
|             self.drmkey.next()
 | |
|             if self.drmkey.gettypename() != "com.amazon.drm.SecretKey@1.0":
 | |
|                 continue
 | |
| 
 | |
|             self.drmkey.stepin()
 | |
|             while self.drmkey.hasnext():
 | |
|                 self.drmkey.next()
 | |
|                 if self.drmkey.getfieldname() == "algorithm":
 | |
|                     _assert(self.drmkey.stringvalue() == "AES", "Unknown cipher algorithm: %s" % self.drmkey.stringvalue())
 | |
|                 elif self.drmkey.getfieldname() == "format":
 | |
|                     _assert(self.drmkey.stringvalue() == "RAW", "Unknown key format: %s" % self.drmkey.stringvalue())
 | |
|                 elif self.drmkey.getfieldname() == "encoded":
 | |
|                     self.secretkey = self.drmkey.lobvalue()
 | |
| 
 | |
|             self.drmkey.stepout()
 | |
|             break
 | |
| 
 | |
|         self.drmkey.stepout()
 | |
| 
 | |
|     def parse(self):
 | |
|         self.envelope.reset()
 | |
|         _assert(self.envelope.hasnext(), "Envelope is empty")
 | |
|         _assert(self.envelope.next() == TID_STRUCT and str.startswith(self.envelope.gettypename(), "com.amazon.drm.VoucherEnvelope@"),
 | |
|                 "Unknown type encountered in envelope, expected VoucherEnvelope")
 | |
|         self.version = int(self.envelope.gettypename().split('@')[1][:-2])
 | |
| 
 | |
|         self.envelope.stepin()
 | |
|         while self.envelope.hasnext():
 | |
|             self.envelope.next()
 | |
|             field = self.envelope.getfieldname()
 | |
|             if field == "voucher":
 | |
|                 self.voucher = BinaryIonParser(BytesIO(self.envelope.lobvalue()))
 | |
|                 addprottable(self.voucher)
 | |
|                 continue
 | |
|             elif field != "strategy":
 | |
|                 continue
 | |
| 
 | |
|             _assert(self.envelope.gettypename() == "com.amazon.drm.PIDv3@1.0", "Unknown strategy: %s" % self.envelope.gettypename())
 | |
| 
 | |
|             self.envelope.stepin()
 | |
|             while self.envelope.hasnext():
 | |
|                 self.envelope.next()
 | |
|                 field = self.envelope.getfieldname()
 | |
|                 if field == "encryption_algorithm":
 | |
|                     self.encalgorithm = self.envelope.stringvalue()
 | |
|                 elif field == "encryption_transformation":
 | |
|                     self.enctransformation = self.envelope.stringvalue()
 | |
|                 elif field == "hashing_algorithm":
 | |
|                     self.hashalgorithm = self.envelope.stringvalue()
 | |
|                 elif field == "lock_parameters":
 | |
|                     self.envelope.stepin()
 | |
|                     while self.envelope.hasnext():
 | |
|                         _assert(self.envelope.next() == TID_STRING, "Expected string list for lock_parameters")
 | |
|                         self.lockparams.append(self.envelope.stringvalue())
 | |
|                     self.envelope.stepout()
 | |
| 
 | |
|             self.envelope.stepout()
 | |
| 
 | |
|         self.parsevoucher()
 | |
| 
 | |
|     def parsevoucher(self):
 | |
|         _assert(self.voucher.hasnext(), "Voucher is empty")
 | |
|         _assert(self.voucher.next() == TID_STRUCT and self.voucher.gettypename() == "com.amazon.drm.Voucher@1.0",
 | |
|                 "Unknown type, expected Voucher")
 | |
| 
 | |
|         self.voucher.stepin()
 | |
|         while self.voucher.hasnext():
 | |
|             self.voucher.next()
 | |
| 
 | |
|             if self.voucher.getfieldname() == "cipher_iv":
 | |
|                 self.cipheriv = self.voucher.lobvalue()
 | |
|             elif self.voucher.getfieldname() == "cipher_text":
 | |
|                 self.ciphertext = self.voucher.lobvalue()
 | |
|             elif self.voucher.getfieldname() == "license":
 | |
|                 _assert(self.voucher.gettypename() == "com.amazon.drm.License@1.0",
 | |
|                         "Unknown license: %s" % self.voucher.gettypename())
 | |
|                 self.voucher.stepin()
 | |
|                 while self.voucher.hasnext():
 | |
|                     self.voucher.next()
 | |
|                     if self.voucher.getfieldname() == "license_type":
 | |
|                         self.license_type = self.voucher.stringvalue()
 | |
|                 self.voucher.stepout()
 | |
| 
 | |
|     def printenvelope(self, lst):
 | |
|         self.envelope.print_(lst)
 | |
| 
 | |
|     def printkey(self, lst):
 | |
|         if self.voucher is None:
 | |
|             self.parse()
 | |
|         if self.drmkey is None:
 | |
|             self.decryptvoucher()
 | |
| 
 | |
|         self.drmkey.print_(lst)
 | |
| 
 | |
|     def printvoucher(self, lst):
 | |
|         if self.voucher is None:
 | |
|             self.parse()
 | |
| 
 | |
|         self.voucher.print_(lst)
 | |
| 
 | |
|     def getlicensetype(self):
 | |
|         return self.license_type
 | |
| 
 | |
| 
 | |
| class DrmIon(object):
 | |
|     ion = None
 | |
|     voucher = None
 | |
|     vouchername = ""
 | |
|     key = b""
 | |
|     onvoucherrequired = None
 | |
| 
 | |
|     def __init__(self, ionstream, onvoucherrequired):
 | |
|         self.ion = BinaryIonParser(ionstream)
 | |
|         addprottable(self.ion)
 | |
|         self.onvoucherrequired = onvoucherrequired
 | |
| 
 | |
|     def parse(self, outpages):
 | |
|         self.ion.reset()
 | |
| 
 | |
|         _assert(self.ion.hasnext(), "DRMION envelope is empty")
 | |
|         _assert(self.ion.next() == TID_SYMBOL and self.ion.gettypename() == "doctype", "Expected doctype symbol")
 | |
|         _assert(self.ion.next() == TID_LIST and self.ion.gettypename() in ["com.amazon.drm.Envelope@1.0", "com.amazon.drm.Envelope@2.0"],
 | |
|                 "Unknown type encountered in DRMION envelope, expected Envelope, got %s" % self.ion.gettypename())
 | |
| 
 | |
|         while True:
 | |
|             if self.ion.gettypename() == "enddoc":
 | |
|                 break
 | |
| 
 | |
|             self.ion.stepin()
 | |
|             while self.ion.hasnext():
 | |
|                 self.ion.next()
 | |
| 
 | |
|                 if self.ion.gettypename() in ["com.amazon.drm.EnvelopeMetadata@1.0", "com.amazon.drm.EnvelopeMetadata@2.0"]:
 | |
|                     self.ion.stepin()
 | |
|                     while self.ion.hasnext():
 | |
|                         self.ion.next()
 | |
|                         if self.ion.getfieldname() != "encryption_voucher":
 | |
|                             continue
 | |
| 
 | |
|                         if self.vouchername == "":
 | |
|                             self.vouchername = self.ion.stringvalue()
 | |
|                             self.voucher = self.onvoucherrequired(self.vouchername)
 | |
|                             self.key = self.voucher.secretkey
 | |
|                             _assert(self.key is not None, "Unable to obtain secret key from voucher")
 | |
|                         else:
 | |
|                             _assert(self.vouchername == self.ion.stringvalue(),
 | |
|                                     "Unexpected: Different vouchers required for same file?")
 | |
| 
 | |
|                     self.ion.stepout()
 | |
| 
 | |
|                 elif self.ion.gettypename() in ["com.amazon.drm.EncryptedPage@1.0", "com.amazon.drm.EncryptedPage@2.0"]:
 | |
|                     decompress = False
 | |
|                     ct = None
 | |
|                     civ = None
 | |
|                     self.ion.stepin()
 | |
|                     while self.ion.hasnext():
 | |
|                         self.ion.next()
 | |
|                         if self.ion.gettypename() == "com.amazon.drm.Compressed@1.0":
 | |
|                             decompress = True
 | |
|                         if self.ion.getfieldname() == "cipher_text":
 | |
|                             ct = self.ion.lobvalue()
 | |
|                         elif self.ion.getfieldname() == "cipher_iv":
 | |
|                             civ = self.ion.lobvalue()
 | |
| 
 | |
|                     if ct is not None and civ is not None:
 | |
|                         self.processpage(ct, civ, outpages, decompress)
 | |
|                     self.ion.stepout()
 | |
| 
 | |
|             self.ion.stepout()
 | |
|             if not self.ion.hasnext():
 | |
|                 break
 | |
|             self.ion.next()
 | |
| 
 | |
|     def print_(self, lst):
 | |
|         self.ion.print_(lst)
 | |
| 
 | |
|     def processpage(self, ct, civ, outpages, decompress):
 | |
|         aes = AES.new(self.key[:16], AES.MODE_CBC, civ[:16])
 | |
|         msg = pkcs7unpad(aes.decrypt(ct), 16)
 | |
| 
 | |
|         if not decompress:
 | |
|             outpages.write(msg)
 | |
|             return
 | |
| 
 | |
|         _assert(msg[0] == b"\x00", "LZMA UseFilter not supported")
 | |
| 
 | |
|         if calibre_lzma is not None:
 | |
|             with calibre_lzma.decompress(msg[1:], bufsize=0x1000000) as f:
 | |
|                 f.seek(0)
 | |
|                 outpages.write(f.read())
 | |
|             return
 | |
| 
 | |
|         decomp = lzma.LZMADecompressor(format=lzma.FORMAT_ALONE)
 | |
|         while not decomp.eof:
 | |
|             segment = decomp.decompress(msg[1:])
 | |
|             msg = b"" # Contents were internally buffered after the first call
 | |
|             outpages.write(segment)
 | 
