Python Cookbook

数据结构与算法

小顶堆

import heapq

nums = [1, 8, 2, 23, 7, -4, 18, 23, 42, 37, 2]
heap = list(nums)
heapq.heapify(heap)
heap
# [-4, 2, 1, 23, 7, 2, 18, 23, 42, 37, 8]
print(heapq.nlargest(3, nums)) # Prints [42, 37, 23]
print(heapq.nsmallest(3, nums)) # Prints [-4, 1, 2]

nlargest()nsmallest() 适合查找元素数量少的情况,如果查找数量较多,通常先排序后切片更快。sorted(items)[:N]

优先队列

import heapq

class PriorityQueue(object):

    def __init__(self):
        self._queue = []
        self._index = 0

    def push(self, item, priority):
        heapq.heappush(self._queue, (-priority, self._index, item))
        self._index += 1

    def pop(self):
        return heapq.heappop(self._queue)[-1]

heapq 是小顶堆,队列中包含 (-priority, index, item) 元组,优先级为负数的目的是使得元素按照优先级从高到低排序,index 变量的作用是保证同等优先级元素的正确排序。通过保存一个不断增加的 index 下标变量,可以确保元素按照它们插入的顺序排序。

multidict 多值字典

from collections import defaultdict

d = defaultdict(set)
d['a'].add(1)
d['a'].add(2)
d['b'].add(4)

# 传统字典模拟

d = {} # 一个普通的字典
d.setdefault('a', []).append(1)
d.setdefault('a', []).append(2)
d.setdefault('b', []).append(4)

使用场景对于没有初值的字典进行操作,如统计列表中数字的个数,需要对未出现的值进行判断,多值字典默认具有初始化工厂获取初值,无需进行判断。

ordereddict 排序字典

from collections import OrderedDict

d = OrderedDict()
d['foo'] = 1
d['bar'] = 2
d['spam'] = 3
d['grok'] = 4
for key in d:
    print(key, d[key])
# Outputs "foo 1", "bar 2", "spam 3", "grok 4"

默认的字典中元素根据 hash 值进行排序,如果需要保留元素插入的顺序,使用 OrderedDict 是个很好的选择,排序字典内部维护着一个根据键插入顺序排序的双向链表,对已存在的键复制不会改变顺序,内存空间是原有字典的两倍。

字典运算


prices = {
    'ACME': 45.23,
    'AAPL': 612.78,
    'IBM': 205.55,
    'HPQ': 37.20,
    'FB': 10.75
}

min_price = min(zip(prices.values(), prices.keys()))
# min_price is (10.75, 'FB')
max_price = max(zip(prices.values(), prices.keys()))
# max_price is (612.78, 'AAPL')

Note:通过 zip 获取最小值或最大值的键值对,需要注意的是在计算操作中使用到了 (值,键) 对。当多个实体拥有相同的值的时候,键会决定返回结果。

prices = { 'AAA' : 45.23, 'ZZZ': 45.23 }
min(zip(prices.values(), prices.keys()))
# (45.23, 'AAA')
max(zip(prices.values(), prices.keys()))
# (45.23, 'ZZZ')

查找字典相同键值

a = {
    'x' : 1,
    'y' : 2,
    'z' : 3
}

b = {
    'w' : 10,
    'x' : 11,
    'y' : 2
}

# Find keys in common
a.keys() & b.keys() # { 'x', 'y' }
# Find keys in a that are not in b
a.keys() - b.keys() # { 'z' }
# Find (key,value) pairs in common
a.items() & b.items() # { ('y', 2) }

Note

  • 字典是一个键集合与值集合的映射关系,字典的 keys() 方法返回一个展现键集合的键视图对象。键视图的一个很少被了解的特性就是它们也支持集合操作,比如集合并、交、差运算。
  • 字典的 items() 方法返回一个包含 (键,值) 对的元素视图对象。这个对象同样也支持集合操作,并且可以被用来查找两个字典有哪些相同的键值对。
  • 尽管字典的 values() 方法也是类似,但是它并不支持集合操作。

删除序列重复元素并保持顺序

def dedupe(items, key=None):
    seen = set()
    for item in items:
        val = item if key is None else key(item)
        if val not in seen:
            yield item
            seen.add(val)

通过 key 参数指定函数,将序列元素如字典转换成 hashable 类型。

a = [ {'x':1, 'y':2}, {'x':1, 'y':3}, {'x':1, 'y':2}, {'x':2, 'y':4}]
list(dedupe(a, key=lambda d: (d['x'],d['y'])))
# [{'x': 1, 'y': 2}, {'x': 1, 'y': 3}, {'x': 2, 'y': 4}]
list(dedupe(a, key=lambda d: d['x']))
# [{'x': 1, 'y': 2}, {'x': 2, 'y': 4}]

命名切片

items = [0, 1, 2, 3, 4, 5, 6]
a = slice(2, 4)
items[2:4]
# [2, 3]
items[a]
# [2, 3]

# indices 映射到指定大小的序列上
a = slice(5, 50, 2)
s = 'HelloWorld'
a.indices(len(s))
# (5, 10, 2)

序列中出现次数最多的元素

words = [
    'look', 'into', 'my', 'eyes', 'look', 'into', 'my', 'eyes',
    'the', 'eyes', 'the', 'eyes', 'the', 'eyes', 'not', 'around', 'the',
    'eyes', "don't", 'look', 'around', 'the', 'eyes', 'look', 'into',
    'my', 'eyes', "you're", 'under'
]
from collections import Counter
word_counts = Counter(words)
# 出现频率最高的3个单词
top_three = word_counts.most_common(3)
print(top_three)
# Outputs [('eyes', 8), ('the', 5), ('look', 4)]

作为输入, Counter 对象可以接受任意的由可哈希(hashable)元素构成的序列对象。 在底层实现上,一个 Counter 对象就是一个字典,将元素映射到它出现的次数上。

Counter 实例一个鲜为人知的特性是它们可以很容易的跟数学运算操作相结合。Counter 对象在几乎所有需要制表或者计数数据的场合是非常有用的工具。

字典排序

from operator import itemgetter

rows = [
    {'fname': 'Brian', 'lname': 'Jones', 'uid': 1003},
    {'fname': 'David', 'lname': 'Beazley', 'uid': 1002},
    {'fname': 'John', 'lname': 'Cleese', 'uid': 1001},
    {'fname': 'Big', 'lname': 'Jones', 'uid': 1004}
]
from operator import itemgetter
rows_by_fname = sorted(rows, key=itemgetter('fname'))
rows_by_uid = sorted(rows, key=itemgetter('uid'))
print(rows_by_fname)
print(rows_by_uid)
# [{'fname': 'Big', 'uid': 1004, 'lname': 'Jones'},
# {'fname': 'Brian', 'uid': 1003, 'lname': 'Jones'},
# {'fname': 'David', 'uid': 1002, 'lname': 'Beazley'},
# {'fname': 'John', 'uid': 1001, 'lname': 'Cleese'}]
# [{'fname': 'John', 'uid': 1001, 'lname': 'Cleese'},
# {'fname': 'David', 'uid': 1002, 'lname': 'Beazley'},
# {'fname': 'Brian', 'uid': 1003, 'lname': 'Jones'},
# {'fname': 'Big', 'uid': 1004, 'lname': 'Jones'}]

itemgetter() 可多个键进行处理,比 key 指定的匿名函数的方式效率高。

排序不支持比较的对象

class User:
    def __init__(self, user_id):
        self.user_id = user_id

    def __repr__(self):
        return 'User({})'.format(self.user_id)


def sort_notcompare():
    users = [User(23), User(3), User(99)]
    print(users)
    print(sorted(users, key=lambda u: u.user_id))

from operator import attrgetter
sorted(users, key=attrgetter('user_id'))
# [User(3), User(23), User(99)]

数据分组

itertools.groupby() 函数对于这样的数据分组操作非常实用,指定字段排序后分组。groupby() 函数扫描整个序列并且查找连续相同值(或者根据指定 key 函数返回值相同)的元素序列。

一个非常重要的准备步骤是要根据指定的字段将数据排序。 因为 groupby() 仅仅检查连续的元素,如果事先并没有排序完成的话,分组函数将得不到想要的结果。

过滤列表数据

mylist = [1, 4, -5, 10, -7, 2, 3, -1]
# 列表推导
[n for n in mylist if n > 0]
# [1, 4, 10, 2, 3]
# 生成表达式
(n for n in mylist if n > 0)
# filter() 函数
values = ['1', '2', '-3', '-', '4', 'N/A', '5']
def is_int(val):
    try:
        x = int(val)
        return True
    except ValueError:
        return False
ivals = list(filter(is_int, values))
print(ivals)
# Outputs ['1', '2', '-3', '4', '5']
# compress() 函数
addresses = [
    '5412 N CLARK',
    '5148 N CLARK',
    '5800 E 58TH',
    '2122 N CLARK',
    '5645 N RAVENSWOOD',
    '1060 W ADDISON',
    '4801 N BROADWAY',
    '1039 W GRANVILLE',
]
counts = [ 0, 3, 10, 4, 1, 7, 6, 1]
# 取出 count > 5 的值
from itertools import compress
more5 = [n > 5 for n in counts]
more5
# [False, False, True, False, False, True, True, False]
list(compress(addresses, more5))
# ['5800 E 58TH', '1060 W ADDISON', '4801 N BROADWAY']

这里的关键点在于先创建一个 Boolean 序列,指示哪些元素符合条件。 然后 compress() 函数根据这个序列去选择输出对应位置为 True 的元素。

filter() 函数类似,compress() 也是返回的一个迭代器。因此,如果你需要得到一个列表, 那么你需要使用 list() 来将结果转换为列表类型。

字典子集

prices = {
    'ACME': 45.23,
    'AAPL': 612.78,
    'IBM': 205.55,
    'HPQ': 37.20,
    'FB': 10.75
}
# Make a dictionary of all prices over 200
p1 = {key: value for key, value in prices.items() if value > 200}
# Make a dictionary of tech stocks
tech_names = {'AAPL', 'IBM', 'HPQ', 'MSFT'}
p2 = {key: value for key, value in prices.items() if key in tech_names}
p2 = { key:prices[key] for key in prices.keys() & tech_names }

映射名称到序列元素

from collections import namedtuple
Subscriber = namedtuple('Subscriber', ['addr', 'joined'])
sub = Subscriber('jonesy@example.com', '2012-10-19')
sub
# Subscriber(addr='jonesy@example.com', joined='2012-10-19')
sub.addr
# 'jonesy@example.com'
sub.joined
# '2012-10-19'
sub = sub._replace(joined='2017-10-19')
# Subscriber(addr='jonesy@example.com', joined='2017-10-19')

命名元组的一个主要用途是将你的代码从下标操作中解脱出来。命名元组另一个用途就是作为字典的替代,因为字典存储需要更多的内存空间。 如果你需要构建一个非常大的包含字典的数据结构,那么使用命名元组会更加高效。 但是需要注意的是,命名元组是不可更改的

如果你真的需要改变属性的值,那么可以使用命名元组实例的 _replace() 方法, 它会创建一个全新的命名元组并将对应的字段用新的值取代。

转化并计算数据

s = sum((x * x for x in nums)) # 显式的传递一个生成器表达式对象
s = sum(x * x for x in nums) # 更加优雅的实现方式,省略了括号

合并词典或映射

from collections import ChainMap
c = ChainMap(a,b)
print(c['x']) # Outputs 1 (from a)
print(c['y']) # Outputs 2 (from b)
print(c['z']) # Outputs 3 (from a)

Note:对于字典的更新或删除操作总是影响的是列表中第一个字典

作为 ChainMap 的替代,你可能会考虑使用 update() 方法将两个字典合并。它需要你创建一个完全不同的字典对象(或者是破坏现有字典结构)。 同时,如果原字典做了更新,这种改变不会反应到新的合并字典中去

字符串和文本

多个界定符分割字符串

line = 'asdf fjdk; afed, fjek,asdf, foo'
import re
re.split(r'[;,\s]\s*', line)
# ['asdf', 'fjdk', 'afed', 'fjek', 'asdf', 'foo']
# 包含分割字符串
fields = re.split(r'(;|,|\s)\s*', line)
fields
# ['asdf', ' ', 'fjdk', ';', 'afed', ',', 'fjek', ',', 'asdf', ',', 'foo']
values = fields[::2]
delimiters = fields[1::2] + ['']
values
# ['asdf', 'fjdk', 'afed', 'fjek', 'asdf', 'foo']
delimiters
# [' ', ';', ',', ',', ',', '']
# Reform the line using the same delimiters
''.join(v+d for v,d in zip(values, delimiters))
# 'asdf fjdk;afed,fjek,asdf,foo'
# 不想保留分割字符串,使用括号分组
re.split(r'(?:,|;|\s)\s*', line)
# ['asdf', 'fjdk', 'afed', 'fjek', 'asdf', 'foo']

shell 通配符匹配字符串

from fnmatch import fnmatch, fnmatchcase
fnmatch('foo.txt', '*.txt')
# True
fnmatch('foo.txt', '?oo.txt')
# True
fnmatch('Dat45.csv', 'Dat[0-9]*')
# True
names = ['Dat1.csv', 'Dat2.csv', 'config.ini', 'foo.py']
# [name for name in names if fnmatch(name, 'Dat*.csv')]
# ['Dat1.csv', 'Dat2.csv']

字符串匹配搜索

如果你想匹配的是字面字符串,那么你通常只需要调用基本字符串方法就行, 比如 str.find() , str.endswith(), str.startswith()

对于复杂的匹配需要使用正则表达式和 re 模块。match() 总是从字符串开始去匹配,如果你想查找字符串任意部分的模式出现位置, 使用 findall() 方法去代替。

findall() 方法会搜索文本并以列表形式返回所有的匹配。 如果你想以迭代方式返回匹配,可以使用 finditer() 方法来代替。

字符串搜索替换

对于简单的字面模式,直接使用 str.replace() 方法,对于复杂的模式,请使用 re 模块中的 sub() 函数,sub() 函数中的第一个参数是被匹配的模式,第二个参数是替换模式。

如果除了替换后的结果外,你还想知道有多少替换发生了,可以使用 re.subn() 来代替。

字符串忽略大小写的搜索替换

为了在文本操作时忽略大小写,你需要在使用 re 模块的时候给这些操作提供 re.IGNORECASE 标志参数。

text = 'UPPER PYTHON, lower python, Mixed Python'
re.findall('python', text, flags=re.IGNORECASE)
# ['PYTHON', 'python', 'Python']
re.sub('python', 'snake', text, flags=re.IGNORECASE)
# 'UPPER snake, lower snake, Mixed snake'
# 替换字符串并不会自动跟被匹配字符串的大小写保持一致。 为了修复这个,你可能需要一个辅助函数
def matchcase(word):
    def replace(m):
        text = m.group()
        if text.isupper():
            return word.upper()
        elif text.islower():
            return word.lower()
        elif text[0].isupper():
            return word.capitalize()
        else:
            return word
    return replace

re.sub('python', matchcase('snake'), text, flags=re.IGNORECASE)
# 'UPPER SNAKE, lower snake, Mixed Snake'

最短匹配模式

正则 ? 非贪婪模式,模式中的*操作符后面加上?修饰符

str_pat = re.compile(r'"(.*?)"')

多行匹配模式

跨行匹配文本

comment = re.compile(r'/\*(.*?)\*/')
text1 = '/* this is a comment */'
text2 = '''/* this is a
multiline comment */
'''

comment.findall(text1)
# [' this is a comment ']
comment.findall(text2)
# []


为了修正这个问题你可以修改模式字符串增加对换行的支持比如

comment = re.compile(r'/\*((?:.|\n)*?)\*/')
comment.findall(text2)
# [' this is a\n multiline comment ']

在这个模式中, (?:.|\n) 指定了一个非捕获组 (也就是它定义了一个仅仅用来做匹配,而不能通过单独捕获或者编号的组)。

re.compile() 函数接受一个标志参数叫 re.DOTALL ,在这里非常有用。 它可以让正则表达式中的点(.)匹配包括换行符在内的任意字符。

Unicode文本标准化

使用 unicodedata 模块先将文本标准化

>>> import unicodedata
>>> t1 = unicodedata.normalize('NFC', s1)
>>> t2 = unicodedata.normalize('NFC', s2)
>>> t1 == t2
True
>>> print(ascii(t1))
'Spicy Jalape\xf1o'
>>> t3 = unicodedata.normalize('NFD', s1)
>>> t4 = unicodedata.normalize('NFD', s2)
>>> t3 == t4
True
>>> print(ascii(t3))
'Spicy Jalapen\u0303o'
>>>

删除不需要的字符

strip() 方法能用于删除开始或结尾的字符。 lstrip()rstrip() 分别从左和从右执行删除操作。但是需要注意的是去除操作不会对字符串的中间的文本产生任何影响。

审查字符

第一步是清理空白字符。为了这样做,先创建一个小的转换表格然后使用 translate() 方法:

s = 'pýtĥöñ\fis\tawesome\r\n'
remap = {
ord('\t') : ' ',
ord('\f') : ' ',
ord('\r') : None # Deleted
}
a = s.translate(remap)
a
# 'pýtĥöñ is awesome\n'

对于简单的替换操作, str.replace() 方法通常是最快的,如果你需要执行任何复杂字符对字符的重新映射或者删除操作的话, tanslate() 方法会非常的快。

字符串对齐

对于基本的字符串对齐操作,可以使用字符串的 ljust() , rjust()center() 方法。

text = 'Hello World'
text.ljust(20)
# 'Hello World         '
text.rjust(20)
# '         Hello World'
text.center(20)
# '    Hello World     '

函数 format() 同样可以用来很容易的对齐字符串。 你要做的就是使用 <,> 或者 ^ 字符后面紧跟一个指定的宽度。

format(text, '>20')
# '         Hello World'
format(text, '<20')
# 'Hello World         '
format(text, '^20')
# '    Hello World     '

合并拼接字符串

如果你想要合并的字符串是在一个序列或者 iterable 中,那么最快的方式就是使用 join() 方法。

parts = ['Is', 'Chicago', 'Not', 'Chicago?']
' '.join(parts)
# 'Is Chicago Not Chicago?'
','.join(parts)
# 'Is,Chicago,Not,Chicago?'

初看起来,这种语法看上去会比较怪,但是 join() 被指定为字符串的一个方法。 这样做的部分原因是你想去连接的对象可能来自各种不同的数据序列(比如列表,元组,字典,文件,集合或生成器等), 如果在所有这些对象上都定义一个 join() 方法明显是冗余的。 因此你只需要指定你想要的分割字符串并调用他的 join() 方法去将文本片段组合起来。

最重要的需要引起注意的是,当我们使用加号(+)操作符去连接大量的字符串的时候是非常低效率的, 因为加号连接会引起内存复制以及垃圾回收操作。

字符串中插入变量

Python并没有对在字符串中简单替换变量值提供直接的支持。 但是通过使用字符串的 format() 方法来解决这个问题。

format()format_map() 的一个缺陷就是它们并不能很好的处理变量缺失的情况。

# format_map
s.format_map(vars())
# 'Guido has 37 messages.'

'%(name) has %(n) messages.' % vars()
# 'Guido has 37 messages.'
# string template
import string
s = string.Template('$name has $n messages.')
s.substitute(vars())
# 'Guido has 37 messages.'

format()format_map() 相比较 Template 方案而已更加先进,因此应该被优先选择。 使用 format() 方法还有一个好处就是你可以获得对字符串格式化的所有支持(对齐,填充,数字格式化等待), 而这些特性是使用像模板字符串之类的方案不可能获得的。

指定列宽格式化字符串

想以指定的列宽将它们重新格式化,使用 textwrap 模块来格式化字符串的输出。

字符串令牌解析

你有一个字符串,想从左至右将其解析为一个令牌流。

text = 'foo = 23 + 42 * 10'
tokens = [('NAME', 'foo'), ('EQ','='), ('NUM', '23'), ('PLUS','+'),
          ('NUM', '42'), ('TIMES', '*'), ('NUM', '10')]
import re
NAME = r'(?P<NAME>[a-zA-Z_][a-zA-Z_0-9]*)'
NUM = r'(?P<NUM>\d+)'
PLUS = r'(?P<PLUS>\+)'
TIMES = r'(?P<TIMES>\*)'
EQ = r'(?P<EQ>=)'
WS = r'(?P<WS>\s+)'

master_pat = re.compile('|'.join([NAME, NUM, PLUS, TIMES, EQ, WS]))

在上面的模式中, ?P<TOKENNAME> 用于给一个模式命名,供后面使用。

为了令牌化,使用模式对象很少被人知道的 scanner() 方法。 这个方法会创建一个 scanner 对象, 在这个对象上不断的调用 match() 方法会一步步的扫描目标文本,每步一个匹配。

def generate_tokens(pat, text):
    Token = namedtuple('Token', ['type', 'value'])
    scanner = pat.scanner(text)
    for m in iter(scanner.match, None):
        yield Token(m.lastgroup, m.group())

# Example use
for tok in generate_tokens(master_pat, 'foo = 42'):
    print(tok)
# Produces output
# Token(type='NAME', value='foo')
# Token(type='WS', value=' ')
# Token(type='EQ', value='=')
# Token(type='WS', value=' ')
# Token(type='NUM', value='42')

Note

  • 正则表达式指定了所有输入中可能出现的文本序列
  • 长模式写在前面

令牌字符串解析

递归下降分析器

根据一组语法规则解析文本并执行命令,或者构造一个代表输入的抽象语法树。

// TODO 编译原理

递归下降解析

数字日期与时间

数字四舍五入

round(value, ndigits) 函数。

Note 当一个值刚好在两个边界的中间的时候, round 函数返回离它最近的偶数。 也就是说,对1.5或者2.5的舍入运算都会得到2。

精确浮点数计算

decimal 模块的一个主要特征是允许你控制计算的每一方面,包括数字位数和四舍五入运算。decimal 模块主要用在涉及到金融的领域。

数字的格式化输出

格式化输出单个数字的时候,可以使用内置的 format() 函数,指定宽度和精度的一般形式是 [<>^]?width[,]?(.digits)? , 其中 width 和 digits 为整数,?代表可选部分。同样的格式也被用在字符串的 format() 方法中。

进制整数

为了将整数转换为二进制、八进制或十六进制的文本串, 可以分别使用 bin() , oct()hex() 函数,如果不想输出 0b , 0o 或者 0x 的前缀的话,可以使用 format() 函数。

字节到大整数的打包与解包

data = b'\x00\x124V\x00x\x90\xab\x00\xcd\xef\x01\x00#\x004'
len(data)
# 16
# bytes 解析为整数
int.from_bytes(data, 'little')
# 69120565665751139577663547927094891008
int.from_bytes(data, 'big')
# 94522842520747284487117727783387188
# 整数转换为 bytes
x = 94522842520747284487117727783387188
x.to_bytes(16, 'big')
# b'\x00\x124V\x00x\x90\xab\x00\xcd\xef\x01\x00#\x004'
x.to_bytes(16, 'little')
# b'4\x00#\x00\x01\xef\xcd\x00\xab\x90x\x00V4\x12\x00'

复数运算

复数可以用使用函数 complex(real, imag) 或者是带有后缀j的浮点数来指定。

a = complex(2, 4)
b = 3 -5j
a
# (2+4j)
b
# (3-5j)

如果要执行其他的复数函数比如正弦、余弦或平方根,使用 cmath 模块。

无穷大与 NaN

Python并没有特殊的语法来表示这些特殊的浮点值,但是可以使用 float() 来创建它们。

a = float('inf')
b = float('-inf')
c = float('nan')
a
# inf
b
# -inf
c
# nan
# 测试值
math.isinf(a)
# True
math.isnan(c)
# True

Note

  • 无穷大数在执行数学计算的时候会传播
  • 有些操作时未定义的并会返回一个 NaN 结果
  • NaN 值会在所有操作中传播,而不会产生异常
  • NaN 值的一个特别的地方时,它们之间的比较操作总是返回 False

分数计算

fractions 模块可以被用来执行包含分数的数学运算。

from fractions import Fraction
a = Fraction(5, 4)
b = Fraction(7, 16)
print(a + b)
# 27/16
print(a * b)
# 35/64

# Getting numerator/denominator
c = a * b
c.numerator
# 35
c.denominator
# 64

# Converting to a float
float(c)
# 0.546875

# Limiting the denominator of a value
print(c.limit_denominator(8))
# 4/7

# Converting a float to a fraction
x = 3.75
y = Fraction(*x.as_integer_ratio())
y
# Fraction(15, 4)

Numpy 模块

// TODO

随机选择

random 模块有大量的函数用来产生随机数和随机选择元素。

import random
values = [1, 2, 3, 4, 5, 6]
random.choice(values)
# 2
# 随机抽取 n 个样本
random.sample(values, 2)
# [6, 2]
# 打乱顺序
random.shuffle(values)
# [2, 4, 6, 5, 3, 1]
# 生成随机数
random.randint(0,10)
# 2
# 生成0到1范围内均匀分布的浮点数
random.random()
# 0.9406677561675867
# 获取N位随机位(二进制)的整数
random.getrandbits(200)
# 335837000776573622800628485064121869519521710558559406913275

random 模块使用 Mersenne Twister 算法来计算生成随机数。

计算最后一个周五的日期

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
Topic: 最后的周五
"""
from datetime import datetime, timedelta

weekdays = ['Monday', 'Tuesday', 'Wednesday', 'Thursday',
            'Friday', 'Saturday', 'Sunday']


def get_previous_by_day(dayname, start_date=None):
    if start_date is None:
        start_date = datetime.today()
    day_num = start_date.weekday()
    day_num_target = weekdays.index(dayname)
    days_ago = (7 + day_num - day_num_target) % 7
    if days_ago == 0:
        days_ago = 7
    target_date = start_date - timedelta(days=days_ago)
    return target_date

上面的算法原理是这样的:先将开始日期和目标日期映射到星期数组的位置上(星期一索引为0), 然后通过模运算计算出目标日期要经过多少天才能到达开始日期。然后用开始日期减去那个时间差即得到结果日期。

计算当前月份的日期范围

from datetime import datetime, date, timedelta
import calendar

def get_month_range(start_date=None):
    if start_date is None:
        start_date = date.today().replace(day=1)
    _, days_in_month = calendar.monthrange(start_date.year, start_date.month)
    end_date = start_date + timedelta(days=days_in_month)
    return (start_date, end_date)

计算出一个对应月份第一天的日期,使用 date 或 datetime 对象的 replace() 方法简单的将 days 属性设置成1即可。 replace() 方法一个好处就是它会创建和你开传入对象类型相同的对象。

字符串转换为日期

from datetime import datetime
text = '2012-09-20'
y = datetime.strptime(text, '%Y-%m-%d')
z = datetime.now()
diff = z - y
diff
datetime.timedelta(3, 77824, 177393)

Note strptime() 的性能要比你想象中的差很多, 因为它是使用纯Python实现,并且必须处理所有的系统本地设置。如果你要在代码中需要解析大量的日期并且已经知道了日期字符串的确切格式,可以自己实现一套解析方案来获取更好的性能。

迭代器与生成器

手动遍历迭代器

为了手动的遍历可迭代对象,使用 next() 函数并在代码中捕获 StopIteration 异常。

def manual_iter():
    with open('/etc/passwd') as f:
        try:
            while True:
                line = next(f)
                print(line, end='')
        except StopIteration:
            pass

通常来讲, StopIteration 用来指示迭代的结尾。

代理迭代

需要定义一个 __iter__() 方法,将迭代操作代理到容器内部的对象上去。

class Node:
    def __init__(self, value):
        self._value = value
        self._children = []

    def __repr__(self):
        return 'Node({!r})'.format(self._value)

    def add_child(self, node):
        self._children.append(node)

    def __iter__(self):
        return iter(self._children)

# Example
if __name__ == '__main__':
    root = Node(0)
    child1 = Node(1)
    child2 = Node(2)
    root.add_child(child1)
    root.add_child(child2)
    # Outputs Node(1), Node(2)
    for ch in root:
        print(ch)

Python的迭代器协议需要 __iter__() 方法返回一个实现了 __next__() 方法的迭代器对象。

这里的 iter() 函数的使用简化了代码, iter(s) 只是简单的通过调用 s.__iter__() 方法来返回对应的迭代器对象, 就跟 len(s) 会调用 s.__len__() 原理是一样的。

使用生成器创建新的迭代模式

如果你想实现一种新的迭代模式,使用一个生成器函数来定义它。

def frange(start, stop, increment):
    x = start
    while x < stop:
        yield x
        x += increment

for n in frange(0, 4, 0.5):
    print(n)

# 0
# 0.5
# 1.0
# 1.5
# 2.0
# 2.5
# 3.0
# 3.5

一个生成器函数主要特征是它只会回应在迭代中使用到的 next 操作。 一旦生成器函数返回退出,迭代终止。

实现迭代器协议

// TODO 复习

生成器版本

实现一个以深度优先方式遍历树形节点的生成器。

class Node:
    def __init__(self, value):
        self._value = value
        self._children = []

    def __repr__(self):
        return 'Node({!r})'.format(self._value)

    def add_child(self, node):
        self._children.append(node)

    def __iter__(self):
        return iter(self._children)

    def depth_first(self):
        yield self
        for c in self:
            yield from c.depth_first()

# Example
if __name__ == '__main__':
    root = Node(0)
    child1 = Node(1)
    child2 = Node(2)
    root.add_child(child1)
    root.add_child(child2)
    child1.add_child(Node(3))
    child1.add_child(Node(4))
    child2.add_child(Node(5))

    for ch in root.depth_first():
        print(ch)
    # Outputs Node(0), Node(1), Node(3), Node(4), Node(2), Node(5)

它首先返回自己本身并迭代每一个子节点并通过调用子节点的 depth_first() 方法(使用 yield from 语句)返回对应元素。

迭代器版本

class Node2:
    def __init__(self, value):
        self._value = value
        self._children = []

    def __repr__(self):
        return 'Node({!r})'.format(self._value)

    def add_child(self, node):
        self._children.append(node)

    def __iter__(self):
        return iter(self._children)

    def depth_first(self):
        return DepthFirstIterator(self)


class DepthFirstIterator(object):
    '''
    Depth-first traversal
    '''

    def __init__(self, start_node):
        self._node = start_node
        self._children_iter = None
        self._child_iter = None

    def __iter__(self):
        return self

    def __next__(self):
        # Return myself if just started; create an iterator for children
        if self._children_iter is None:
            self._children_iter = iter(self._node)
            return self._node
        # If processing a child, return its next item
        # 迭代子节点
        elif self._child_iter:
            try:
                nextchild = next(self._child_iter)
                return nextchild
            except StopIteration:
                self._child_iter = None
                return next(self)
        # Advance to the next child and start its iteration
        else:
            self._child_iter = next(self._children_iter).depth_first()
            return next(self)

Python的迭代协议要求一个 __iter__() 方法返回一个特殊的迭代器对象, 这个迭代器对象实现了 __next__() 方法并通过 StopIteration 异常标识迭代的完成。

反向迭代

使用内置的 reversed() 函数。反向迭代仅仅当对象的大小可预先确定或者对象实现了 __reversed__() 的特殊方法时才能生效。 如果两者都不符合,那你必须先将对象转换为一个列表才行。

带有外部状态的生成器函数

如果你想让你的生成器暴露外部状态给用户, 别忘了你可以简单的将它实现为一个类,然后把生成器函数放到 __iter__() 方法中过去。

from collections import deque

class linehistory:
    def __init__(self, lines, histlen=3):
        self.lines = lines
        self.history = deque(maxlen=histlen)

    def __iter__(self):
        for lineno, line in enumerate(self.lines, 1):
            self.history.append((lineno, line))
            yield line

    def clear(self):
        self.history.clear()

迭代器切片

函数 itertools.islice() 正好适用于在迭代器和生成器上做切片操作。

这里要着重强调的一点是 islice() 会消耗掉传入的迭代器中的数据。 必须考虑到迭代器是不可逆的这个事实。

跳过可迭代对象的开始部分

itertools.dropwhile() 函数。使用时,你给它传递一个函数对象和一个可迭代对象。 它会返回一个迭代器对象,丢弃原有序列中直到函数返回Flase之前的所有元素,然后返回后面所有元素。

# 去除开始注释行
from itertools import dropwhile
with open('/etc/passwd') as f:
    for line in dropwhile(lambda line: line.startswith('#'), f):
        print(line, end='')

排列组合的迭代

itertools.permutations() 接受一个集合并产生一个元组序列,每个元组由集合中所有元素的一个可能排列组成。

itertools.combinations() 可得到输入集合中元素的所有的组合。

在计算组合的时候,一旦元素被选取就会从候选中剔除掉(比如如果元素’a’已经被选取了,那么接下来就不会再考虑它了)。 而函数 itertools.combinations_with_replacement() 允许同一个元素被选择多次。

同时迭代多个序列

为了同时迭代多个序列,使用 zip() 函数。zip() 会生成一个可返回元组 (x, y) 的迭代器,迭代长度跟参数中最短序列长度一致

如果需要最长的序列遍历,那么可以使用 itertools.zip_longest() 函数来代替,fillvalue 参数指定没有对应值的默认值。

不同集合上元素的迭代

itertools.chain() 接受一个或多个可迭代对象作为输入参数。 然后创建一个迭代器,依次连续的返回每个可迭代对象中的元素。 这种方式要比先将序列合并再迭代要高效的多。

from itertools import chain
a = [1, 2, 3, 4]
b = ['x', 'y', 'z']
for x in chain(a, b):
print(x)
# 1
# 2
# 3
# 4
# x
# y
# z

数据处理管道

目的:以数据管道(类似Unix管道)的方式迭代处理数据。

import os
import fnmatch
import gzip
import bz2
import re

def gen_find(filepat, top):
    '''
    Find all filenames in a directory tree that match a shell wildcard pattern
    '''
    for path, dirlist, filelist in os.walk(top):
        for name in fnmatch.filter(filelist, filepat):
            yield os.path.join(path,name)

def gen_opener(filenames):
    '''
    Open a sequence of filenames one at a time producing a file object.
    The file is closed immediately when proceeding to the next iteration.
    '''
    for filename in filenames:
        if filename.endswith('.gz'):
            f = gzip.open(filename, 'rt')
        elif filename.endswith('.bz2'):
            f = bz2.open(filename, 'rt')
        else:
            f = open(filename, 'rt')
        yield f
        f.close()

def gen_concatenate(iterators):
    '''
    Chain a sequence of iterators together into a single sequence.
    '''
    for it in iterators:
        # 展开生成的生成器
        yield from it
        # = for i in it:
        #       yield i

def gen_grep(pattern, lines):
    '''
    Look for a regex pattern in a sequence of lines
    '''
    pat = re.compile(pattern)
    for line in lines:
        if pat.search(line):
            yield line

yield 语句作为数据的生产者而 for 循环语句作为数据的消费者。gen_concatenate() 函数中出现 yield from 语句,目的是将输入序列拼接成一个很长的行序列,并将 yield 操作代理到父生成器上去。 语句 yield from it 简单的返回生成器 it 所产生的所有值,yield from 在生成器中调用其他生成器作为子例程的时候非常有用。

嵌套的序列

from collections import Iterable

def flatten(items, ignore_types=(str, bytes)):
    for x in items:
        if isinstance(x, Iterable) and not isinstance(x, ignore_types):
            yield from flatten(x)
#            for i in flatten(x):
#                yield i
        else:
            yield x

items = [1, 2, [3, 4, [5, 6], 7], 8]
# Produces 1 2 3 4 5 6 7 8
for x in flatten(items):
    print(x)

语句 yield from 在生成器中调用其他生成器作为子例程的时候非常有用。如果不使用它,那么就必须写额外的 for 循环了。

顺序迭代合并后的排序迭代对象

有一系列排序序列,想将它们合并后得到一个排序序列并在上面迭代遍历。

import heapq
a = [1, 4, 7, 10]
b = [2, 5, 6, 11]
for c in heapq.merge(a, b):
    print(c)
# 1
# 2
# 4
# 5
# 6
# 7
# 10
# 11

heapq.merge 可迭代特性意味着它不会立马读取所有序列。 这就意味着你可以在非常长的序列中使用它,而不会有太大的开销。

heapq.merge() 需要所有输入序列必须是排序。 特别的,它并不会预先读取所有数据到堆栈中或者预先排序,也不会对输入做任何的排序检测。 它仅仅是检查所有序列的开始部分并返回最小的那个,这个过程一直会持续直到所有输入序列中的元素都被遍历完。

迭代器代替 while 无限循环

CHUNKSIZE = 8192

def reader(s):
    while True:
        data = s.recv(CHUNKSIZE)
        if data == b'':
            break
        process_data(data)
# 迭代器操作
def reader2(s):
    for chunk in iter(lambda: s.recv(CHUNKSIZE), b''):
        pass
        # process_data(data)

iter 函数一个鲜为人知的特性是它接受一个可选的 callable 对象和一个标记(结尾)值作为输入参数。这种特殊的方法对于一些特定的会被重复调用的函数很有效果,比如涉及到I/O调用的函数。

文件与 IO

字符串的I/O操作

使用 io.StringIO()io.BytesIO() 类来创建类文件对象操作字符串数据。

s = io.StringIO()
s.write('Hello World\n')
# 12
print('This is a test', file=s)
# 15
# Get all of the data written so far
s.getvalue()
# 'Hello World\nThis is a test\n'

当你想模拟一个普通的文件的时候 StringIO 和 BytesIO 类是很有用的。需要注意的是, StringIO 和 BytesIO 实例并没有整数类型的文件描述符。 因此,它们不能在那些需要使用真实的系统级文件如文件,管道或者是套接字的程序中使用。

固定大小记录的文件迭代

from functools import partial

RECORD_SIZE = 32

with open('somefile.data', 'rb') as f:
    records = iter(partial(f.read, RECORD_SIZE), b'')
    for r in records:
        ...

records 对象是一个可迭代对象,它会不断的产生固定大小的数据块,直到文件末尾。 要注意的是如果总记录大小不是块大小的整数倍的话,最后一个返回元素的字节数会比期望值少

读取二进制数据到可变缓冲区中

import os.path

def read_into_buffer(filename):
    buf = bytearray(os.path.getsize(filename))
    with open(filename, 'rb') as f:
        f.readinto(buf)
    return buf

和普通 read() 方法不同的是, readinto() 填充已存在的缓冲区而不是为新对象重新分配内存再返回它们。

另外有一个有趣特性就是 memoryview , 它可以通过零复制的方式对已存在的缓冲区执行切片操作,甚至还能修改它的内容。

使用 f.readinto() 时需要注意的是,你必须检查它的返回值,也就是实际读取的字节数。如果字节数小于缓冲区大小,表明数据被截断或者被破坏了(比如你期望每次读取指定数量的字节)。

内存映射的二进制文件

import os
import mmap

def memory_map(filename, access=mmap.ACCESS_WRITE):
    size = os.path.getsize(filename)
    fd = os.open(filename, os.O_RDWR)
    return mmap.mmap(fd, size, access=access)

# 为了使用这个函数,你需要有一个已创建并且内容不为空的文件。 下面是一个例子,教你怎样初始创建一个文件并将其内容扩充到指定大小:

size = 1000000
with open('data', 'wb') as f:
    f.seek(size-1)
    f.write(b'\x00')
m = memory_map('data')
len(m)
# 1000000
# Memoryview of unsigned integers
v = memoryview(m).cast('I')
v[0] = 7
m[0:4]
# b'\x07\x00\x00\x00'

内存映射一个文件并不会导致整个文件被读取到内存中。 也就是说,文件并没有被复制到内存缓存或数组中。相反,操作系统仅仅为文件内容保留了一段虚拟内存。当你访问文件的不同区域时,这些区域的内容才根据需要被读取并映射到内存区域中。 而那些从没被访问到的部分还是留在磁盘上。

打印不合法的文件名

详细

问题:打印文件名的时候程序崩溃, 出现了 UnicodeEncodeError 异常和一条奇怪的消息—— surrogates not allowed 。

def bad_filename(filename):
    return repr(filename)[1:-1]

# 重新编码

def bad_filename2(filename):
    temp = filename.encode(sys.getfilesystemencoding(), errors='surrogateescape')
    return temp.decode('latin-1')

try:
    print(filename)
except UnicodeEncodeError:
    print(bad_filename(filename))
surrogateescape:
这种是Python在绝大部分面向OS的API中所使用的错误处理器,
它能以一种优雅的方式处理由操作系统提供的数据的编码问题。
在解码出错时会将出错字节存储到一个很少被使用到的Unicode编码范围内。
在编码时将那些隐藏值又还原回原先解码失败的字节序列。
它不仅对于OS API非常有用,也能很容易的处理其他情况下的编码错误。

当执行类似 os.listdir() 这样的函数时,这些不合规范的文件名就会让 Python 陷入困境。 一方面,它不能仅仅只是丢弃这些不合格的名字。而另一方面,它又不能将这些文件名转换为正确的文本字符串。 Python 对这个问题的解决方案是从文件名中获取未解码的字节值比如 \xhh 并将它映射成 Unicode 字符 \udchh 表示的所谓的”代理编码”。当你想要输出文件名时才会碰到些麻烦(比如打印输出到屏幕或日志文件等).

增加或改变已打开文件的编码

如果你想给一个以二进制模式打开的文件添加 Unicode 编码/解码方式, 可以使用 io.TextIOWrapper() 对象包装它。

import urllib.request
import io

u = urllib.request.urlopen('http://www.python.org')
f = io.TextIOWrapper(u, encoding='utf-8')
text = f.read()

将字节写入文本文件

能够通过读取文本文件的 buffer 属性来读取二进制数据。

I/O系统以层级结构的形式构建而成。 文本文件是通过在一个拥有缓冲的二进制模式文件上增加一个Unicode编码/解码层来创建。 buffer 属性指向对应的底层文件。如果你直接访问它的话就会绕过文本编码/解码层。

>>> f = open('sample.txt','w')
>>> f
<_io.TextIOWrapper name='sample.txt' mode='w' encoding='UTF-8'>
>>> f.buffer
<_io.BufferedWriter name='sample.txt'>
>>> f.buffer.raw
<_io.FileIO name='sample.txt' mode='wb'>

io.TextIOWrapper 是一个编码和解码 Unicode 的文本处理层, io.BufferedWriter 是一个处理二进制数据的带缓冲的I/O层, io.FileIO 是一个表示操作系统底层文件描述符的原始文件。 增加或改变文本编码会涉及增加或改变最上面的 io.TextIOWrapper 层。

文件描述符包装成文件对象

一个文件描述符和一个打开的普通文件是不一样的。 文件描述符仅仅是一个由操作系统指定的整数,用来指代某个系统的I/O通道。 如果你碰巧有这么一个文件描述符,你可以通过使用 open() 函数来将其包装为一个Python的文件对象。 你仅仅只需要使用这个整数值的文件描述符作为第一个参数来代替文件名即可。

创建临时文件和文件夹

TemporaryFile()NamedTemporaryFile()TemporaryDirectory() 函数应该是处理临时文件目录的最简单的方式了,因为它们会自动处理所有的创建和清理步骤。

序列化Python对象

对于序列化最普遍的做法就是使用 pickle 模块。

dump()dumps()load()loads()函数。

有些类型的对象是不能被序列化的。这些通常是那些依赖外部系统状态的对象, 比如打开的文件,网络连接,线程,进程,栈帧等等。 用户自定义类可以通过提供 __getstate__()__setstate__() 方法来绕过这些限制。

数据编码和处理

读写JSON数据

json 模块提供了一种很简单的方式来编码和解码JSON数据。 其中两个主要的函数是 json.dumps()json.loads()

JSON 编码的格式对于 Python 语法而已几乎是完全一样的,除了一些小的差异之外。 比如,True 会被映射为 true,False 被映射为 false,而 None 会被映射为 null。

JSON字典与 Python 对象转换

def serialize_instance(obj):
    d = { '__classname__' : type(obj).__name__ }
    d.update(vars(obj))
    return d

# Dictionary mapping names to known classes
classes = {
    'Point' : Point
}

def unserialize_object(d):
    clsname = d.pop('__classname__', None)
    if clsname:
        cls = classes[clsname]
        obj = cls.__new__(cls) # Make instance without calling __init__
        for key, value in d.items():
            setattr(obj, key, value)
        return obj
    else:
        return d

p = Point(2,3)
s = json.dumps(p, default=serialize_instance)
s
# '{"__classname__": "Point", "y": 3, "x": 2}'
a = json.loads(s, object_hook=unserialize_object)
a
# main__.Point object at 0x1017577d0>

XML 操作

from urllib.request import urlopen
from xml.etree.ElementTree import parse


# Download the RSS feed and parse it
u = urlopen('http://planet.python.org/rss20.xml')
doc = parse(u)

# Extract and output tags of interest
for item in doc.iterfind('channel/item'):
    title = item.findtext('title')
    date = item.findtext('pubDate')
    link = item.findtext('link')
    des = item.findtext('description')

    print(title)
    print(date)
    print(link)
    print(des)

增量解析大型 XML 文件

from xml.etree.ElementTree import iterparse

def parse_and_remove(filename, path):
    path_parts = path.split('/')
    doc = iterparse(filename, ('start', 'end'))
    # Skip the root element
    next(doc)

    tag_stack = []
    elem_stack = []
    for event, elem in doc:
        if event == 'start':
            tag_stack.append(elem.tag)
            elem_stack.append(elem)
        elif event == 'end':
            if tag_stack == path_parts:
                yield elem
                elem_stack[-2].remove(elem)
            try:
                tag_stack.pop()
                elem_stack.pop()
            except IndexError:
                pass

第一,iterparse() 方法允许对XML文档进行增量操作。 使用时,你需要提供文件名和一个包含下面一种或多种类型的事件列表: start , end, start-ns 和 end-ns 。 由 iterparse() 创建的迭代器会产生形如 (event, elem) 的元组, 其中 event 是上述事件列表中的某一个,而 elem 是相应的XML元素。

start 事件在某个元素第一次被创建并且还没有被插入其他数据(如子元素)时被创建。 而 end 事件在某个元素已经完成时被创建。

修改文档

>>> from xml.etree.ElementTree import parse, Element
>>> doc = parse('pred.xml')
>>> root = doc.getroot()
>>> root
<Element 'stop' at 0x100770cb0>

>>> # Remove a few elements
>>> root.remove(root.find('sri'))
>>> root.remove(root.find('cr'))
>>> # Insert a new element after <nm>...</nm>
>>> root.getchildren().index(root.find('nm'))
1
>>> e = Element('spam')
>>> e.text = 'This is a test'
>>> root.insert(2, e)

>>> # Write back to a file
>>> doc.write('newpred.xml', xml_declaration=True)

命名空间解析XML文档

class XMLNamespaces:
    def __init__(self, **kwargs):
        self.namespaces = {}
        for name, uri in kwargs.items():
            self.register(name, uri)
    def register(self, name, uri):
        self.namespaces[name] = '{'+uri+'}'
    def __call__(self, path):
        return path.format_map(self.namespaces)

>>> ns = XMLNamespaces(html='http://www.w3.org/1999/xhtml')
>>> doc.find(ns('content/{html}html'))
<Element '{http://www.w3.org/1999/xhtml}html' at 0x1007767e0>
>>> doc.findtext(ns('content/{html}html/{html}head/{html}title'))
'Hello World'

编码和解码十六进制数

如果你只是简单的解码或编码一个十六进制的原始字符串,可以使用 binascii 模块。

# Initial byte string
s = b'hello'
# Encode as hex
import binascii
h = binascii.b2a_hex(s)
h
# b'68656c6c6f'
# Decode back to bytes
binascii.a2b_hex(h)
# b'hello'

>>> import base64
>>> h = base64.b16encode(s)
>>> h
b'68656C6C6F'
>>> base64.b16decode(h)
b'hello'

函数 base64.b16decode()base64.b16encode() 只能操作大写形式的十六进制字母, 而 binascii 模块中的函数大小写都能处理。

当解码Base64的时候,字节字符串和Unicode文本都可以作为参数。 但是,Unicode字符串只能包含ASCII字符。

编码解码Base64数据

# Some byte data
s = b'hello'
import base64

# Encode as Base64
a = base64.b64encode(s)
a
# b'aGVsbG8='

# Decode from Base64
base64.b64decode(a)
# b'hello'

Base64编码仅仅用于面向字节的数据比如字节字符串和字节数组。 此外,编码处理的输出结果总是一个字节字符串。

当解码Base64的时候,字节字符串和Unicode文本都可以作为参数。 但是,Unicode字符串只能包含ASCII字符。

Struct

struct 模块处理二进制数据。

from struct import Struct
def write_records(records, format, f):
    '''
    Write a sequence of tuples to a binary file of structures.
    '''
    record_struct = Struct(format)
    for r in records:
        f.write(record_struct.pack(*r))

# Example
if __name__ == '__main__':
    records = [ (1, 2.3, 4.5),
                (6, 7.8, 9.0),
                (12, 13.4, 56.7) ]
    with open('data.b', 'wb') as f:
        write_records(records, '<idd', f)

# 增量读取文件

def read_records(format, f):
    record_struct = Struct(format)
    chunks = iter(lambda: f.read(record_struct.size), b'')
    return (record_struct.unpack(chunk) for chunk in chunks)

# Example
if __name__ == '__main__':
    with open('data.b','rb') as f:
        for rec in read_records('<idd', f):
            # Process rec

# 一次性读取

def unpack_records(format, data):
    record_struct = Struct(format)
    return (record_struct.unpack_from(data, offset)
            for offset in range(0, len(data), record_struct.size))
# Example
if __name__ == '__main__':
    with open('data.b', 'rb') as f:
        data = f.read()
    for rec in unpack_records('<idd', data):

产生的 Struct 实例有很多属性和方法用来操作相应类型的结构。 size 属性包含了结构的字节数,这在I/O操作时非常有用。 pack()unpack() 方法被用来打包和解包数据。

unpack_from() 对于从一个大型二进制数组中提取二进制数据非常有用, 因为它不会产生任何的临时对象或者进行内存复制操作。

可变长度二进制数据

struct 模块可被用来编码/解码几乎所有类型的二进制的数据结构。

数据结构定义

polys = [
    [ (1.0, 2.5), (3.5, 4.0), (2.5, 1.5) ],
    [ (7.0, 1.2), (5.1, 3.0), (0.5, 7.5), (0.8, 9.0) ],
    [ (3.4, 6.3), (1.2, 0.5), (4.6, 9.2) ],
]

# 头部文件
+------+--------+------------------------------------+
|Byte  | Type   |  Description                       |
+======+========+====================================+
|0     | int    |  文件代码0x1234小端          |
+------+--------+------------------------------------+
|4     | double |  x 的最小值小端                |
+------+--------+------------------------------------+
|12    | double |  y 的最小值小端                |
+------+--------+------------------------------------+
|20    | double |  x 的最大值小端                |
+------+--------+------------------------------------+
|28    | double |  y 的最大值小端                |
+------+--------+------------------------------------+
|36    | int    |  三角形数量小端                |
+------+--------+------------------------------------+

# 多边形记录
+------+--------+-------------------------------------------+
|Byte  | Type   |  Description                              |
+======+========+===========================================+
|0     | int    |  记录长度N字节                        |
+------+--------+-------------------------------------------+
|4-N   | Points |  (X,Y) 坐标以浮点数表示                 |
+------+--------+-------------------------------------------+

直接代码解析

import struct
import itertools

def write_polys(filename, polys):
    # Determine bounding box
    flattened = list(itertools.chain(*polys))
    min_x = min(x for x, y in flattened)
    max_x = max(x for x, y in flattened)
    min_y = min(y for x, y in flattened)
    max_y = max(y for x, y in flattened)
    with open(filename, 'wb') as f:
        f.write(struct.pack('<iddddi', 0x1234,
                            min_x, min_y,
                            max_x, max_y,
                            len(polys)))
        for poly in polys:
            size = len(poly) * struct.calcsize('<dd')
            f.write(struct.pack('<i', size + 4))
            for pt in poly:
                f.write(struct.pack('<dd', *pt))

def read_polys(filename):
    with open(filename, 'rb') as f:
        # Read the header
        header = f.read(40)
        file_code, min_x, min_y, max_x, max_y, num_polys = \
            struct.unpack('<iddddi', header)
        polys = []
        for n in range(num_polys):
            pbytes, = struct.unpack('<i', f.read(4))
            poly = []
            for m in range(pbytes // 16):
                pt = struct.unpack('<dd', f.read(16))
                poly.append(pt)
            polys.append(poly)
    return polys

高级抽象形式

# 数据结构类
import struct

class StructField:
    '''
    Descriptor representing a simple structure field
    '''
    def __init__(self, format, offset):
        self.format = format
        self.offset = offset
    def __get__(self, instance, cls):
        if instance is None:
            return self
        else:
            r = struct.unpack_from(self.format, instance._buffer, self.offset)
            return r[0] if len(r) == 1 else r

class Structure:
    def __init__(self, bytedata):
        self._buffer = memoryview(bytedata)

# 头部信息
class PolyHeader(Structure):
    file_code = StructField('<i', 0)
    min_x = StructField('<d', 4)
    min_y = StructField('<d', 12)
    max_x = StructField('<d', 20)
    max_y = StructField('<d', 28)
    num_polys = StructField('<i', 36)

>>> f = open('polys.bin', 'rb')
>>> phead = PolyHeader(f.read(40))
>>> phead.file_code == 0x1234
True
>>> phead.min_x
0.5
>>> phead.min_y
0.5
>>> phead.max_x
7.0
>>> phead.max_y
9.2
>>> phead.num_polys
3

# 元类方法自动计算结构体偏移
class StructureMeta(type):
    '''
    Metaclass that automatically creates StructField descriptors
    '''
    def __init__(self, clsname, bases, clsdict):
        fields = getattr(self, '_fields_', [])
        byte_order = ''
        offset = 0
        for format, fieldname in fields:
            if format.startswith(('<','>','!','@')):
                byte_order = format[0]
                format = format[1:]
            format = byte_order + format
            setattr(self, fieldname, StructField(format, offset))
            offset += struct.calcsize(format)
        setattr(self, 'struct_size', offset)

class Structure(metaclass=StructureMeta):
    def __init__(self, bytedata):
        self._buffer = bytedata

    @classmethod
    def from_file(cls, f):
        return cls(f.read(cls.struct_size))

class PolyHeader(Structure):
    _fields_ = [
        ('<i', 'file_code'),
        ('d', 'min_x'),
        ('d', 'min_y'),
        ('d', 'max_x'),
        ('d', 'max_y'),
        ('i', 'num_polys')
    ]

# 嵌套字节结构

class NestedStruct:
    '''
    Descriptor representing a nested structure
    '''
    def __init__(self, name, struct_type, offset):
        self.name = name
        self.struct_type = struct_type
        self.offset = offset

    def __get__(self, instance, cls):
        if instance is None:
            return self
        else:
            data = instance._buffer[self.offset:
                            self.offset+self.struct_type.struct_size]
            result = self.struct_type(data)
            # Save resulting structure back on instance to avoid
            # further recomputation of this step
            setattr(instance, self.name, result)
            return result

class StructureMeta(type):
    '''
    Metaclass that automatically creates StructField descriptors
    '''
    def __init__(self, clsname, bases, clsdict):
        fields = getattr(self, '_fields_', [])
        byte_order = ''
        offset = 0
        for format, fieldname in fields:
            if isinstance(format, StructureMeta):
                setattr(self, fieldname,
                        NestedStruct(fieldname, format, offset))
                offset += format.struct_size
            else:
                if format.startswith(('<','>','!','@')):
                    byte_order = format[0]
                    format = format[1:]
                format = byte_order + format
                setattr(self, fieldname, StructField(format, offset))
                offset += struct.calcsize(format)
        setattr(self, 'struct_size', offset)

class Point(Structure):
    _fields_ = [
        ('<d', 'x'),
        ('d', 'y')
    ]

class PolyHeader(Structure):
    _fields_ = [
        ('<i', 'file_code'),
        (Point, 'min'), # nested struct
        (Point, 'max'), # nested struct
        ('i', 'num_polys')
    ]

>>> f = open('polys.bin', 'rb')
>>> phead = PolyHeader.from_file(f)
>>> phead.file_code == 0x1234
True
>>> phead.min # Nested structure
<__main__.Point object at 0x1006a48d0>
>>> phead.min.x
0.5
>>> phead.min.y
0.5
>>> phead.max.x
7.0
>>> phead.max.y
9.2
>>> phead.num_polys
3

# 变长数据解析

class SizedRecord:
    def __init__(self, bytedata):
        self._buffer = memoryview(bytedata)

    @classmethod
    def from_file(cls, f, size_fmt, includes_size=True):
        sz_nbytes = struct.calcsize(size_fmt)
        sz_bytes = f.read(sz_nbytes)
        sz, = struct.unpack(size_fmt, sz_bytes)
        buf = f.read(sz - includes_size * sz_nbytes)
        return cls(buf)

    def iter_as(self, code):
        if isinstance(code, str):
            s = struct.Struct(code)
            for off in range(0, len(self._buffer), s.size):
                yield s.unpack_from(self._buffer, off)
        elif isinstance(code, StructureMeta):
            size = code.struct_size
            for off in range(0, len(self._buffer), size):
                data = self._buffer[off:off+size]
                yield code(data)

# 结合
class Point(Structure):
    _fields_ = [
        ('<d', 'x'),
        ('d', 'y')
    ]

class PolyHeader(Structure):
    _fields_ = [
        ('<i', 'file_code'),
        (Point, 'min'),
        (Point, 'max'),
        ('i', 'num_polys')
    ]

def read_polys(filename):
    polys = []
    with open(filename, 'rb') as f:
        phead = PolyHeader.from_file(f)
        for n in range(phead.num_polys):
            rec = SizedRecord.from_file(f, '<i')
            poly = [ (p.x, p.y) for p in rec.iter_as(Point) ]
            polys.append(poly)
    return polys

当一个 Structure 实例被创建时, __init__() 仅仅只是创建一个字节数据的内存视图,没有做其他任何事。

为了实现懒解包和打包,需要使用 StructField 描述器类。 用户在 _fields_ 中列出来的每个属性都会被转化成一个 StructField 描述器, 它将相关结构格式码和偏移值保存到存储缓存中。元类 StructureMeta 在多个结构类被定义时自动创建了这些描述器。

StructureMeta 的一个很微妙的地方就是它会固定字节数据顺序。也就是说,如果任意的属性指定了一个字节顺序(<表示低位优先 或者 >表示高位优先), 那后面所有字段的顺序都以这个顺序为准。

函数

匿名函数捕获变量值

如果你想让某个匿名函数在定义时就捕获到值,可以将那个参数值定义成默认参数即可。

x = 10
a = lambda y, x=x: x + y
x = 20
b = lambda y, x=x: x + y
a(10)
# 20
b(10)
# 30

带额外状态信息的回调函数

def apply_asnc(func, args, *, callback):
    result = func(*args)
    callback(result)

# 回调函数

def print_result(result):
    print('Got:', result)

def add(x, y):
    return x + y

apply_async(add, (2, 3), callback=print_result)
# Got: 5
apply_async(add, ('hello', 'world'), callback=print_result)
# Got: helloworld

使回调函数访问其他变量或者特定环境的变量值

类实例

class ResultHandler(object):

    def __init__(self):
        self.sequence = 0

    def handler(self, result):
        self.sequence += 1
        print('[{}] Got: {}'.format(self.sequenceee, result))

r = ResultHandler()
apply_async(add, (2, 3), callback=r.handler)
# [1] Got: 5
apply_async(add, ('hello', 'world'), callback=r.handler)
# [2] Got: helloworld

函数闭包

def make_handler():
    sequence = 0
    def handler(result):
        nonlocal sequence
        sequence += 1
        print('[{}] Got: {}'.format(sequence, result))
    return handler

handler = make_handler()
apply_async(add, (2, 3), callback=handler)
# [1] Got: 5
apply_async(add, ('hello', 'world'), callback=handler)
# [2] Got: helloworld

协程方法

def make_handler():
    sequence = 0
    while True:
        result = yield
        sequence += 1
        print('[{}] Got: {}'.format(sequence, result))

handler = make_handler()
next(handler) # Advance to the yield
apply_async(add, (2, 3), callback=handler.send)
# [1] Got: 5
apply_async(add, ('hello', 'world'), callback=handler.send)
# [2] Got: helloworld

# 仅仅只需要给回调函数传递额外的值

apply_asyncadd, (2, 3), callback=lambda r: handler(r, seq))
# [1] Got: 5

至少有两种主要方式来捕获和保存状态信息,你可以在一个对象实例(通过一个绑定方法)或者在一个闭包中保存它。 两种方式相比,闭包或许是更加轻量级和自然一点,因为它们可以很简单的通过函数来构造。

内联回调函数

当你编写使用回调函数的代码的时候,担心很多小函数的扩张可能会弄乱程序控制流。

from queue import Queue
from functools import wraps

def apply_async(func, args, *, callback):
    # Compute the result
    result = func(*args)

    # Invoke the callback with the result
    callback(result)

class Async:
    def __init__(self, func, args):
        self.func = func
        self.args = args

def inlined_async(func):
    @wraps(func)
    def wrapper(*args):
        f = func(*args)
        result_queue = Queue()
        result_queue.put(None)
        while True:
            result = result_queue.get()
            try:
                a = f.send(result)
                apply_async(a.func, a.args, callback=result_queue.put)
            except StopIteration:
                break
    return wrapper

def add(x, y):
    return x + y

@inlined_async
def test():
    r = yield Async(add, (2, 3))
    print(r)
    r = yield Async(add, ('hello', 'world'))
    print(r)
    for n in range(10):
        r = yield Async(add, (n, n))
        print(r)
    print('Goodbye')

test()
# 5
# helloworld
# 0
# 2
# 4
# 6
# 8
# 10
# 12
# 14
# 16
# 18
# Goodbye

yield 操作会使一个生成器函数产生一个值并暂停。 接下来调用生成器的 __next__()send() 方法又会让它从暂停处继续执行。

核心就在 inline_async() 装饰器函数中了。 关键点就是,装饰器会逐步遍历生成器函数的所有 yield 语句,每一次一个。 为了这样做,刚开始的时候创建了一个 result 队列并向里面放入一个 None 值。 然后开始一个循环操作,从队列中取出结果值并发送给生成器,它会持续到下一个 yield 语句, 在这里一个 Async 的实例被接受到。然后循环开始检查函数和参数,并开始进行异步计算 apply_async() 。 然而,这个计算有个最诡异部分是它并没有使用一个普通的回调函数,而是用队列的 put() 方法来回调。

将复杂的控制流隐藏到生成器函数背后的例子在标准库和第三方包中都能看到。 比如,在 contextlib 中的 @contextmanager 装饰器使用了一个令人费解的技巧, 通过一个 yield 语句将进入和离开上下文管理器粘合在一起。

访问闭包中定义的变量

扩展函数中的某个闭包,允许它能访问和修改函数的内部变量。

import sys
class ClosureInstance:
    def __init__(self, locals=None):
        if locals is None:
            locals = sys._getframe(1).f_locals

        # Update instance dictionary with callables
        self.__dict__.update((key,value) for key, value in locals.items()
                            if callable(value) )
    # Redirect special methods
    def __len__(self):
        return self.__dict__['__len__']()

# Example use
def Stack():
    items = []
    def push(item):
        items.append(item)

    def pop():
        return items.pop()

    def __len__():
        return len(items)

    return ClosureInstance()

类与对象

让对象支持上下文管理协议

为了让一个对象兼容 with 语句,你需要实现 __enter__()__exit__() 方法。

from socket import socket, AF_INET, SOCK_STREAM

class LazyConnection:
    def __init__(self, address, family=AF_INET, type=SOCK_STREAM):
        self.address = address
        self.family = family
        self.type = type
        self.sock = None

    def __enter__(self):
        if self.sock is not None:
            raise RuntimeError('Already connected')
        self.sock = socket(self.family, self.type)
        self.sock.connect(self.address)
        return self.sock

    def __exit__(self, exc_ty, exc_val, tb):
        self.sock.close()
        self.sock = None

from functools import partial

conn = LazyConnection(('www.python.org', 80))
# Connection closed
with conn as s:
    # conn.__enter__() executes: connection open
    s.send(b'GET /index.html HTTP/1.0\r\n')
    s.send(b'Host: www.python.org\r\n')
    s.send(b'\r\n')
    resp = b''.join(iter(partial(s.recv, 8192), b''))
    # conn.__exit__() executes: connection closed

# 支持嵌套版本

from socket import socket, AF_INET, SOCK_STREAM

class LazyConnection:
    def __init__(self, address, family=AF_INET, type=SOCK_STREAM):
        self.address = address
        self.family = family
        self.type = type
        self.connections = []

    def __enter__(self):
        sock = socket(self.family, self.type)
        sock.connect(self.address)
        self.connections.append(sock)
        return sock

    def __exit__(self, exc_ty, exc_val, tb):
        self.connections.pop().close()

# Example use
from functools import partial

conn = LazyConnection(('www.python.org', 80))
with conn as s1:
    pass
    with conn as s2:
        pass
        # s1 and s2 are independent sockets

编写上下文管理器的主要原理是你的代码会放到 with 语句块中执行。 当出现 with 语句的时候,对象的 __enter__() 方法被触发, 它返回的值(如果有的话)会被赋值给 as 声明的变量。然后,with 语句块里面的代码开始执行。 最后,__exit__() 方法被触发进行清理工作,__exit__() 方法的第三个参数包含了异常类型、异常值和追溯信息(如果有的话)。

创建大量对象节省内存方法

对于主要是用来当成简单的数据结构的类而言,你可以通过给类添加 slots 属性来极大的减少实例所占的内存。

关于 slots 的一个常见误区是它可以作为一个封装工具来防止用户给实例增加新的属性。 尽管使用slots可以达到这样的目的,但是这个并不是它的初衷。 slots 更多的是用来作为一个内存优化工具。

在类中封装属性名

大多数而言,你应该让你的非公共名称以单下划线开头。但是,如果代码会涉及到子类, 并且有些内部属性应该在子类中隐藏起来,那么才考虑使用双下划线方案,这种属性通过继承是无法被覆盖的。定义的一个变量和某个保留关键字冲突,这时候可以使用单下划线作为后缀。

创建可管理的属性

增加除访问与修改之外的其他处理逻辑。

class Person:
    def __init__(self, first_name):
        self.first_name = first_name

    # Getter function
    @property
    def first_name(self):
        return self._first_name

    # Setter function
    @first_name.setter
    def first_name(self, value):
        if not isinstance(value, str):
            raise TypeError('Expected a string')
        self._first_name = value

    # Deleter function (optional)
    @first_name.deleter
    def first_name(self):
        raise AttributeError("Can't delete attribute")

property 的一个关键特征是它看上去跟普通的 attribute 没什么两样, 但是访问它的时候会自动触发 getter 、setter 和 deleter 方法。

上述代码中有三个相关联的方法,这三个方法的名字都必须一样。 第一个方法是一个 getter 函数,它使得 first_name 成为一个属性。在实现一个property的时候,底层数据(如果有的话)仍然需要存储在某个地方。 因此,在 get 和 set 方法中,你会看到对 _first_name 属性的操作,这也是实际数据保存的地方。

调用父类方法

class A:
    def __init__(self):
        self.x = 0

class B(A):
    def __init__(self):
        super().__init__()
        self.y = 1

class Proxy:
    def __init__(self, obj):
        self._obj = obj

    # Delegate attribute lookup to internal obj
    def __getattr__(self, name):
        return getattr(self._obj, name)

    # Delegate attribute assignment
    def __setattr__(self, name, value):
        if name.startswith('_'):
            super().__setattr__(name, value) # Call original __setattr__
        else:
            setattr(self._obj, name, value)
  • super() 函数的一个常见用法是在 init() 方法中确保父类被正确的初始化了
  • super() 的另外一个常见用法出现在覆盖 Python 特殊方法的代码

当你使用 super() 函数时,Python会在MRO列表上继续搜索下一个类。 只要每个重定义的方法统一使用 super() 并只调用它一次, 那么控制流最终会遍历完整个MRO列表,每个方法也只会被调用一次。

super() 有个令人吃惊的地方是它并不一定去查找某个类在MRO中下一个直接父类, 你甚至可以在一个没有直接父类的类中使用它。

class A:
    def spam(self):
        print('A.spam')
        super().spam()

>>> a = A()
>>> a.spam()
A.spam
Traceback (most recent call last):
    File "<stdin>", line 1, in <module>
    File "<stdin>", line 4, in spam
AttributeError: 'super' object has no attribute 'spam'
>>>

>>> class B:
...     def spam(self):
...         print('B.spam')
...
>>> class C(A,B):
...     pass
...
>>> c = C()
>>> c.spam()
A.spam
B.spam
>>>

准则

首先,确保在继承体系中所有相同名字的方法拥有可兼容的参数签名(比如相同的参数个数和参数名称)。 这样可以确保 super() 调用一个非直接父类方法时不会出错。 其次,最好确保最顶层的类提供了这个方法的实现,这样的话在MRO上面的查找链肯定可以找到某个确定的方法。

子类中拓展 property

class Person:
    def __init__(self, name):
        self.name = name

    # Getter function
    @property
    def name(self):
        return self._name

    # Setter function
    @name.setter
    def name(self, value):
        if not isinstance(value, str):
            raise TypeError('Expected a string')
        self._name = value

    # Deleter function
    @name.deleter
    def name(self):
        raise AttributeError("Can't delete attribute")

class SubPerson(Person):
    @property
    def name(self):
        print('Getting name')
        return super().name

    @name.setter
    def name(self, value):
        print('Setting name to', value)
        super(SubPerson, SubPerson).name.__set__(self, value)

    @name.deleter
    def name(self):
        print('Deleting name')
        super(SubPerson, SubPerson).name.__delete__(self)
# 仅修改一个方法

class SubPerson(Person):
    @Person.name.getter
    def name(self):
        print('Getting name')
        return super().name

property其实是 getter、setter 和 deleter 方法的集合,而不是单个方法。 因此,当你扩展一个property的时候,你需要先确定你是否要重新定义所有的方法还是说只修改其中某一个。如果你只想重定义其中一个方法,那只使用 @property 本身是不够的。

class SubPerson(Person):
    @property  # Doesn't work
    def name(self):
        print('Getting name')
        return super().name

>>> s = SubPerson('Guido')
Traceback (most recent call last):
    File "<stdin>", line 1, in <module>
    File "example.py", line 5, in __init__
        self.name = name
AttributeError: can't set attribute

# 正确方式

class SubPerson(Person):
    @Person.name.getter
    def name(self):
        print('Getting name')
        return super().name

描述符方式

# A descriptor
class String:
    def __init__(self, name):
        self.name = name

    def __get__(self, instance, cls):
        if instance is None:
            return self
        return instance.__dict__[self.name]

    def __set__(self, instance, value):
        if not isinstance(value, str):
            raise TypeError('Expected a string')
        instance.__dict__[self.name] = value

# A class with a descriptor
class Person:
    name = String('name')

    def __init__(self, name):
        self.name = name

# Extending a descriptor with a property
class SubPerson(Person):
    @property
    def name(self):
        print('Getting name')
        return super().name

    @name.setter
    def name(self, value):
        print('Setting name to', value)
        super(SubPerson, SubPerson).name.__set__(self, value)

    @name.deleter
    def name(self):
        print('Deleting name')
        super(SubPerson, SubPerson).name.__delete__(self)

创建新的类或实例属性

目的:创建一个新的拥有一些额外功能的实例属性类型,比如类型检查。

一个描述器就是一个实现了三个核心的属性访问操作(get, set, delete)的类, 分别为 __get__()__set__()__delete__() 这三个特殊的方法。

# Descriptor attribute for an integer type-checked attribute
class Integer:
    def __init__(self, name):
        self.name = name

    def __get__(self, instance, cls):
        if instance is None:
            return self
        else:
            # getattr(instance, self.name) 会造成递归
            return instance.__dict__[self.name]

    def __set__(self, instance, value):
        if not isinstance(value, int):
            raise TypeError('Expected an int')
        instance.__dict__[self.name] = value

    def __delete__(self, instance):
        del instance.__dict__[self.name]

# 使用
class Point:
    x = Integer('x')
    y = Integer('y')

    def __init__(self, x, y):
        # self.x = Integer('x') # No! Must be a class variable
        # self.y = Integer('y')
        self.x = x
        self.y = y

描述器的一个比较困惑的地方是它只能在类级别被定义,而不能为每个实例单独定义。

__get__() 看上去有点复杂的原因归结于实例变量和类变量的不同。 如果一个描述器被当做一个类变量来访问,那么 instance 参数被设置成 None 。

类型检查描述符

# Descriptor for a type-checked attribute
class Typed:
    def __init__(self, name, expected_type):
        self.name = name
        self.expected_type = expected_type
    def __get__(self, instance, cls):
        if instance is None:
            return self
        else:
            return instance.__dict__[self.name]

    def __set__(self, instance, value):
        if not isinstance(value, self.expected_type):
            raise TypeError('Expected ' + str(self.expected_type))
        instance.__dict__[self.name] = value
    def __delete__(self, instance):
        del instance.__dict__[self.name]

# Class decorator that applies it to selected attributes
def typeassert(**kwargs):
    def decorate(cls):
        for name, expected_type in kwargs.items():
            # Attach a Typed descriptor to the class
            setattr(cls, name, Typed(name, expected_type))
        return cls
    return decorate

# Example use
@typeassert(name=str, shares=int, price=float)
class Stock:
    def __init__(self, name, shares, price):
        self.name = name
        self.shares = shares
        self.price = price

延迟计算属性

目的:将一个只读属性定义成一个property,并且只在访问的时候才会计算结果。 但是一旦被访问后,你希望结果值被缓存起来,不用每次都去计算。

class lazyproperty:
    def __init__(self, func):
        self.func = func

    def __get__(self, instance, cls):
        if instance is None:
            return self
        else:
            value = self.func(instance)
            setattr(instance, self.func.__name__, value)
            return value

import math

class Circle:
    def __init__(self, radius):
        self.radius = radius

    @lazyproperty
    def area(self):
        print('Computing area')
        return math.pi * self.radius ** 2

    @lazyproperty
    def perimeter(self):
        print('Computing perimeter')
        return 2 * math.pi * self.radius

方案有一个小缺陷就是计算出的值被创建后是可以被修改的。

def lazyproperty(func):
    name = '_lazy_' + func.__name__
    @property
    def lazy(self):
        if hasattr(self, name):
            return getattr(self, name)
        else:
            value = func(self)
            setattr(self, name, value)
            return value
    return lazy

然而,这种方案有一个缺点就是所有get操作都必须被定向到属性的 getter 函数上去,相对而言效率要低。

简化数据结构初始化

import math

class Structure1:
    _fields = []

    def __init__(self, *args, **kwargs):
        if len(args) > len(self._fields):
            raise TypeError('Expected {} arguments'.format(len(self._fields)))

        # Set all of the positional arguments
        for name, value in zip(self._fields, args):
            setattr(self, name, value)
        # ==         self.__dict__.update(zip(self._fields,args))

        # Set the remaining keyword arguments
        # 支持关键词参数
        for name in self._fields[len(args):]:
            setattr(self, name, kwargs.pop(name))

        # Check for any remaining unknown arguments
        if kwargs:
            raise TypeError('Invalid argument(s): {}'.format(','.join(kwargs)))

# Example class definitions
class Stock(Structure1):
    _fields = ['name', 'shares', 'price']

class Point(Structure1):
    _fields = ['x', 'y']

class Circle(Structure1):
    _fields = ['radius']

    def area(self):
        return math.pi * self.radius ** 2

# 支持额外的参数

class Structure2:
    # Class variable that specifies expected fields
    _fields = []

    def __init__(self, *args, **kwargs):
        if len(args) != len(self._fields):
            raise TypeError('Expected {} arguments'.format(len(self._fields)))

        # Set the arguments
        for name, value in zip(self._fields, args):
            setattr(self, name, value)

        # Set the additional arguments (if any)
        extra_args = kwargs.keys() - self._fields
        for name in extra_args:
            setattr(self, name, kwargs.pop(name))

        if kwargs:
            raise TypeError('Duplicate values for {}'.format(','.join(kwargs)))

定义接口或者抽象基类

from abc import ABCMeta, abstractmethod

class IStream(metaclass=ABCMeta):
    @abstractmethod
    def read(self, maxbytes=-1):
        pass

    @abstractmethod
    def write(self, data):
        pass

# 注册方法
import io

# Register the built-in I/O classes as supporting our interface
IStream.register(io.IOBase)

# Open a normal file and type check
f = open('foo.txt')
isinstance(f, IStream) # Returns True

抽象类的目的就是让别的类继承它并实现特定的抽象方法,除了继承这种方式外,还可以通过注册方式来让某个类实现抽象基类。

抽象基类的一个主要用途是在代码中检查某些类是否为特定类型,实现了特定接口。

Note:@abstractmethod 还能注解静态方法、类方法和 properties 。 你只需保证这个注解紧靠在函数定义前即可

标准库中有很多用到抽象基类的地方。collections 模块定义了很多跟容器和迭代器(序列、映射、集合等)有关的抽象基类。 numbers 库定义了跟数字对象(整数、浮点数、有理数等)有关的基类。io 库定义了很多跟I/O操作相关的基类。

实现数据模型的类型约束

目的:限制某些在属性赋值上的数据结构

// TODO 理解

描述符方式

# 设置属性描述符
# Base class. Uses a descriptor to set a value
class Descriptor:
    def __init__(self, name=None, **opts):
        self.name = name
        for key, value in opts.items():
            setattr(self, key, value)

    def __set__(self, instance, value):
        instance.__dict__[self.name] = value

# 类型检查描述符
# Descriptor for enforcing types
class Typed(Descriptor):
    expected_type = type(None)

    def __set__(self, instance, value):
        if not isinstance(value, self.expected_type):
            raise TypeError('expected ' + str(self.expected_type))
        super().__set__(instance, value)

# 参数检查描述符
# Descriptor for enforcing values
class Unsigned(Descriptor):
    def __set__(self, instance, value):
        if value < 0:
            raise ValueError('Expected >= 0')
        super().__set__(instance, value)


class MaxSized(Descriptor):
    def __init__(self, name=None, **opts):
        if 'size' not in opts:
            raise TypeError('missing size option')
        super().__init__(name, **opts)

    def __set__(self, instance, value):
        if len(value) >= self.size:
            raise ValueError('size must be < ' + str(self.size))
        super().__set__(instance, value)

# 定义数据类型
class Integer(Typed):
    expected_type = int

class UnsignedInteger(Integer, Unsigned):
    pass

class Float(Typed):
    expected_type = float

class UnsignedFloat(Float, Unsigned):
    pass

class String(Typed):
    expected_type = str

class SizedString(String, MaxSized):
    pass

# 定义类
class Stock:
    # Specify constraints
    name = SizedString('name', size=8)
    shares = UnsignedInteger('shares')
    price = UnsignedFloat('price')

    def __init__(self, name, shares, price):
        self.name = name
        self.shares = shares
        self.price = price

s = Stock('ACME', 75, 0)
s.name
# 'ACME'
s.shares = 75
s.shares = -10
# Traceback (most recent call last):
#     File "<stdin>", line 1, in <module>
#     File "example.py", line 17, in __set__
#         super().__set__(instance, value)
#     File "example.py", line 23, in __set__
#         raise ValueError('Expected >= 0')
# ValueError: Expected >= 0

类型检查类装饰器方式

# Class decorator to apply constraints
def check_attributes(**kwargs):
    def decorate(cls):
        for key, value in kwargs.items():
            # 实例对象
            if isinstance(value, Descriptor):
                # name 名称
                value.name = key
                setattr(cls, key, value)
            else:
            # 非实例
                setattr(cls, key, value(key))
        return cls

    return decorate

# Example
@check_attributes(name=SizedString(size=8),
                  shares=UnsignedInteger,
                  price=UnsignedFloat)
class Stock:
    def __init__(self, name, shares, price):
        self.name = name
        self.shares = shares
        self.price = price

类型检查元类方式

# A metaclass that applies checking
class checkedmeta(type):
    def __new__(cls, clsname, bases, methods):
        # Attach attribute names to the descriptors
        for key, value in methods.items():
            if isinstance(value, Descriptor):
                value.name = key
        return type.__new__(cls, clsname, bases, methods)

# Example
class Stock2(metaclass=checkedmeta):
    name = SizedString(size=8)
    shares = UnsignedInteger()
    price = UnsignedFloat()

    def __init__(self, name, shares, price):
        self.name = name
        self.shares = shares
        self.price = price

所有方法中,类装饰器方案应该是最灵活和最高明的。 首先,它并不依赖任何其他新的技术,比如元类。其次,装饰器可以很容易的添加或删除。最后,装饰器还能作为混入类的替代技术来实现同样的效果。

类装饰器方式

# Decorator for applying type checking
def Typed(expected_type, cls=None):
    if cls is None:
        return lambda cls: Typed(expected_type, cls)
    super_set = cls.__set__

    def __set__(self, instance, value):
        if not isinstance(value, expected_type):
            raise TypeError('expected ' + str(expected_type))
        super_set(self, instance, value)

    cls.__set__ = __set__
    return cls


# Decorator for unsigned values
def Unsigned(cls):
    super_set = cls.__set__

    def __set__(self, instance, value):
        if value < 0:
            raise ValueError('Expected >= 0')
        super_set(self, instance, value)

    cls.__set__ = __set__
    return cls


# Decorator for allowing sized values
def MaxSized(cls):
    super_init = cls.__init__

    def __init__(self, name=None, **opts):
        if 'size' not in opts:
            raise TypeError('missing size option')
        super_init(self, name, **opts)

    cls.__init__ = __init__

    super_set = cls.__set__

    def __set__(self, instance, value):
        if len(value) >= self.size:
            raise ValueError('size must be < ' + str(self.size))
        super_set(self, instance, value)

    cls.__set__ = __set__
    return cls


# Specialized descriptors
@Typed(int)
class Integer(Descriptor):
    pass
@Unsigned
class UnsignedInteger(Integer):
    pass
@Typed(float)
class Float(Descriptor):
    pass
@Unsigned
class UnsignedFloat(Float):
    pass
@Typed(str)
class String(Descriptor):
    pass
@MaxSized
class SizedString(String):
    pass

自定义容器

目的:自定义的类来模拟内置的容器类功能,继承容器模块中的抽象接口。

import collections
import bisect


class SortedItems(collections.Sequence):
    def __init__(self, initial=None):
        self._items = sorted(initial) if initial is not None else []

    # Required sequence methods
    def __getitem__(self, index):
        return self._items[index]

    def __len__(self):
        return len(self._items)

    # Method for adding an item in the right location
    def add(self, item):
        bisect.insort(self._items, item)

items = SortedItems([5, 1, 3])
print(list(items))
print(items[0], items[-1])
items.add(2)
print(list(items))

这里面使用到了 bisect 模块,它是一个在排序列表中插入元素的高效方式。可以保证元素插入后还保持顺序。

使用 collections 中的抽象基类可以确保你自定义的容器实现了所有必要的方法。并且还能简化类型检查。

items = SortedItems()
import collections
isinstance(items, collections.Iterable)
# True
isinstance(items, collections.Sequence)
# True
isinstance(items, collections.Container)
# True
isinstance(items, collections.Sized)
# True
isinstance(items, collections.Mapping)
# False

属性的代理访问

代理是一种编程模式,它将某个操作转移给另外一个对象来实现,目的可能是作为继承的一个替代方法或者实现代理模式。

class A:
    def spam(self, x):
        pass

    def foo(self):
        pass

class B2:
    """使用__getattr__的代理,代理方法比较多时候"""

    def __init__(self):
        self._a = A()

    def bar(self):
        pass

    # Expose all of the methods defined on class A
    def __getattr__(self, name):
        """这个方法在访问的attribute不存在的时候被调用
        the __getattr__() method is actually a fallback method
        that only gets called when an attribute is not found"""
        return getattr(self._a, name)

代理类方式

# A proxy class that wraps around another object, but
# exposes its public attributes
class Proxy:
    def __init__(self, obj):
        self._obj = obj

    # Delegate attribute lookup to internal obj
    def __getattr__(self, name):
        print('getattr:', name)
        return getattr(self._obj, name)

    # Delegate attribute assignment
    def __setattr__(self, name, value):
        if name.startswith('_'):
            super().__setattr__(name, value)
        else:
            print('setattr:', name, value)
            setattr(self._obj, name, value)

    # Delegate attribute deletion
    def __delattr__(self, name):
        if name.startswith('_'):
            super().__delattr__(name)
        else:
            print('delattr:', name)
            delattr(self._obj, name)

当实现代理模式时,还有些细节需要注意。 首先,__getattr__() 实际是一个后备方法,只有在属性不存在时才会调用。 因此,如果代理类实例本身有这个属性的话,那么不会触发这个方法的。 另外,__setattr__()__delattr__() 需要额外的魔法来区分代理实例和被代理实例 _obj 的属性。 通常的约定是只代理那些不以下划线 _ 开头的属性。

__getattr__() 对于大部分以双下划线(__)开始和结尾的属性并不适用,为了让它支持这些方法,你必须手动的实现这些方法代理。

class ListLike:
    """__getattr__对于双下划线开始和结尾的方法是不能用的,需要一个个去重定义"""

    def __init__(self):
        self._items = []

    def __getattr__(self, name):
        return getattr(self._items, name)

    # Added special methods to support certain list operations
    def __len__(self):
        return len(self._items)

    def __getitem__(self, index):
        return self._items[index]

    def __setitem__(self, index, value):
        self._items[index] = value

    def __delitem__(self, index):
        del self._items[index]

类中定义多个构造函数

import time

class Date:
    """方法一:使用类方法"""
    # Primary constructor
    def __init__(self, year, month, day):
        self.year = year
        self.month = month
        self.day = day

    # Alternate constructor
    @classmethod
    def today(cls):
        t = time.localtime()
        return cls(t.tm_year, t.tm_mon, t.tm_mday)

类方法的一个主要用途就是定义多个构造器。

利用 Mixins 扩展类功能

Mixin

class LoggedMappingMixin:
    """
    Add logging to get/set/delete operations for debugging.
    """
    __slots__ = ()  # 混入类都没有实例变量,因为直接实例化混入类没有任何意义

    def __getitem__(self, key):
        print('Getting ' + str(key))
        return super().__getitem__(key)

    def __setitem__(self, key, value):
        print('Setting {} = {!r}'.format(key, value))
        return super().__setitem__(key, value)

    def __delitem__(self, key):
        print('Deleting ' + str(key))
        return super().__delitem__(key)

class SetOnceMappingMixin:
    '''
    Only allow a key to be set once.
    '''
    __slots__ = ()

    def __setitem__(self, key, value):
        if key in self:
            raise KeyError(str(key) + ' already set')
        return super().__setitem__(key, value)

class StringKeysMappingMixin:
    '''
    Restrict keys to strings only
    '''
    __slots__ = ()

    def __setitem__(self, key, value):
        if not isinstance(key, str):
            raise TypeError('keys must be strings')
        return super().__setitem__(key, value)

class LoggedDict(LoggedMappingMixin, dict):
    pass

from collections import defaultdict

class SetOnceDefaultDict(SetOnceMappingMixin, defaultdict):
    pass

这些类单独使用起来没有任何意义,事实上如果你去实例化任何一个类,除了产生异常外没任何作用。 它们是用来通过多继承来和其他映射对象混入使用的。

对于混入类,有几点需要记住。首先是,混入类不能直接被实例化使用。 其次,混入类没有自己的状态信息,也就是说它们并没有定义 __init__() 方法,并且没有实例属性。 这也是为什么我们在上面明确定义了 __slots__ = ()

类装饰器

def LoggedMapping(cls):
    """第二种方式:使用类装饰器"""
    cls_getitem = cls.__getitem__
    cls_setitem = cls.__setitem__
    cls_delitem = cls.__delitem__

    def __getitem__(self, key):
        print('Getting ' + str(key))
        return cls_getitem(self, key)

    def __setitem__(self, key, value):
        print('Setting {} = {!r}'.format(key, value))
        return cls_setitem(self, key, value)

    def __delitem__(self, key):
        print('Deleting ' + str(key))
        return cls_delitem(self, key)

    cls.__getitem__ = __getitem__
    cls.__setitem__ = __setitem__
    cls.__delitem__ = __delitem__
    return cls


@LoggedMapping
class LoggedDict(dict):
    pass

实现状态对象或者状态机(State行为模式)

目的:实现一个状态机或者是在不同状态下执行操作的对象,但是又不想在代码中出现太多的条件判断语句。

ConnectionState <|-- ClosedConnectionState
ConnectionState <|-- OpenConnectionState
Connection "1" *-- "n" ConnectionState
Connection <.. ClosedConnectionState
Connection <.. OpenConnectionState
class Connection:
    """新方案——对每个状态定义一个类"""
    # 成员变量构成组合
    def __init__(self):
        self.new_state(ClosedConnectionState)

    def new_state(self, newstate):
        self._state = newstate
        # Delegate to the state class

    def read(self):
        return self._state.read(self)

    def write(self, data):
        return self._state.write(self, data)

    def open(self):
        return self._state.open(self)

    def close(self):
        return self._state.close(self)


# Connection state base class
class ConnectionState:
    @staticmethod
    def read(conn):
        raise NotImplementedError()

    @staticmethod
    def write(conn, data):
        raise NotImplementedError()

    @staticmethod
    def open(conn):
        raise NotImplementedError()

    @staticmethod
    def close(conn):
        raise NotImplementedError()


# Implementation of different states
class ClosedConnectionState(ConnectionState):
    @staticmethod
    def read(conn):
        raise RuntimeError('Not open')

    @staticmethod
    def write(conn, data):
        raise RuntimeError('Not open')

    # 形参构成依赖
    @staticmethod
    def open(conn):
        conn.new_state(OpenConnectionState)

    @staticmethod
    def close(conn):
        raise RuntimeError('Already closed')


class OpenConnectionState(ConnectionState):
    @staticmethod
    def read(conn):
        print('reading')

    @staticmethod
    def write(conn, data):
        print('writing')

    @staticmethod
    def open(conn):
        raise RuntimeError('Already open')

    # 形参构成依赖
    @staticmethod
    def close(conn):
        conn.new_state(ClosedConnectionState)

c = Connection()
c._state
# <class '__main__.ClosedConnectionState'>
c.read()
# Traceback (most recent call last):
#     File "<stdin>", line 1, in <module>
#     File "example.py", line 10, in read
#         return self._state.read(self)
#     File "example.py", line 43, in read
#         raise RuntimeError('Not open')
# RuntimeError: Not open
c.open()
c._state
# <class '__main__.OpenConnectionState'>
c.read()
# reading
c.write('hello')
# writing
c.close()
c._state
# <class '__main__.ClosedConnectionState'>

每个状态对象都只有静态方法,并没有存储任何的实例属性数据。 实际上,所有状态信息都只存储在 Connection 实例中。

通过字符串调用对象方法(反射自省)

import math

class Point:
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __repr__(self):
        return 'Point({!r:},{!r:})'.format(self.x, self.y)

    def distance(self, x, y):
        return math.hypot(self.x - x, self.y - y)


p = Point(2, 3)
d = getattr(p, 'distance')(0, 0)  # Calls p.distance(0, 0)

# 使用 operator.methodcaller()

import operator
operator.methodcaller('distance', 0, 0)(p)

实现访问者模式

// TODO 理解

目的:处理由大量不同类型的对象组成的复杂数据结构,每一个对象都需要进行不同的处理。

Node <|-- Number
Node <|-- UnaryOperator
Node <|-- BinaryOperator
BinaryOperator <|-- Add
BinaryOperator <|-- Sub
BinaryOperator <|-- Mul
BinaryOperator <|-- Div
UnaryOperator <|-- Negate
NodeVistor <|-- Evaluater
NodeVistor <|-- StackCode
Evaluater ..> Node
StackCode ..> Node
class Node:
    pass

class UnaryOperator(Node):
    def __init__(self, operand):
        self.operand = operand

class BinaryOperator(Node):
    def __init__(self, left, right):
        self.left = left
        self.right = right

class Add(BinaryOperator):
    pass

class Sub(BinaryOperator):
    pass

class Mul(BinaryOperator):
    pass

class Div(BinaryOperator):
    pass

class Negate(UnaryOperator):
    pass

class Number(Node):
    def __init__(self, value):
        self.value = value

t1 = Sub(Number(3), Number(4))
t2 = Mul(Number(2), t1)
t3 = Div(t2, Number(5))
t4 = Add(Number(1), t3)

访问者模式

class NodeVisitor:

    # 依赖node
    def visit(self, node):
        methname = 'visit_' + type(node).__name__
        meth = getattr(self, methname, None)
        if meth is None:
            meth = self.generic_visit
        return meth(node)

    def generic_visit(self, node):
        raise RuntimeError('No {} method'.format('visit_' + type(node).__name__))

# 为了使用这个类,可以定义一个类继承它并且实现各种 `visit_Name()` 方法,其中Name是node类型。

class Evaluator(NodeVisitor):
    def visit_Number(self, node):
        return node.value

    def visit_Add(self, node):
        return self.visit(node.left) + self.visit(node.right)

    def visit_Sub(self, node):
        return self.visit(node.left) - self.visit(node.right)

    def visit_Mul(self, node):
        return self.visit(node.left) * self.visit(node.right)

    def visit_Div(self, node):
        return self.visit(node.left) / self.visit(node.right)

    def visit_Negate(self, node):
        return -node.operand

e = Evaluator()
e.visit(t4)

将一个表达式转换成多个操作序列

class StackCode(NodeVisitor):
    def generate_code(self, node):
        self.instructions = []
        self.visit(node)
        return self.instructions

    def visit_Number(self, node):
        self.instructions.append(('PUSH', node.value))

    def binop(self, node, instruction):
        self.visit(node.left)
        self.visit(node.right)
        self.instructions.append((instruction,))

    def visit_Add(self, node):
        self.binop(node, 'ADD')

    def visit_Sub(self, node):
        self.binop(node, 'SUB')

    def visit_Mul(self, node):
        self.binop(node, 'MUL')

    def visit_Div(self, node):
        self.binop(node, 'DIV')

    def unaryop(self, node, instruction):
        self.visit(node.operand)
        self.instructions.append((instruction,))

    def visit_Negate(self, node):
        self.unaryop(node, 'NEG')

s = StackCode()
s.generate_code(t4)
# [('PUSH', 1), ('PUSH', 2), ('PUSH', 3), ('PUSH', 4), ('SUB',), ('MUL',), ('PUSH', 5), ('DIV',), ('ADD',)]

这种技术也是实现其他语言中switch或case语句的方式。

访问者模式一个缺点就是它严重依赖递归,如果数据结构嵌套层次太深可能会有问题。

不用递归实现访问者模式

// TODO 理解

使用生成器可以在树遍历或搜索算法中消除递归。避免递归的一个通常方法是使用一个栈或队列的数据结构

import types

class Node:
    pass

class NodeVisitor:
    def visit(self, node):
        stack = [node]
        last_result = None
        while stack:
            try:
                last = stack[-1]
                if isinstance(last, types.GeneratorType):
                    # 生成器运行
                    stack.append(last.send(last_result))
                    last_result = None
                elif isinstance(last, Node):
                    # 访问节点
                    stack.append(self._visit(stack.pop()))
                else:
                    # 返回结果
                    last_result = stack.pop()
            except StopIteration:
                # 抛弃已结束生成器
                stack.pop()

        return last_result

    def _visit(self, node):
        methname = 'visit_' + type(node).__name__
        meth = getattr(self, methname, None)
        if meth is None:
            meth = self.generic_visit
        return meth(node)

    def generic_visit(self, node):
        raise RuntimeError('No {} method'.format('visit_' + type(node).__name__))

# 替换 yield 关键词

class Evaluator(NodeVisitor):
    def visit_Number(self, node):
        return node.value

    def visit_Add(self, node):
        # return self.visit(node.left) + self.visit(node.right)
        yield (yield node.left) + (yield node.right)

    def visit_Sub(self, node):
        yield (yield node.left) - (yield node.right)

    def visit_Mul(self, node):
        yield (yield node.left) * (yield node.right)

    def visit_Div(self, node):
        yield (yield node.left) / (yield node.right)

    def visit_Negate(self, node):
        yield - (yield node.operand)

e = Evaluator()
e.visit(t4)

让类支持比较操作

from functools import total_ordering

class Room:
    def __init__(self, name, length, width):
        self.name = name
        self.length = length
        self.width = width
        self.square_feet = self.length * self.width

@total_ordering
class House:
    def __init__(self, name, style):
        self.name = name
        self.style = style
        self.rooms = list()

    @property
    def living_space_footage(self):
        return sum(r.square_feet for r in self.rooms)

    def add_room(self, room):
        self.rooms.append(room)

    def __str__(self):
        return '{}: {} square foot {}'.format(self.name,
                self.living_space_footage,
                self.style)

    def __eq__(self, other):
        return self.living_space_footage == other.living_space_footage

    def __lt__(self, other):
        return self.living_space_footage < other.living_space_footage

# Build a few houses, and add rooms to them
h1 = House('h1', 'Cape')
h1.add_room(Room('Master Bedroom', 14, 21))
h1.add_room(Room('Living Room', 18, 20))
h1.add_room(Room('Kitchen', 12, 16))
h1.add_room(Room('Office', 12, 12))
h2 = House('h2', 'Ranch')
h2.add_room(Room('Master Bedroom', 14, 21))
h2.add_room(Room('Living Room', 18, 20))
h2.add_room(Room('Kitchen', 12, 16))
h3 = House('h3', 'Split')
h3.add_room(Room('Master Bedroom', 14, 21))
h3.add_room(Room('Living Room', 18, 20))
h3.add_room(Room('Office', 12, 16))
h3.add_room(Room('Kitchen', 15, 17))
houses = [h1, h2, h3]
print('Is h1 bigger than h2?', h1 > h2) # prints True
print('Is h2 smaller than h3?', h2 < h3) # prints True
print('Is h2 greater than or equal to h1?', h2 >= h1) # Prints False
print('Which one is biggest?', max(houses)) # Prints 'h3: 1101-square-foot Split'
print('Which is smallest?', min(houses)) # Prints 'h2: 846-square-foot Ranch'

装饰器 functools.total_ordering 就是用来简化这个处理的。 使用它来装饰一个来,你只需定义一个 __eq__() 方法, 外加其他方法(lt, le, gt, or ge)中的一个即可。 然后装饰器会自动为你填充其它比较方法。它就是定义了一个从每个比较支持方法到所有需要定义的其他方法的一个映射而已。

创建缓存实例

目的:在创建一个类的对象时,如果之前使用同样参数创建过这个对象, 你想返回它的缓存引用。

# 工厂函数
# The class in question
class Spam:
    def __init__(self, name):
        self.name = name

# Caching support
import weakref
_spam_cache = weakref.WeakValueDictionary()
def get_spam(name):
    if name not in _spam_cache:
        s = Spam(name)
        _spam_cache[name] = s
    else:
        s = _spam_cache[name]
    return s

a = get_spam('foo')
b = get_spam('bar')
a is b
# False
c = get_spam('foo')
a is c
# True

一个 WeakValueDictionary 实例只会保存那些在其它地方还在被使用的实例。 否则的话,只要实例不再被使用了,它就从字典中被移除了。

应用缓存管理器

import weakref

class CachedSpamManager:
    def __init__(self):
        self._cache = weakref.WeakValueDictionary()

    def get_spam(self, name):
        if name not in self._cache:
            s = Spam(name)
            self._cache[name] = s
        else:
            s = self._cache[name]
        return s

    def clear(self):
            self._cache.clear()

class Spam:
    manager = CachedSpamManager()
    def __init__(self, name):
        self.name = name

    def get_spam(name):
        return Spam.manager.get_spam(name)

# 隐藏实例化类方法
# 第一个是将类的名字修改为以下划线(_)开头,提示用户别直接调用它。 第二种就是让这个类的 __init__() 方法抛出一个异常,让它不能被初始化

class Spam2:
    def __init__(self, *args, **kwargs):
        raise RuntimeError("Can't instantiate directly")

    # Alternate constructor
    @classmethod
    def _new(cls, name):
        self = cls.__new__(cls)
        self.name = name
        return self

# ------------------------最后的方案------------------------
class CachedSpamManager2:
    def __init__(self):
        self._cache = weakref.WeakValueDictionary()

    def get_spam(self, name):
        if name not in self._cache:
            temp = Spam3._new(name)  # Modified creation
            self._cache[name] = temp
        else:
            temp = self._cache[name]
        return temp

    def clear(self):
            self._cache.clear()

class Spam3:
    def __init__(self, *args, **kwargs):
        raise RuntimeError("Can't instantiate directly")

    # Alternate constructor
    @classmethod
    def _new(cls, name):
        self = cls.__new__(cls)
        self.name = name
        return self

元编程

函数装饰器

在函数上添加包装器,增加额外的操作处理。

import timeit
import logging
from functolls import wraps

def timethis(func):
    """
    Decorator that report the execution time.
    """
    @wraps(func)
    def _wrapper(*args, *kwargs):
        start = timeit.default_timer()
        ret = func(*args, *kwargs)
        end = timeit.defaulr_timer()
        logging.getLogger(__name__)
        logging.debug('{} running time: {}'.format(func.__name__, end-start))
        return ret
    return _wrapper

@timethis
def countdown(n):
    '''
    Counts down
    '''
    while n > 0:
        n -= 1

countdown(100000)

@wraps 装饰器来注解底层包装函数。保留函数元数据。@wraps 有一个重要特征是它能让你通过属性 __wrapped__ 直接访问被包装函数。如果有多个包装器,那么访问 __wrapped__ 属性的行为是不可预知的,应该避免这样做。

带参数装饰器

import logging
from functools import wraps

def logged(level, *, name=None, message=None):
    def decorate(func):
        logname = name if name else func.__module__
        log = logging.getLogger(logname)
        logmsg = message if message else func.__name__
        @wraps(func)
        def wrapper(*args, **kwargs):
            log.log(level, logmsg)
            return func(*args, **kwargs)
        return wrapper
    return decorate
# Example use
@logged(logging.DEBUG)
def add(x, y):
    return x + y

logged() 的返回结果必须是一个可调用对象,它接受一个函数作为参数并包装它。

可自定义属性的装饰器

装饰器包装一个函数,并且允许用户提供参数在运行时控制装饰器行为。

from functools import wraps, partial
import logging
# Utility decorator to attach a function as an attribute of obj
def attach_wrapper(obj, func=None):
    if func is None:
        return partial(attach_wrapper, obj)
    setattr(obj, func.__name__, func)
    return func

def logged(level, name=None, message=None):
    '''
    Add logging to a function. level is the logging
    level, name is the logger name, and message is the
    log message. If name and message aren't specified,
    they default to the function's module and name.
    '''
    def decorate(func):
        logname = name if name else func.__module__
        log = logging.getLogger(logname)
        logmsg = message if message else func.__name__

        @wraps(func)
        def wrapper(*args, **kwargs):
            log.log(level, logmsg)
            return func(*args, **kwargs)

        # Attach setter functions
        @attach_wrapper(wrapper)
        def set_level(newlevel):
            nonlocal level
            level = newlevel

        @attach_wrapper(wrapper)
        def set_message(newmsg):
            nonlocal logmsg
            logmsg = newmsg

        return wrapper

    return decorate

# Example use
@logged(logging.DEBUG)
def add(x, y):
    return x + y

@logged(logging.CRITICAL, 'example')
def spam():
    print('Spam!')

>>> import logging
>>> logging.basicConfig(level=logging.DEBUG)
>>> add(2, 3)
DEBUG:__main__:add
5
>>> # Change the log message
>>> add.set_message('Add called')
>>> add(2, 3)
DEBUG:__main__:Add called
5
>>> # Change the log level
>>> add.set_level(logging.WARNING)
>>> add(2, 3)
WARNING:__main__:Add called
5
>>>

# 直接修改属性(这个方法也可能正常工作,但前提是它必须是最外层的装饰器才行。)
@wraps(func)
def wrapper(*args, **kwargs):
    wrapper.log.log(wrapper.level, wrapper.logmsg)
    return func(*args, **kwargs)

# Attach adjustable attributes
wrapper.level = level
wrapper.logmsg = logmsg
wrapper.log = log

关键点在于访问函数(如 set_message()set_level() ),它们被作为属性赋给包装器。 每个访问函数允许使用 nonlocal 来修改函数内部的变量。还有一个令人吃惊的地方是访问函数会在多层装饰器间传播(如果你的装饰器都使用了 @functools.wraps 注解)。

带可选参数的装饰器

你想写一个装饰器,既可以不传参数给它,也可以传递可选参数给它。

import logging
from functools import wraps, partial

def logged(func=None, *, level=logging.DEBUG, name=None, message=None):
    if func is None:
        return partial(logged, level=level, name=name, message=message)

    logname = name if name else func.__module__
    log = logging.getLogger(logname)
    logmsg = message if message else func.__name__

    @wraps(func)
    def wrapper(*args, **kwargs):
        log.log(level, logmsg)
        return func(*args, **kwargs)

    return wrapper

@logged
def add(x, y):
    return x + y
# add = logged(add)
@logged(level=logging.CRITICAL, name='example')
def spam():
    print('Spam!')
# spam = logged(level=logging.CRITICAL, name='example')(spam)

利用装饰器强制函数上的类型检查

from inspect import signature
from functools import wraps

def typeassert(*ty_args, **ty_kwargs):
    def decorate(func):
        # If in optimized mode, disable type checking
        if not __debug__:
            return func

        # Map function argument names to supplied types
        sig = signature(func)
        bound_types = sig.bind_partial(*ty_args, **ty_kwargs).arguments

        @wraps(func)
        def wrapper(*args, **kwargs):
            bound_values = sig.bind(*args, **kwargs)
            # Enforce type assertions across supplied arguments
            for name, value in bound_values.arguments.items():
                if name in bound_types:
                    if not isinstance(value, bound_types[name]):
                        raise TypeError(
                            'Argument {} must be {}'.format(name, bound_types[name])
                            )
            return func(*args, **kwargs)
        return wrapper
    return decorate

这个方案还有点小瑕疵,它对于有默认值的参数并不适用。

inspect.signature() 提取一个可调用对象的参数签名信息。

from inspect import signature
def spam(x, y, z=42):
    pass
sig = signature(spam)
print(sig)
# (x, y, z=42)
sig.parameters
# mappingproxy(OrderedDict([('x', <Parameter at 0x10077a050 'x'>),
# ('y', <Parameter at 0x10077a158 'y'>), ('z', <Parameter at 0x10077a1b0 'z'>)]))

bind_partial() 方法来执行从指定类型到名称的部分绑定。

bound_types = sig.bind_partial(int,z=int)
bound_types
# <inspect.BoundArguments object at 0x10069bb50>
bound_types.arguments
# OrderedDict([('x', <class 'int'>), ('z', <class 'int'>)])

装饰器与函数注解的方法比较,如果注解被用来做类型检查就不能做其他事情了。而且 @typeassert 不能再用于使用注解做其他事情的函数了。

类内定义装饰器

from functools import wraps

class A:
    # Decorator as an instance method
    def decorator1(self, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            print('Decorator 1')
            return func(*args, **kwargs)
        return wrapper

    # Decorator as a class method
    @classmethod
    def decorator2(cls, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            print('Decorator 2')
            return func(*args, **kwargs)
        return wrapper

# As an instance method
a = A()
@a.decorator1
def spam():
    pass
# As a class method
@A.decorator2
def grok():
    pass

尽管最外层的装饰器函数比如 decorator1()decorator2() 需要提供一个 self 或 cls 参数, 但是在两个装饰器内部被创建的 wrapper() 函数并不需要包含这个 self 参数。 你唯一需要这个参数是在你确实要访问包装器中这个实例的某些部分的时候。其他情况下都不用去管它。

@property 装饰器实际上是一个类,它里面定义了三个方法 getter(), setter(), deleter() , 每一个方法都是一个装饰器。它为什么要这么定义的主要原因是各种不同的装饰器方法会在关联的 property 实例上操作它的状态。 因此,任何时候只要你碰到需要在装饰器中记录或绑定信息,那么这不失为一种可行方法。

类作为装饰器

为了将装饰器定义成一个实例,你需要确保它实现了 __call__()__get__() 方法。

import types
from functools import wraps

class Profiled:
    def __init__(self, func):
        wraps(func)(self)
        self.ncalls = 0

    def __call__(self, *args, **kwargs):
        self.ncalls += 1
        return self.__wrapped__(*args, **kwargs)

    def __get__(self, instance, cls):
        # 区分实例与类调用
        if instance is None:
            return self
        else:
            return types.MethodType(self, instance)

@Profiled
def add(x, y):
    return x + y

class Spam:
    @Profiled
    def bar(self, x):
        print(self, x)

__get__() 方法是为了确保绑定方法对象能被正确的创建。 type.MethodType() 手动创建一个绑定方法来使用。

闭包与 nonlocal 实现装饰器。

import types
from functools import wraps

def profiled(func):
    ncalls = 0
    @wraps(func)
    def wrapper(*args, **kwargs):
        nonlocal ncalls
        ncalls += 1
        print(ncalls)
        return func(*args, **kwargs)
    wrapper.ncalls = lambda: ncalls
    return wrapper

# Example
@profiled
def add(x, y):
    return x + y

add.ncalls()

类或静态方法提供装饰器

给类或静态方法提供装饰器是很简单的,不过要确保装饰器在 @classmethod 或 @staticmethod 之前。

如果你把装饰器的顺序写错了就会出错。同样 @abstractmethod 也需要注意顺序,问题在于 @classmethod 和 @staticmethod 实际上并不会创建可直接调用的对象, 而是创建特殊的描述器对象,因此当你试着在其他装饰器中将它们当做函数来使用时就会出错。

装饰器为被包装函数增加参数

from functools import wraps

def optional_debug(func):
    @wraps(func)
    def wrapper(*args, debug=False, **kwargs):
        if debug:
            print('Calling', func.__name__)
        return func(*args, **kwargs)

    return wrapper

@optional_debug
def spam(a,b,c):
    print(a,b,c)

spam(1,2,3)
# 1 2 3
spam(1,2,3, debug=True)
# Calling spam
# 1 2 3

多个函数需要扩充参数

from functools import wraps
import inspect

def optional_debug(func):
    # 排除已有 debug 参数的函数
    if 'debug' in inspect.getargspec(func).args:
        raise TypeError('debug argument already defined')

    @wraps(func)
    def wrapper(*args, debug=False, **kwargs):
        if debug:
            print('Calling', func.__name__)
        return func(*args, **kwargs)
    # 修复签名问题
    sig = inspect.signature(func)
    parms = list(sig.parameters.values())
    parms.append(inspect.Parameter('debug',
                inspect.Parameter.KEYWORD_ONLY,
                default=False))
    wrapper.__signature__ = sig.replace(parameters=parms)
    return wrapper

使用装饰器扩充类的功能

想通过自省或者重写类定义的某部分来修改它的行为,但不希望使用继承或元类的方式。

继承方式

class LoggedGetattribute:
    def __getattribute__(self, name):
        print('getting:', name)
        return super().__getattribute__(name)

# Example:
class A(LoggedGetattribute):
    def __init__(self,x):
        self.x = x
    def spam(self):
        pass

装饰器方式

def log_getattribute(cls):
    # Get the original implementation
    orig_getattribute = cls.__getattribute__

    # Make a new definition
    def new_getattribute(self, name):
        print('getting:', name)
        return orig_getattribute(self, name)

    # Attach to the class and return
    cls.__getattribute__ = new_getattribute
    return cls

# Example use
@log_getattribute
class A:
    def __init__(self,x):
        self.x = x
    def spam(self):
        pass

某种程度上来讲,类装饰器方案就显得更加直观,并且它不会引入新的继承体系。它的运行速度也更快一些, 因为它并不依赖 super() 函数。

使用元类控制实例的创建

通过改变实例创建方式来实现单例、缓存或其他类似的特性。

元类单例

class Singleton(type):
    def __init__(self, *args, **kwargs):
        self.__instance = None
        super().__init__(*args, **kwargs)

    def __call__(self, *args, **kwargs):
        if self.__instance is None:
            self.__instance = super().__call__(*args, **kwargs)
            return self.__instance
        else:
            return self.__instance

# Example
class Spam(metaclass=Singleton):
    def __init__(self):
        print('Creating Spam')

工厂实现单例模式

class _Spam:
    def __init__(self):
        print('Creating Spam')

_spam_instance = None

def Spam():
    global _spam_instance

    if _spam_instance is not None:
        return _spam_instance
    else:
        _spam_instance = _Spam()
        return _spam_instance

缓存实例

import weakref

class Cached(type):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__cache = weakref.WeakValueDictionary()

    def __call__(self, *args):
        if args in self.__cache:
            return self.__cache[args]
        else:
            obj = super().__call__(*args)
            self.__cache[args] = obj
            return obj

# Example
class Spam(metaclass=Cached):
    def __init__(self, name):
        print('Creating Spam({!r})'.format(name))
        self.name = name

利用元类实现多种实例创建模式通常要比不使用元类的方式优雅得多。更多关于创建缓存实例、弱引用等内容,请 参考

捕获类的属性定义顺序

自动记录一个类中属性和方法定义的顺序,然后可以利用它来做很多操作(比如序列化、映射到数据库等等)。

利用元类可以很容易的捕获类的定义信息。

from collections import OrderedDict

# A set of descriptors for various types
class Typed:
    _expected_type = type(None)
    def __init__(self, name=None):
        self._name = name

    def __set__(self, instance, value):
        if not isinstance(value, self._expected_type):
            raise TypeError('Expected ' + str(self._expected_type))
        instance.__dict__[self._name] = value

class Integer(Typed):
    _expected_type = int

class Float(Typed):
    _expected_type = float

class String(Typed):
    _expected_type = str

# Metaclass that uses an OrderedDict for class body
class OrderedMeta(type):
    def __new__(cls, clsname, bases, clsdict):
        d = dict(clsdict)
        order = []
        for name, value in clsdict.items():
            if isinstance(value, Typed):
                # 名称
                value._name = name
                order.append(name)
        d['_order'] = order
        return type.__new__(cls, clsname, bases, d)

    # 最先调用返回 clsdict
    @classmethod
    def __prepare__(cls, clsname, bases):
        return OrderedDict()


class Structure(metaclass=OrderedMeta):
    def as_csv(self):
        return ','.join(str(getattr(self.name)) for name in self._order)

# Example use
class Stock(Structure):
    name = String()
    shares = Integer()
    price = Float()

    def __init__(self, name, shares, price):
        self.name = name
        self.shares = shares
        self.price = price

s = Stock('GOOG',100,490.1)
s.name
# 'GOOG'
s.as_csv()
# 'GOOG,100,490.1'

一个关键点就是 OrderedMeta 元类中定义的 __prepare__() 方法。 这个方法会在开始定义类和它的父类的时候被执行。它必须返回一个映射对象以便在类定义体中被使用到。 我们这里通过返回了一个 OrderedDict 而不是一个普通的字典,可以很容易的捕获定义的顺序。

定义有可选参数的元类

为了使元类支持这些关键字参数,你必须确保在 __prepare__() , __new__()__init__() 方法中 都使用强制关键字参数。

class MyMeta(type):
    # Optional
    @classmethod
    def __prepare__(cls, name, bases, *, debug=False, synchronize=False):
        # Custom processing
        pass
        return super().__prepare__(name, bases)

    # Required
    def __new__(cls, name, bases, ns, *, debug=False, synchronize=False):
        # Custom processing
        pass
        return super().__new__(cls, name, bases, ns)

    # Required
    def __init__(self, name, bases, ns, *, debug=False, synchronize=False):
        # Custom processing
        pass
        super().__init__(name, bases, ns)

给一个元类添加可选关键字参数需要你完全弄懂类创建的所有步骤, 因为这些参数会被传递给每一个相关的方法。 __prepare__() 方法在所有类定义开始执行前首先被调用,用来创建类命名空间。 通常来讲,这个方法只是简单的返回一个字典或其他映射对象。 __new__() 方法被用来实例化最终的类对象。它在类的主体被执行完后开始执行。 __init__() 方法最后被调用,用来执行其他的一些初始化工作。

当我们构造元类的时候,通常只需要定义一个 __new__()__init__() 方法,但不是两个都定义。 但是,如果需要接受其他的关键字参数的话,这两个方法就要同时提供,并且都要提供对应的参数签名。 默认的 __prepare__() 方法接受任意的关键字参数,但是会忽略它们, 所以只有当这些额外的参数可能会影响到类命名空间的创建时你才需要去定义 __prepare__() 方法。

使用关键字参数配置一个元类还可以视作对类变量的一种替代方式。

class Spam(metaclass=MyMeta):
    debug = True
    synchronize = True
    pass

将这些属性定义为参数的好处在于它们不会污染类的名称空间, 这些属性仅仅只从属于类的创建阶段,而不是类中的语句执行阶段。 另外,它们在 __prepare__() 方法中是可以被访问的,因为这个方法会在所有类主体执行前被执行。 但是类变量只能在元类的 __new__()__init__() 方法中可见。

*args 和 **kwargs 的强制参数签名

对任何涉及到操作函数调用签名的问题,你都应该使用 inspect 模块中的签名特性。主要关注两个类:Signature 和 Parameter 。

>>> from inspect import Signature, Parameter
>>> # Make a signature for a func(x, y=42, *, z=None)
>>> parms = [ Parameter('x', Parameter.POSITIONAL_OR_KEYWORD),
...         Parameter('y', Parameter.POSITIONAL_OR_KEYWORD, default=42),
...         Parameter('z', Parameter.KEYWORD_ONLY, default=None) ]
>>> sig = Signature(parms)
>>> print(sig)
(x, y=42, *, z=None)
>>>

# 一旦你有了一个签名对象,你就可以使用它的 bind() 方法很容易的将它绑定到 *args 和 **kwargs 上去。

>>> def func(*args, **kwargs):
...     bound_values = sig.bind(*args, **kwargs)
...     for name, value in bound_values.arguments.items():
...         print(name,value)
from inspect import Signature, Parameter

def make_sig(*names):
    parms = [Parameter(name, Parameter.POSITIONAL_OR_KEYWORD)
            for name in names]
    return Signature(parms)

class Structure:
    __signature__ = make_sig()
    def __init__(self, *args, **kwargs):
        bound_values = self.__signature__.bind(*args, **kwargs)
        for name, value in bound_values.arguments.items():
            setattr(self, name, value)

# Example use
class Stock(Structure):
    __signature__ = make_sig('name', 'shares', 'price')

class Point(Structure):
    __signature__ = make_sig('x', 'y')

>>> import inspect
>>> print(inspect.signature(Stock))
(name, shares, price)
>>> s1 = Stock('ACME', 100, 490.1)
>>> s2 = Stock('ACME', 100)
Traceback (most recent call last):
...
TypeError: 'price' parameter lacking default value
>>> s3 = Stock('ACME', 100, 490.1, shares=50)
Traceback (most recent call last):
...
TypeError: multiple values for argument 'shares'
>>>

元类创建签名对象

from inspect import Signature, Parameter

def make_sig(*names):
    parms = [Parameter(name, Parameter.POSITIONAL_OR_KEYWORD)
            for name in names]
    return Signature(parms)

class StructureMeta(type):
    def __new__(cls, clsname, bases, clsdict):
        clsdict['__signature__'] = make_sig(*clsdict.get('_fields',[]))
        return super().__new__(cls, clsname, bases, clsdict)

class Structure(metaclass=StructureMeta):
    _fields = []
    def __init__(self, *args, **kwargs):
        bound_values = self.__signature__.bind(*args, **kwargs)
        for name, value in bound_values.arguments.items():
            setattr(self, name, value)

# Example
class Stock(Structure):
    _fields = ['name', 'shares', 'price']

class Point(Structure):
    _fields = ['x', 'y']

当我们自定义签名的时候,将签名存储在特定的属性 signature 中通常是很有用的。 这样在使用 inspect 模块执行内省的代码就能发现签名并将它作为调用。

在类上强制使用编程规约

如果你想监控类的定义,通常可以通过定义一个元类。一个基本元类通常是继承自 type 并重定义它的 __new__() 方法 或者是 __init__() 方法。

class MyMeta(type):
    def __new__(self, clsname, bases, clsdict):
        # clsname is name of class being defined
        # bases is tuple of base classes
        # clsdict is class dictionary
        return super().__new__(cls, clsname, bases, clsdict)

class MyMeta(type):
    def __init__(self, clsname, bases, clsdict):
        super().__init__(clsname, bases, clsdict)
        # clsname is name of class being defined
        # bases is tuple of base classes
        # clsdict is class dictionary

元类的一个关键特点是它允许你在定义的时候检查类的内容。在重新定义 __init__() 方法中, 你可以很轻松的检查类字典、父类等等。并且,一旦某个元类被指定给了某个类,那么就会被继承到所有子类中去。 因此,一个框架的构建者就能在大型的继承体系中通过给一个顶级父类指定一个元类去捕获所有下面子类的定义。

class NoMixedCaseMeta(type):
    def __new__(cls, clsname, bases, clsdict):
        for name in clsdict:
            if name.lower() != name:
                raise TypeError('Bad attribute name: ' + name)
        return super().__new__(cls, clsname, bases, clsdict)

class Root(metaclass=NoMixedCaseMeta):
    pass

class A(Root):
    def foo_bar(self): # Ok
        pass

class B(Root):
    def fooBar(self): # TypeError
        pass

检测重载方法

确保它的调用参数跟父类中原始方法有着相同的参数签名。

from inspect import signature
import logging

class MatchSignaturesMeta(type):

    def __init__(self, clsname, bases, clsdict):
        super().__init__(clsname, bases, clsdict)
        # 找位于继承体系中构建 self 父类的定义
        sup = super(self, self)
        for name, value in clsdict.items():
            if name.startswith('_') or not callable(value):
                continue
            # Get the previous definition (if any) and compare the signatures
            prev_dfn = getattr(sup, name, None)
            if prev_dfn:
                prev_sig = signature(prev_dfn)
                val_sig = signature(value)
                if prev_sig != val_sig:
                    logging.warning('Signature mismatch in %s. %s != %s',
                                    value.__qualname__, prev_sig, val_sig)

# Example
class Root(metaclass=MatchSignaturesMeta):
    pass

class A(Root):
    def foo(self, x, y):
        pass

    def spam(self, x, *, z):
        pass

# Class with redefined methods, but slightly different signatures
class B(A):
    def foo(self, a, b):
        pass

    def spam(self,x,z):
        pass

# WARNING:root:Signature mismatch in B.spam. (self, x, *, z) != (self, x, z)
# WARNING:root:Signature mismatch in B.foo. (self, x, y) != (self, a, b)

在元类中选择重新定义 __new__() 方法还是 __init__() 方法取决于你想怎样使用结果类。 __new__() 方法在类创建之前被调用,通常用于通过某种方式(比如通过改变类字典的内容)修改类的定义。 而 __init__() 方法是在类被创建之后被调用,当你需要完整构建类对象的时候会很有用。 在最后一个例子中,这是必要的,因为它使用了 super() 函数来搜索之前的定义。 它只能在类的实例被创建之后,并且相应的方法解析顺序也已经被设置好了。

代码中有一行使用了 super(self, self) 并不是排版错误。 当使用元类的时候,我们要时刻记住一点就是 self 实际上是一个类对象。 因此,这条语句其实就是用来寻找位于继承体系中构建 self 父类的定义

以编程方式定义类

你可以使用函数 types.new_class() 来初始化新的类对象。 你需要做的只是提供类的名字、父类元组、关键字参数,以及一个用成员变量填充类字典的回调函数。

# stock.py
# Example of making a class manually from parts

# Methods
def __init__(self, name, shares, price):
    self.name = name
    self.shares = shares
    self.price = price
def cost(self):
    return self.shares * self.price

cls_dict = {
    '__init__' : __init__,
    'cost' : cost,
}

# Make a class
import types

Stock = types.new_class('Stock', (), {}, lambda ns: ns.update(cls_dict))
Stock.__module__ = __name__

每次当一个类被定义后,它的 __module__ 属性包含定义它的模块名。 这个名字用于生成 __repr__() 方法的输出。(<class ‘main.Stock’>)

new_class() 第三个参数还可以包含其他的关键字参数。第四个参数最神秘,它是一个用来接受类命名空间的映射对象的函数。 通常这是一个普通的字典,但是它实际上是 __prepare__() 方法返回的任意对象, 这个函数需要使用 update() 方法给命名空间增加内容。

exec 方式实现 named tuple

import operator
import types
import sys

def named_tuple(classname, fieldnames):
    # Populate a dictionary of field property accessors
    cls_dict = { name: property(operator.itemgetter(n))
                for n, name in enumerate(fieldnames) }

    # Make a __new__ function and add to the class dict
    def __new__(cls, *args):
        if len(args) != len(fieldnames):
            raise TypeError('Expected {} arguments'.format(len(fieldnames)))
        return tuple.__new__(cls, args)

    cls_dict['__new__'] = __new__

    # Make the class
    cls = types.new_class(classname, (tuple,), {},
                        lambda ns: ns.update(cls_dict))

    # Set the module to that of the caller
    cls.__module__ = sys._getframe(1).f_globals['__name__']
    return cls

Stock = type('Stock', (), cls_dict) 这种方法的问题在于它忽略了一些关键步骤,比如对于元类中 __prepare__() 方法的调用。 通过使用 types.new_class() ,你可以保证所有的必要初始化步骤都能得到执行。

如果你仅仅只是想执行准备步骤,可以使用 types.prepare_class()

import types
metaclass, kwargs, ns = types.prepare_class('Stock', (), {'metaclass': type})

在定义的时候初始化类的成员

在类被定义的时候就初始化一部分类的成员,而不是要等到实例被创建后。

import operator

class StructTupleMeta(type):
    def __init__(cls, *args, **kwargs):
        super().__init__(*args, **kwargs)
        for n, name in enumerate(cls._fields):
            setattr(cls, name, property(operator.itemgetter(n)))

class StructTuple(tuple, metaclass=StructTupleMeta):
    _fields = []
    def __new__(cls, *args):
        if len(args) != len(cls._fields):
            raise ValueError('{} arguments required'.format(len(cls._fields)))
        return super().__new__(cls,args)

class Stock(StructTuple):
    _fields = ['name', 'shares', 'price']

class Point(StructTuple):
    _fields = ['x', 'y']

s = Stock('ACME', 50, 91.1)

函数 operator.itemgetter() 创建一个访问器函数, 然后 property() 函数将其转换成一个属性。

本节最难懂的部分是知道不同的初始化步骤是什么时候发生的。 StructTupleMeta 中的 __init__() 方法只在每个类被定义时被调用一次。 cls 参数就是那个被定义的类。实际上,上述代码使用了 _fields 类变量来保存新的被定义的类, 然后给它再添加一点新的东西。

StructTuple 类作为一个普通的基类,供其他使用者来继承。 这个类中的 __new__() 方法用来构造新的实例。 这里使用 __new__() 并不是很常见,主要是因为我们要修改元组的调用签名, 使得我们可以像普通的实例调用那样创建实例。

__init__() 不同的是,__new__() 方法在实例被创建之前被触发。 由于元组是不可修改的,所以一旦它们被创建了就不可能对它做任何改变。

利用函数注解实现方法重载

// TODO 复习

实现基于类型的方法重载。

元类方法

# multiple.py
import inspect
import types

class MultiMethod:
    '''
    Represents a single multimethod.
    '''
    def __init__(self, name):
        self._methods = {}
        self.__name__ = name

    def register(self, meth):
        '''
        Register a new method as a multimethod
        '''
        sig = inspect.signature(meth)

        # Build a type signature from the method's annotations
        types = []
        for name, parm in sig.parameters.items():
            if name == 'self':
                continue
            if parm.annotation is inspect.Parameter.empty:
                raise TypeError(
                    'Argument {} must be annotated with a type'.format(name)
                )
            if not isinstance(parm.annotation, type):
                raise TypeError(
                    'Argument {} annotation must be a type'.format(name)
                )
            if parm.default is not inspect.Parameter.empty:
                self._methods[tuple(types)] = meth
            types.append(parm.annotation)

        self._methods[tuple(types)] = meth

    def __call__(self, *args):
        '''
        Call a method based on type signature of the arguments
        '''
        types = tuple(type(arg) for arg in args[1:])
        meth = self._methods.get(types, None)
        if meth:
            return meth(*args)
        else:
            raise TypeError('No matching method for types {}'.format(types))

    def __get__(self, instance, cls):
        '''
        Descriptor method needed to make calls work in a class
        '''
        if instance is not None:
            return types.MethodType(self, instance)
        else:
            return self

class MultiDict(dict):
    '''
    Special dictionary to build multimethods in a metaclass
    '''
    def __setitem__(self, key, value):
        if key in self:
            # If key already exists, it must be a multimethod or callable
            current_value = self[key]
            if isinstance(current_value, MultiMethod):
                current_value.register(value)
            else:
                mvalue = MultiMethod(key)
                mvalue.register(current_value)
                mvalue.register(value)
                super().__setitem__(key, mvalue)
        else:
            super().__setitem__(key, value)

class MultipleMeta(type):
    '''
    Metaclass that allows multiple dispatch of methods
    '''
    def __new__(cls, clsname, bases, clsdict):
        return type.__new__(cls, clsname, bases, dict(clsdict))

    @classmethod
    def __prepare__(cls, clsname, bases):
        return MultiDict()

class Spam(metaclass=MultipleMeta):
    def bar(self, x:int, y:int):
        print('Bar 1:', x, y)

    def bar(self, s:str, n:int = 0):
        print('Bar 2:', s, n)

# Example: overloaded __init__
import time

class Date(metaclass=MultipleMeta):
    def __init__(self, year: int, month:int, day:int):
        self.year = year
        self.month = month
        self.day = day

    def __init__(self):
        t = time.localtime()
        self.__init__(t.tm_year, t.tm_mon, t.tm_mday)

本节的实现中的主要思路其实是很简单的。MutipleMeta 元类使用它的 __prepare__() 方法 来提供一个作为 MultiDict 实例的自定义字典。这个跟普通字典不一样的是, MultiDict 会在元素被设置的时候检查是否已经存在,如果存在的话,重复的元素会在 MultiMethod 实例中合并。

MultiMethod 实例通过构建从类型签名到函数的映射来收集方法。 在这个构建过程中,函数注解被用来收集这些签名然后构建这个映射。 这个过程在 MultiMethod.register() 方法中实现。 这种映射的一个关键特点是对于多个方法,所有参数类型都必须要指定,否则就会报错。

为了让 MultiMethod 实例模拟一个调用,__call__() 方法被实现了。 这个方法从所有排除 self 的参数中构建一个类型元组,在内部map中查找这个方法, 然后调用相应的方法。为了能让 MultiMethod 实例在类定义时正确操作,__get__() 是必须得实现的。 它被用来构建正确的绑定方法。

缺陷:

  • 不能使用关键字参数
  • 对于继承有限制

描述器实现

import types

class multimethod:
    def __init__(self, func):
        self._methods = {}
        self.__name__ = func.__name__
        self._default = func

    def match(self, *types):
        def register(func):
            # 查找默认参数个数
            ndefaults = len(func.__defaults__) if func.__defaults__ else 0
            for n in range(ndefaults+1):
                self._methods[types[:len(types) - n]] = func
            return self
        return register

    def __call__(self, *args):
        types = tuple(type(arg) for arg in args[1:])
        meth = self._methods.get(types, None)
        if meth:
            return meth(*args)
        else:
            return self._default(*args)

    def __get__(self, instance, cls):
        if instance is not None:
            return types.MethodType(self, instance)
        else:
            return self

class Spam:
    @multimethod
    def bar(self, *args):
        # Default method called if no match
        raise TypeError('No matching method for bar')

    @bar.match(int, int)
    def bar(self, x, y):
        print('Bar 1:', x, y)

    @bar.match(str, int)
    def bar(self, s, n = 0):
        print('Bar 2:', s, n)

描述器方案同样也有前面提到的限制(不支持关键字参数和继承)。

Guido van Rossum

避免重复的属性方法

def typed_property(name, expected_type):
    storage_name = '_' + name

    @property
    def prop(self):
        return getattr(self, storage_name)

    @prop.setter
    def prop(self, value):
        if not isinstance(value, expected_type):
            raise TypeError('{} must be a {}'.format(name, expected_type))
        setattr(self, storage_name, value)

    return prop

# Example use
class Person:
    name = typed_property('name', str)
    age = typed_property('age', int)

    def __init__(self, name, age):
        self.name = name
        self.age = age

# 简化方法

from functools import partial

String = partial(typed_property, expected_type=str)
Integer = partial(typed_property, expected_type=int)

# Example:
class Person:
    name = String('name')
    age = Integer('age')

    def __init__(self, name, age):
        self.name = name
        self.age = age

定义上下文管理器

import time
from contextlib import contextmanager

@contextmanager
def timethis(label):
    start = time.time()
    try:
        yield
    finally:
        end = time.time()
        print('{}: {}'.format(label, end - start))

# Example use
with timethis('counting'):
    n = 10000000
    while n > 0:
        n -= 1

在函数 timethis() 中,yield 之前的代码会在上下文管理器中作为 __enter__() 方法执行, 所有在 yield 之后的代码会作为 __exit__() 方法执行。 如果出现了异常,异常会在 yield 语句那里抛出。

实现了列表对象上的某种事务

@contextmanager
def list_transaction(orig_list):
    working = list(orig_list)
    yield working
    orig_list[:] = working

@contextmanager 应该仅仅用来写自包含的上下文管理函数。如果你有一些对象(比如一个文件、网络连接或锁),需要支持 with 语句,那么你就需要单独实现 __enter__() 方法和 __exit__() 方法。

在局部变量域中执行代码

调用 exec() 之前使用 locals() 函数来得到一个局部变量字典。

默认情况下,exec() 会在调用者局部和全局范围内执行代码。然而,在函数里面, 传递给 exec() 的局部范围是拷贝实际局部变量组成的一个字典。 因此,如果 exec() 如果执行了修改操作,这种修改后的结果对实际局部变量值是没有影响的。

调用 locals() 获取局部变量时,你获得的是传递给 exec() 的局部变量的一个拷贝。 通过在代码执行后审查这个字典的值,那就能获取修改后的值了。

解析与分析Python源码

// TODO

模块与包

构建一个模块的层级包

封装成包是很简单的。在文件系统上组织你的代码,并确保每个目录都定义了一个__init__.py文件。

控制模块被全部导入的内容

在你的模块中定义一个变量 __all__ 来明确地列出需要导出的内容。

如果你不做任何事, 这样的导入将会导入所有不以下划线开头的。 另一方面,如果定义了 __all__ , 那么只有被列举出的东西会被导出。如果你将 __all__ 定义成一个空列表, 没有东西将被导入。 如果 __all__ 包含未定义的名字, 在导入时引起 AttributeError。

使用相对路径名导入包中子模块

import语句的 . 和 .. 看起来很滑稽, 但它指定目录名.为当前目录,..B为目录../B。这种语法只适用于 import

尽管使用相对导入看起来像是浏览文件系统,但是不能到定义包的目录之外。也就是说,使用点的这种模式从不是包的目录中导入将会引发错误。

最后,相对导入只适用于在合适的包中的模块。尤其是在顶层的脚本的简单模块中,它们将不起作用。

将模块分割成多个文件

# mymodule.py
class A:
    def spam(self):
        print('A.spam')

class B(A):
    def bar(self):
        print('B.bar')

# mymodule/
#     __init__.py
#     a.py
#     b.py

# __init__.py
from .a import A
from .b import B

延迟导入

# __init__.py
def A():
    from .a import A
    return A()

def B():
    from .b import B
    return B()

if isinstance(x, mymodule.A): # Error
if isinstance(x, mymodule.a.A): # Ok

延迟加载的主要缺点是继承和类型检查可能会中断。

延迟加载的真实例子, 见标准库 multiprocessing/__init__.py 的源码。

利用命名空间导入目录分散的代码

foo-package/
    spam/
        blah.py

bar-package/
    spam/
        grok.py

>>> import sys
>>> sys.path.extend(['foo-package', 'bar-package'])
>>> import spam.blah
>>> import spam.grok

包命名空间的一个重要特点是任何人都可以用自己的代码来扩展命名空间。

一个包是否被作为一个包命名空间的主要方法是检查其 file 属性。如果没有,那包是个命名空间。这也可以由其字符表现形式中的“namespace”这个词体现出来。

重新加载模块

>>> import spam
>>> import imp
>>> imp.reload(spam)

reload() 擦除了模块底层字典的内容,并通过重新执行模块的源代码来刷新它。模块对象本身的身份保持不变。因此,该操作在程序中所有已经被导入了的地方更新了模块。

在生产环境中可能需要避免重新加载模块。在交互环境下调试,解释程序并试图弄懂它。

读取位于包中的数据文件

pkgutil.get_data() 函数是一个读取数据文件的高级工具,不用管包是如何安装以及安装在哪。它只是工作并将文件内容以字节字符串返回。

get_data() 的第一个参数是包含包名的字符串。你可以直接使用包名,也可以使用特殊的变量,比如 __package__ 。第二个参数是包内文件的相对名称。如果有必要,可以使用标准的Unix命名规范到不同的目录,只要最后的目录仍然位于包中。

将文件夹加入到sys.path

  • 使用PYTHONPATH环境变量来添加
  • 是创建一个.pth文件,将目录列举出来,.pth 文件需要放在某个 Python 的 site-packages 目录
  • 编码实现
import sys
from os.path import abspath, join, dirname
sys.path.insert(0, join(abspath(dirname(__file__)), 'src'))

通过字符串名导入模块

使用 importlib.import_module() 函数来手动导入名字为字符串给出的一个模块或者包的一部分。

import_module只是简单地执行和import相同的步骤,但是返回生成的模块对象。你只需要将其存储在一个变量,然后像正常的模块一样使用。

通过钩子远程加载模块

// TODO

网络与Web编程

作为客户端与HTTP服务交互

from urllib import request, parse

# Base URL being accessed
url = 'http://httpbin.org/get'

# Dictionary of query parameters (if any)
parms = {
   'name1' : 'value1',
   'name2' : 'value2'
}

# Encode the query string
querystring = parse.urlencode(parms)

# Make a GET request and read the response
u = request.urlopen(url+'?' + querystring)
resp = u.read()

from urllib import request, parse

# Extra headers
headers = {
    'User-agent' : 'none/ofyourbusiness',
    'Spam' : 'Eggs'
}

req = request.Request(url, querystring.encode('ascii'), headers=headers)

# Make a request and read the response
u = request.urlopen(req)
resp = u.read()

创建 TCP 服务器

from socketserver import BaseRequestHandler, StreamRequestHandler, TCPServer

# 文本类型
class EchoHandler(BaseRequestHandler):
    def handle(self):
        print('Got connection from', self.client_address)
        while True:
            msg = self.request.recv(8192)
            if not msg:
                break
            self.request.send(msg)

# 流类型
class EchoHandler2(StreamRequestHandler):
    def handle(self):
        print('Got connection from', self.client_address)
        # self.rfile is a file-like object for reading
        for line in self.rfile:
            # self.wfile is a file-like object for writing
            self.wfile.write(line)

if __name__ == '__main__':
    serv = TCPServer(('', 20000), EchoHandler)
    serv = TCPServer(('', 20000), EchoHandler2)
    serv.serve_forever()

并发 TCP 服务

# ForkingTCPServer 或者是 ThreadingTCPServer
from socketserver import ThreadingTCPServer


if __name__ == '__main__':
    serv = ThreadingTCPServer(('', 20000), EchoHandler)
    serv.serve_forever()

# 限制线程个数

if __name__ == '__main__':
    from threading import Thread
    NWORKERS = 16
    serv = TCPServer(('', 20000), EchoHandler)
    for n in range(NWORKERS):
        t = Thread(target=serv.serve_forever)
        t.daemon = True
        t.start()
    serv.serve_forever()

端口复用

一般来讲,一个 TCPServer 在实例化的时候会绑定并激活相应的 socket 。 不过,有时候你想通过设置某些选项去调整底下的 socket ,可以设置参数 bind_and_activate=False 。

# 允许服务器重新绑定一个之前使用过的端口号
if __name__ == '__main__':
    serv = TCPServer(('', 20000), EchoHandler, bind_and_activate=False)
    # Set up various socket options
    serv.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
    # Bind and activate
    serv.server_bind()
    serv.server_activate()
    serv.serve_forever()

# 简化版
if __name__ == '__main__':
    TCPServer.allow_reuse_address = True
    serv = TCPServer(('', 20000), EchoHandler)
    serv.serve_forever()

socket 底层实现

from socket import socket, AF_INET, SOCK_STREAM

def echo_handler(address, client_sock):
    print('Got connection from {}'.format(address))
    while True:
        msg = client_sock.recv(8192)
        if not msg:
            break
        client_sock.sendall(msg)
    client_sock.close()

def echo_server(address, backlog=5):
    sock = socket(AF_INET, SOCK_STREAM)
    sock.bind(address)
    sock.listen(backlog)
    while True:
        client_sock, client_addr = sock.accept()
        echo_handler(client_addr, client_sock)

if __name__ == '__main__':
    echo_server(('', 20000))

创建 UDP 服务器

from socketserver import BaseRequestHandler, UDPServer
import time

class TimeHandler(BaseRequestHandler):
    def handle(self):
        print('Got connection from', self.client_address)
        # Get message and client socket
        msg, sock = self.request
        resp = time.ctime()
        sock.sendto(resp.encode('ascii'), self.client_address)

if __name__ == '__main__':
    serv = UDPServer(('', 20000), TimeHandler)
    serv.serve_forever()

并发 UDP 服务

from socketserver import ThreadingUDPServer

   if __name__ == '__main__':
    serv = ThreadingUDPServer(('',20000), TimeHandler)
    serv.serve_forever()

socket 底层实现

from socket import socket, AF_INET, SOCK_DGRAM
import time

def time_server(address):
    sock = socket(AF_INET, SOCK_DGRAM)
    sock.bind(address)
    while True:
        msg, addr = sock.recvfrom(8192)
        print('Got message from', addr)
        resp = time.ctime()
        sock.sendto(resp.encode('ascii'), addr)

if __name__ == '__main__':
    time_server(('', 20000))

CIDR地址生成对应的IP地址集

>>> import ipaddress
>>> net = ipaddress.ip_network('123.45.67.64/27')
>>> net
IPv4Network('123.45.67.64/27')

要注意的是,ipaddress 模块跟其他一些和网络相关的模块比如 socket 库交集很少。 所以,你不能使用 IPv4Address 的实例来代替一个地址字符串,你首先得显式的使用 str() 转换它。

创建一个简单的REST接口

# resty.py

import cgi

def notfound_404(environ, start_response):
    start_response('404 Not Found', [ ('Content-type', 'text/plain') ])
    return [b'Not Found']

class PathDispatcher:
    def __init__(self):
        self.pathmap = { }

    def __call__(self, environ, start_response):
        path = environ['PATH_INFO']
        # 查询参数
        params = cgi.FieldStorage(environ['wsgi.input'],
                                  environ=environ)
        method = environ['REQUEST_METHOD'].lower()
        environ['params'] = { key: params.getvalue(key) for key in params }
        handler = self.pathmap.get((method,path), notfound_404)
        return handler(environ, start_response)

    def register(self, method, path, function):
        self.pathmap[method.lower(), path] = function
        return function

import time

_hello_resp = '''\
<html>
  <head>
     <title>Hello {name}</title>
   </head>
   <body>
     <h1>Hello {name}!</h1>
   </body>
</html>'''

def hello_world(environ, start_response):
    start_response('200 OK', [ ('Content-type','text/html')])
    params = environ['params']
    resp = _hello_resp.format(name=params.get('name'))
    yield resp.encode('utf-8')

_localtime_resp = '''\
<?xml version="1.0"?>
<time>
  <year>{t.tm_year}</year>
  <month>{t.tm_mon}</month>
  <day>{t.tm_mday}</day>
  <hour>{t.tm_hour}</hour>
  <minute>{t.tm_min}</minute>
  <second>{t.tm_sec}</second>
</time>'''

def localtime(environ, start_response):
    start_response('200 OK', [ ('Content-type', 'application/xml') ])
    resp = _localtime_resp.format(t=time.localtime())
    yield resp.encode('utf-8')

if __name__ == '__main__':
    # from resty import PathDispatcher
    from wsgiref.simple_server import make_server

    # Create the dispatcher and register functions
    dispatcher = PathDispatcher()
    dispatcher.register('GET', '/hello', hello_world)
    dispatcher.register('GET', '/localtime', localtime)

    # Launch a basic server
    httpd = make_server('', 8080, dispatcher)
    print('Serving on port 8080...')
    httpd.serve_forever()
  • environ 属性是一个字典,包含了从web服务器提供的CGI接口中获取的值。
  • cgi.FieldStorage() 可以从请求中提取查询参数并将它们放入一个类字典对象中以便后面使用。
  • start_response 参数是一个为了初始化一个请求对象而必须被调用的函数。 第一个参数是返回的HTTP状态值,第二个参数是一个(名,值)元组列表,用来构建返回的HTTP头。
  • 为了返回数据,一个WSGI程序必须返回一个字节字符串序列。

WSGI本身是一个很小的标准。因此它并没有提供一些高级的特性比如认证、cookies、重定向等。

通过 XML-RPC 实现简单的远程调用

from xmlrpc.server import SimpleXMLRPCServer

class KeyValueServer:
    _rpc_methods_ = ['get', 'set', 'delete', 'exists', 'keys']
    def __init__(self, address):
        self._data = {}
        self._serv = SimpleXMLRPCServer(address, allow_none=True)
        for name in self._rpc_methods_:
            self._serv.register_function(getattr(self, name))

    def get(self, name):
        return self._data[name]

    def set(self, name, value):
        self._data[name] = value

    def delete(self, name):
        del self._data[name]

    def exists(self, name):
        return name in self._data

    def keys(self):
        return list(self._data)

    def serve_forever(self):
        self._serv.serve_forever()

# Example
if __name__ == '__main__':
    kvserv = KeyValueServer(('', 15000))
    kvserv.serve_forever()

# 客户端访问

>>> from xmlrpc.client import ServerProxy
>>> s = ServerProxy('http://localhost:15000', allow_none=True)
>>> s.set('foo', 'bar')
>>> s.set('spam', [1, 2, 3])
>>> s.keys()
['spam', 'foo']
>>> s.get('foo')
'bar'
>>> s.get('spam')
[1, 2, 3]
>>> s.delete('spam')
>>> s.exists('spam')
False
>>>
# 二进制需要特别处理
>>> s.set('foo', b'Hello World')
>>> s.get('foo')
<xmlrpc.client.Binary object at 0x10131d410>

>>> _.data
b'Hello World'
>>>

在不同的Python解释器之间交互

你在不同的机器上面运行着多个 Python 解释器实例,并希望能够在这些解释器之间通过消息来交换数据。

from multiprocessing.connection import Listener
import traceback

def echo_client(conn):
    try:
        while True:
            msg = conn.recv()
            conn.send(msg)
    except EOFError:
        print('Connection closed')

def echo_server(address, authkey):
    serv = Listener(address, authkey=authkey)
    while True:
        try:
            client = serv.accept()

            echo_client(client)
        except Exception:
            traceback.print_exc()

echo_server(('', 25000), authkey=b'peekaboo')

# 访问

>>> from multiprocessing.connection import Client
>>> c = Client(('localhost', 25000), authkey=b'peekaboo')
>>> c.send('hello')
>>> c.recv()
'hello'
>>> c.send(42)
>>> c.recv()
42
>>> c.send([1, 2, 3, 4, 5])
>>> c.recv()
[1, 2, 3, 4, 5]

跟底层 socket 不同的是,每个消息会完整保存(每一个通过 send() 发送的对象能通过 recv() 来完整接受)。 另外,所有对象会通过 pickle 序列化。因此,任何兼容 pickle 的对象都能在此连接上面被发送和接受。

通用准则,不要使用 multiprocessing 来实现一个对外的公共服务。 Client()Listener() 中的 authkey 参数用来认证发起连接的终端用户。

实现远程方法调用

服务端

# rpcserver.py

import pickle
class RPCHandler:
    def __init__(self):
        self._functions = { }

    def register_function(self, func):
        self._functions[func.__name__] = func

    def handle_connection(self, connection):
        try:
            while True:
                # Receive a message
                func_name, args, kwargs = pickle.loads(connection.recv())
                # Run the RPC and send a response
                try:
                    r = self._functions[func_name](*args, **kwargs)
                    connection.send(pickle.dumps(r))
                except Exception as e:
                    connection.send(pickle.dumps(e))
        except EOFError:
             pass

from multiprocessing.connection import Listener
from threading import Thread

def rpc_server(handler, address, authkey):
    sock = Listener(address, authkey=authkey)
    while True:
        client = sock.accept()
        t = Thread(target=handler.handle_connection, args=(client,))
        t.daemon = True
        t.start()

# Some remote functions
def add(x, y):
    return x + y

def sub(x, y):
    return x - y

# Register with a handler
handler = RPCHandler()
handler.register_function(add)
handler.register_function(sub)

# Run the server
rpc_server(handler, ('localhost', 17000), authkey=b'peekaboo')

客户端

import pickle

class RPCProxy:
    def __init__(self, connection):
        self._connection = connection
    def __getattr__(self, name):
        def do_rpc(*args, **kwargs):
            self._connection.send(pickle.dumps((name, args, kwargs)))
            result = pickle.loads(self._connection.recv())
            if isinstance(result, Exception):
                raise result
            return result
        return do_rpc

>>> from multiprocessing.connection import Client
>>> c = Client(('localhost', 17000), authkey=b'peekaboo')
>>> proxy = RPCProxy(c)
>>> proxy.add(2, 3)

5
>>> proxy.sub(2, 3)
-1
>>> proxy.sub([1, 2], 4)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "rpcserver.py", line 37, in do_rpc
    raise result
TypeError: unsupported operand type(s) for -: 'list' and 'int'

简单的客户端认证

import hmac
import os

def client_authenticate(connection, secret_key):
    '''
    Authenticate client to a remote service.
    connection represents a network connection.
    secret_key is a key known only to both client/server.
    '''
    message = connection.recv(32)
    hash = hmac.new(secret_key, message)
    digest = hash.digest()
    connection.send(digest)

def server_authenticate(connection, secret_key):
    '''
    Request client authentication.
    '''
    message = os.urandom(32)
    connection.send(message)
    hash = hmac.new(secret_key, message)
    digest = hash.digest()
    response = connection.recv(len(digest))
    return hmac.compare_digest(digest,response)

# 使用

from socket import socket, AF_INET, SOCK_STREAM

secret_key = b'peekaboo'
def echo_handler(client_sock):
    if not server_authenticate(client_sock, secret_key):
        client_sock.close()
        return
    while True:

        msg = client_sock.recv(8192)
        if not msg:
            break
        client_sock.sendall(msg)

def echo_server(address):
    s = socket(AF_INET, SOCK_STREAM)
    s.bind(address)
    s.listen(5)
    while True:
        c,a = s.accept()
        echo_handler(c)

echo_server(('', 18000))

# Within a client, you would do this:

from socket import socket, AF_INET, SOCK_STREAM

secret_key = b'peekaboo'

s = socket(AF_INET, SOCK_STREAM)
s.connect(('localhost', 18000))
client_authenticate(s, secret_key)
s.send(b'Hello World')
resp = s.recv(1024)

网络服务中加入 SSL

# 服务端

from socket import socket, AF_INET, SOCK_STREAM
import ssl

KEYFILE = 'server_key.pem'   # Private key of the server
CERTFILE = 'server_cert.pem' # Server certificate (given to client)

def echo_client(s):
    while True:
        data = s.recv(8192)
        if data == b'':
            break
        s.send(data)
    s.close()
    print('Connection closed')

def echo_server(address):
    s = socket(AF_INET, SOCK_STREAM)
    s.bind(address)
    s.listen(1)

    # Wrap with an SSL layer requiring client certs
    s_ssl = ssl.wrap_socket(s,
                            keyfile=KEYFILE,
                            certfile=CERTFILE,
                            server_side=True
                            )
    # Wait for connections
    while True:
        try:
            c,a = s_ssl.accept()
            print('Got connection', c, a)
            echo_client(c)
        except Exception as e:
            print('{}: {}'.format(e.__class__.__name__, e))

echo_server(('', 20000))

# 客户端
>>> from socket import socket, AF_INET, SOCK_STREAM
>>> import ssl
>>> s = socket(AF_INET, SOCK_STREAM)
>>> s_ssl = ssl.wrap_socket(s,
                cert_reqs=ssl.CERT_REQUIRED,
                ca_certs = 'server_cert.pem')
>>> s_ssl.connect(('localhost', 20000))
>>> s_ssl.send(b'Hello World?')
12
>>> s_ssl.recv(8192)
b'Hello World?'

mixin 类

import ssl

class SSLMixin:
    '''
    Mixin class that adds support for SSL to existing servers based
    on the socketserver module.
    '''
    def __init__(self, *args,
                 keyfile=None, certfile=None, ca_certs=None,
                 cert_reqs=ssl.CERT_NONE,
                 **kwargs):
        self._keyfile = keyfile
        self._certfile = certfile
        self._ca_certs = ca_certs
        self._cert_reqs = cert_reqs
        super().__init__(*args, **kwargs)

    def get_request(self):
        client, addr = super().get_request()
        client_ssl = ssl.wrap_socket(client,
                                     keyfile = self._keyfile,
                                     certfile = self._certfile,
                                     ca_certs = self._ca_certs,
                                     cert_reqs = self._cert_reqs,
                                     server_side = True)
        return client_ssl, addr

SSL XMLRPC

# XML-RPC server with SSL

from xmlrpc.server import SimpleXMLRPCServer

class SSLSimpleXMLRPCServer(SSLMixin, SimpleXMLRPCServer):
    pass

Here's the XML-RPC server from Recipe 11.6 modified only slightly to use SSL:

import ssl
from xmlrpc.server import SimpleXMLRPCServer
from sslmixin import SSLMixin

class SSLSimpleXMLRPCServer(SSLMixin, SimpleXMLRPCServer):
    pass

class KeyValueServer:
    _rpc_methods_ = ['get', 'set', 'delete', 'exists', 'keys']
    def __init__(self, *args, **kwargs):
        self._data = {}
        self._serv = SSLSimpleXMLRPCServer(*args, allow_none=True, **kwargs)
        for name in self._rpc_methods_:
            self._serv.register_function(getattr(self, name))

    def get(self, name):
        return self._data[name]

    def set(self, name, value):
        self._data[name] = value

    def delete(self, name):
        del self._data[name]

    def exists(self, name):
        return name in self._data

    def keys(self):
        return list(self._data)

    def serve_forever(self):
        self._serv.serve_forever()

if __name__ == '__main__':
    KEYFILE='server_key.pem'    # Private key of the server
    CERTFILE='server_cert.pem'  # Server certificate
    kvserv = KeyValueServer(('', 15000),
                            keyfile=KEYFILE,
                            certfile=CERTFILE)
    kvserv.serve_forever()

>>> from xmlrpc.client import ServerProxy
>>> s = ServerProxy('https://localhost:15000', allow_none=True)
>>> s.set('foo','bar')
>>> s.set('spam', [1, 2, 3])
>>> s.keys()
['spam', 'foo']
>>> s.get('foo')
'bar'
>>> s.get('spam')
[1, 2, 3]
>>> s.delete('spam')
>>> s.exists('spam')
False

验证 XMLRPC 服务器证书

from xmlrpc.client import SafeTransport, ServerProxy
import ssl

class VerifyCertSafeTransport(SafeTransport):
    def __init__(self, cafile, certfile=None, keyfile=None):
        SafeTransport.__init__(self)
        self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
        self._ssl_context.load_verify_locations(cafile)
        if certfile:
            self._ssl_context.load_cert_chain(certfile, keyfile)
        self._ssl_context.verify_mode = ssl.CERT_REQUIRED

    def make_connection(self, host):
        # Items in the passed dictionary are passed as keyword
        # arguments to the http.client.HTTPSConnection() constructor.
        # The context argument allows an ssl.SSLContext instance to
        # be passed with information about the SSL configuration
        s = super().make_connection((host, {'context': self._ssl_context}))
        return s

# Create the client proxy
s = ServerProxy('https://localhost:15000',
                transport=VerifyCertSafeTransport('server_cert.pem'),
                allow_none=True)

进程间传递 Socket 文件描述符

import multiprocessing
from multiprocessing.reduction import recv_handle, send_handle
import socket

def worker(in_p, out_p):
    out_p.close()
    while True:
        fd = recv_handle(in_p)
        print('CHILD: GOT FD', fd)
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=fd) as s:
            while True:
                msg = s.recv(1024)
                if not msg:
                    break
                print('CHILD: RECV {!r}'.format(msg))
                s.send(msg)

def server(address, in_p, out_p, worker_pid):
    in_p.close()
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
    s.bind(address)
    s.listen(1)
    while True:
        client, addr = s.accept()
        print('SERVER: Got connection from', addr)
        send_handle(out_p, client.fileno(), worker_pid)
        client.close()

if __name__ == '__main__':
    c1, c2 = multiprocessing.Pipe()
    worker_p = multiprocessing.Process(target=worker, args=(c1,c2))
    worker_p.start()

    server_p = multiprocessing.Process(target=server,
                  args=(('', 15000), c1, c2, worker_p.pid))
    server_p.start()

    c1.close()
    c2.close()

两个进程被创建并通过一个 multiprocessing 管道连接起来。 服务器进程打开一个socket并等待客户端连接请求。 工作进程仅仅使用 recv_handle() 在管道上面等待接收一个文件描述符。 当服务器接收到一个连接,它将产生的socket文件描述符通过 send_handle() 传递给工作进程。

// TODO

理解事件驱动IO

// TODO 理解

事件驱动I/O本质上来讲就是将基本I/O操作(比如读和写)转化为你程序需要处理的事件。

class EventHandler:
    def fileno(self):
        'Return the associated file descriptor'
        raise NotImplemented('must implement')

    def wants_to_receive(self):
        'Return True if receiving is allowed'
        return False

    def handle_receive(self):
        'Perform the receive operation'
        pass

    def wants_to_send(self):
        'Return True if sending is requested'
        return False

    def handle_send(self):
        'Send outgoing data'
        pass

import select

def event_loop(handlers):
    while True:
        wants_recv = [h for h in handlers if h.wants_to_receive()]
        wants_send = [h for h in handlers if h.wants_to_send()]
        can_recv, can_send, _ = select.select(wants_recv, wants_send, [])
        for h in can_recv:
            h.handle_receive()
        for h in can_send:
            h.handle_send()

事件循环的关键部分是 select() 调用,它会不断轮询文件描述符从而激活它。 在调用 select() 之前,事件循环会询问所有的处理器来决定哪一个想接受或发生。 然后它将结果列表提供给 select() 。然后 select() 返回准备接受或发送的对象组成的列表。 然后相应的 handle_receive()handle_send() 方法被触发。

UDP 网络服务

import socket
import time

class UDPServer(EventHandler):
    def __init__(self, address):
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.sock.bind(address)

    def fileno(self):
        return self.sock.fileno()

    def wants_to_receive(self):
        return True

class UDPTimeServer(UDPServer):
    def handle_receive(self):
        msg, addr = self.sock.recvfrom(1)
        self.sock.sendto(time.ctime().encode('ascii'), addr)

class UDPEchoServer(UDPServer):
    def handle_receive(self):
        msg, addr = self.sock.recvfrom(8192)
        self.sock.sendto(msg, addr)

if __name__ == '__main__':
    handlers = [ UDPTimeServer(('',14000)), UDPEchoServer(('',15000))  ]
    event_loop(handlers)

>>> from socket import *
>>> s = socket(AF_INET, SOCK_DGRAM)
>>> s.sendto(b'',('localhost',14000))
0
>>> s.recvfrom(128)
(b'Tue Sep 18 14:29:23 2012', ('127.0.0.1', 14000))
>>> s.sendto(b'Hello',('localhost',15000))
5
>>> s.recvfrom(128)
(b'Hello', ('127.0.0.1', 15000))

TCP 网络服务

每一个客户端都要初始化一个新的处理器对象。

class TCPServer(EventHandler):
    def __init__(self, address, client_handler, handler_list):
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
        self.sock.bind(address)
        self.sock.listen(1)
        self.client_handler = client_handler
        self.handler_list = handler_list

    def fileno(self):
        return self.sock.fileno()

    def wants_to_receive(self):
        return True

    def handle_receive(self):
        client, addr = self.sock.accept()
        # Add the client to the event loop's handler list
        self.handler_list.append(self.client_handler(client, self.handler_list))

class TCPClient(EventHandler):
    def __init__(self, sock, handler_list):
        self.sock = sock
        self.handler_list = handler_list
        self.outgoing = bytearray()

    def fileno(self):
        return self.sock.fileno()

    def close(self):
        self.sock.close()
        # Remove myself from the event loop's handler list
        self.handler_list.remove(self)

    def wants_to_send(self):
        return True if self.outgoing else False

    def handle_send(self):
        nsent = self.sock.send(self.outgoing)
        self.outgoing = self.outgoing[nsent:]

class TCPEchoClient(TCPClient):
    def wants_to_receive(self):
        return True

    def handle_receive(self):
        data = self.sock.recv(8192)
        if not data:
            self.close()
        else:
            self.outgoing.extend(data)

if __name__ == '__main__':
   handlers = []
   handlers.append(TCPServer(('',16000), TCPEchoClient, handlers))
   event_loop(handlers)

TCP例子的关键点是从处理器中列表增加和删除客户端的操作。 对每一个连接,一个新的处理器被创建并加到列表中。当连接被关闭后,每个客户端负责将其从列表中删除。

事件驱动I/O的缺点是没有真正的同步机制。 如果任何事件处理器方法阻塞或执行一个耗时计算,它会阻塞所有的处理进程。 调用那些并不是事件驱动风格的库函数也会有问题,同样要是某些库函数调用会阻塞,那么也会导致整个事件循环停止。

对于阻塞或耗时计算的问题可以通过将事件发送个其他单独的现场或进程来处理。

from concurrent.futures import ThreadPoolExecutor
import os

class ThreadPoolHandler(EventHandler):
    def __init__(self, nworkers):
        if os.name == 'posix':
            self.signal_done_sock, self.done_sock = socket.socketpair()
        else:
            server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            server.bind(('127.0.0.1', 0))
            server.listen(1)
            self.signal_done_sock = socket.socket(socket.AF_INET,
                                                  socket.SOCK_STREAM)
            self.signal_done_sock.connect(server.getsockname())
            self.done_sock, _ = server.accept()
            server.close()

        self.pending = []
        self.pool = ThreadPoolExecutor(nworkers)

    def fileno(self):
        return self.done_sock.fileno()

    # Callback that executes when the thread is done
    def _complete(self, callback, r):

        self.pending.append((callback, r.result()))
        self.signal_done_sock.send(b'x')

    # Run a function in a thread pool
    def run(self, func, args=(), kwargs={},*,callback):
        r = self.pool.submit(func, *args, **kwargs)
        r.add_done_callback(lambda r: self._complete(callback, r))

    def wants_to_receive(self):
        return True

    # Run callback functions of completed work
    def handle_receive(self):
        # Invoke all pending callback functions
        for callback, result in self.pending:
            callback(result)
            self.done_sock.recv(1)
        self.pending = []

run() 方法被用来将工作提交给回调函数池,处理完成后调用。 实际工作被提交给 ThreadPoolExecutor 实例。 不过一个难点是协调计算结果和事件循环,为了解决它,我们创建了一对 socket 并将其作为某种信号量机制来使用。 当线程池完成工作后,它会执行类中的 _complete() 方法。 这个方法再某个 socket 上写入字节之前会将挂起的回调函数和结果放入队列中。 fileno() 方法返回另外的那个 socket。 因此,这个字节被写入时,它会通知事件循环, 然后 handle_receive() 方法被激活并为所有之前提交的工作执行回调函数。

# A really bad Fibonacci implementation
def fib(n):
    if n < 2:
        return 1
    else:
        return fib(n - 1) + fib(n - 2)

class UDPFibServer(UDPServer):
    def handle_receive(self):
        msg, addr = self.sock.recvfrom(128)
        n = int(msg)
        pool.run(fib, (n,), callback=lambda r: self.respond(r, addr))

    def respond(self, result, addr):
        self.sock.sendto(str(result).encode('ascii'), addr)

if __name__ == '__main__':
    pool = ThreadPoolHandler(16)
    handlers = [ pool, UDPFibServer(('',16000))]
    event_loop(handlers)

from socket import *
sock = socket(AF_INET, SOCK_DGRAM)
for x in range(40):
    sock.sendto(str(x).encode('ascii'), ('localhost', 16000))
    resp = sock.recvfrom(8192)
    print(resp[0])

发送与接收大型数组

# zerocopy.py

def send_from(arr, dest):
    view = memoryview(arr).cast('B')
    while len(view):
        nsent = dest.send(view)
        view = view[nsent:]

def recv_into(arr, source):
    view = memoryview(arr).cast('B')
    while len(view):
        nrecv = source.recv_into(view)
        view = view[nrecv:]

# Server

>>> from socket import *
>>> s = socket(AF_INET, SOCK_STREAM)
>>> s.bind(('', 25000))
>>> s.listen(1)
>>> c,a = s.accept()
>>> import numpy
>>> a = numpy.arange(0.0, 50000000.0)
>>> send_from(a, c)

# Client

>>> from socket import *
>>> c = socket(AF_INET, SOCK_STREAM)
>>> c.connect(('localhost', 25000))
>>> import numpy
>>> a = numpy.zeros(shape=50000000, dtype=float)
>>> a[0:10]
array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])
>>> recv_into(a, c)
>>> a[0:10]
array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.])

这个视图能被传递给socket相关函数, 比如 socket.send()send.recv_into() 。 在内部,这些方法能够直接操作这个内存区域。

这里有个问题就是接受者必须事先知道有多少数据要被发送, 以便它能预分配一个数组或者确保它能将接受的数据放入一个已经存在的数组中。 如果没办法知道的话,发送者就得先将数据大小发送过来,然后再发送实际的数组数据。

并发编程

启动与停止线程

threading 库可以在单独的线程中执行任何的在 Python 中可以调用的对象。你可以创建一个 Thread 对象并将你要执行的对象以 target 参数的形式提供给该对象。

# Code to execute in an independent thread
import time
def countdown(n):
    while n > 0:
        print('T-minus', n)
        n -= 1
        time.sleep(5)

# Create and launch a thread
from threading import Thread
t = Thread(target=countdown, args=(10,))
t.start()

当你创建好一个线程对象后,该对象并不会立即执行,除非你调用它的 start() 方法。

由于全局解释锁(GIL)的原因,Python 的线程被限制到同一时刻只允许一个线程执行这样一个执行模型。所以,Python 的线程更适用于处理 I/O 和其他需要并发执行的阻塞操作(比如等待I/O、等待从数据库获取数据等等),而不是需要多处理器并行的计算密集型任务。

线程同步方式

线程的一个关键特性是每个线程都是独立运行且状态不可预测。Event 对象包含一个可由线程设置的信号标志,它允许线程等待某些事件的发生。在初始情况下,event 对象中的信号标志被设置为假。如果有线程等待一个 event 对象,而这个 event 对象的标志为假,那么这个线程将会被一直阻塞直至该标志为真。一个线程如果将一个 event 对象的信号标志设置为真,它将唤醒所有等待这个 event 对象的线程。如果一个线程等待一个已经被设置为真的 event 对象,那么它将忽略这个事件,继续执行。

from threading import Thread, Event
import time

# Code to execute in an independent thread
def countdown(n, started_evt):
    print('countdown starting')
    started_evt.set()
    while n > 0:
        print('T-minus', n)
        n -= 1
        time.sleep(5)

# Create the event object that will be used to signal startup
started_evt = Event()

# Launch the thread and pass the startup event
print('Launching countdown')
t = Thread(target=countdown, args=(10, started_evt))
t.start()

# Wait for the thread to start
started_evt.wait()
print('countdown is running')

# Launching countdown
# countdown starting
# T-minus 10
# countdown is running
# T-minus 9
# T-minus 8
# T-minus 7
# T-minus 6
# T-minus 5
# T-minus 4
# T-minus 3
# T-minus 2
# T-minus 1

event 对象最好单次使用,就是说,你创建一个 event 对象,让某个线程等待这个对象,一旦这个对象被设置为真,你就应该丢弃它。尽管可以通过 clear() 方法来重置 event 对象,但是很难确保安全地清理 event 对象并对它重新赋值。很可能会发生错过事件、死锁或者其他问题(特别是,你无法保证重置 event 对象的代码会在线程再次等待这个 event 对象之前执行)。如果一个线程需要不停地重复使用 event 对象,你最好使用 Condition 对象来代替。

Condition

event 对象的一个重要特点是当它被设置为真时,会唤醒所有等待它的线程。如果你只想唤醒单个线程,最好是使用信号量或者 Condition 对象来替代。

import threading
import time

class PeriodicTimer:
    def __init__(self, interval):
        self._interval = interval
        self._flag = 0
        self._cv = threading.Condition()

    def start(self):
        t = threading.Thread(target=self.run)
        t.daemon = True
        t.start()

    def run(self):
        '''
        Run the timer and notify waiting threads after each interval
        '''
        while True:
            time.sleep(self._interval)
            with self._cv:
                 self._flag ^= 1
                 self._cv.notify_all()

    def wait_for_tick(self):
        '''
        Wait for the next tick of the timer
        '''
        with self._cv:
            last_flag = self._flag
            while last_flag == self._flag:
                self._cv.wait()

# Example use of the timer
ptimer = PeriodicTimer(5)
ptimer.start()

# Two threads that synchronize on the timer
def countdown(nticks):
    while nticks > 0:
        ptimer.wait_for_tick()
        print('T-minus', nticks)
        nticks -= 1

def countup(last):
    n = 0
    while n < last:
        ptimer.wait_for_tick()
        print('Counting', n)
        n += 1

threading.Thread(target=countdown, args=(10,)).start()
threading.Thread(target=countup, args=(5,)).start()

Semaphore

# Worker thread
def worker(n, sema):
    # Wait to be signaled
    sema.acquire()

    # Do some work
    print('Working', n)

# Create some threads
sema = threading.Semaphore(0)
nworkers = 10
for n in range(nworkers):
    t = threading.Thread(target=worker, args=(n, sema,))
    t.start()

>>> sema.release()
Working 0
>>> sema.release()
Working 1

线程间通信

from queue import Queue
from threading import Thread

# Object that signals shutdown
_sentinel = object()

# A thread that produces data
def producer(out_q):
    while running:
        # Produce some data
        pass
        out_q.put(data)

    # Put the sentinel on the queue to indicate completion
    out_q.put(_sentinel)

# A thread that consumes data
def consumer(in_q):
    while True:
        # Get some data
        data = in_q.get()

        # Check for termination
        if data is _sentinel:
            in_q.put(_sentinel)
            break

        # Process the data
        pass

# Create the shared queue and launch both threads
q = Queue()
t1 = Thread(target=consumer, args=(q,))
t2 = Thread(target=producer, args=(q,))
t1.start()
t2.start()

Queue 对象已经包含了必要的锁,所以你可以通过它在多个线程间多安全地共享数据。本例中有一个特殊的地方:消费者在读到这个特殊值之后立即又把它放回到队列中,将之传递下去。这样,所有监听这个队列的消费者线程就可以全部关闭了。

Condition 实现优先队列



class PriorityQueue:
    def __init__(self):
        self._queue = []
        self._count = 0
        self._cv = threading.Condition()
    def put(self, item, priority):
        with self._cv:
            heapq.heappush(self._queue, (-priority, self._count, item))
            self._count += 1
            self._cv.notify()

    def get(self):
        with self._cv:
            while len(self._queue) == 0:
                self._cv.wait()
            return heapq.heappop(self._queue)[-1]

使用队列来进行线程间通信是一个单向、不确定的过程。通常情况下,你没有办法知道接收数据的线程是什么时候接收到的数据并开始工作的。不过队列对象提供一些基本完成的特性,比如 task_done()join()

from queue import Queue
from threading import Thread

# A thread that produces data
def producer(out_q):
    while running:
        # Produce some data
        ...
        out_q.put(data)

# A thread that consumes data
def consumer(in_q):
    while True:
        # Get some data
        data = in_q.get()

        # Process the data
        ...
        # Indicate completion
        in_q.task_done()

# Create the shared queue and launch both threads
q = Queue()
t1 = Thread(target=consumer, args=(q,))
t2 = Thread(target=producer, args=(q,))
t1.start()
t2.start()

# Wait for all produced items to be consumed
q.join()

如果一个线程需要在一个“消费者”线程处理完特定的数据项时立即得到通知,你可以把要发送的数据和一个 Event 放到一起使用,这样“生产者”就可以通过这个Event对象来监测处理的过程了。

from queue import Queue
from threading import Thread, Event

# A thread that produces data
def producer(out_q):
    while running:
        # Produce some data
        ...
        # Make an (data, event) pair and hand it to the consumer
        evt = Event()
        out_q.put((data, evt))
        ...
        # Wait for the consumer to process the item
        evt.wait()

# A thread that consumes data
def consumer(in_q):
    while True:
        # Get some data
        data, evt = in_q.get()
        # Process the data
        ...
        # Indicate completion
        evt.set()

使用线程队列有一个要注意的问题是,向队列中添加数据项时并不会复制此数据项,线程间通信实际上是在线程间传递对象引用。如果你担心对象的共享状态,那你最好只传递不可修改的数据结构(如:整型、字符串或者元组)或者一个对象的深拷贝。

Queue 对象提供一些在当前上下文很有用的附加特性。比如在创建 Queue 对象时提供可选的 size 参数来限制可以添加到队列中的元素数量。get()put() 方法都支持非阻塞方式和设定超时。

import queue
q = queue.Queue()

try:
    data = q.get(block=False)
except queue.Empty:
    ...

try:
    q.put(item, block=False)
except queue.Full:
    ...

try:
    data = q.get(timeout=5.0)
except queue.Empty:
    ...

最后,有 q.qsize()q.full()q.empty() 等实用方法可以获取一个队列的当前大小和状态。但要注意,这些方法都不是线程安全的。

线程锁

import threading

class SharedCounter:
    '''
    A counter object that can be shared by multiple threads.
    '''
    def __init__(self, initial_value = 0):
        self._value = initial_value
        self._value_lock = threading.Lock()

    def incr(self,delta=1):
        '''
        Increment the counter with locking
        '''
        with self._value_lock:
             self._value += delta

    def decr(self,delta=1):
        '''
        Decrement the counter with locking
        '''
        with self._value_lock:
             self._value -= delta

Lock 对象和 with 语句块一起使用可以保证互斥执行,就是每次只有一个线程可以执行 with 语句包含的代码块。with 语句会在这个代码块执行前自动获取锁,在执行结束后自动释放锁。

RLock

RLock (可重入锁)可以被同一个线程多次获取,主要用来实现基于监测对象模式的锁定和同步。在使用这种锁的情况下,当锁被持有时,只有一个线程可以使用完整的函数或者类中的方法。

# 所有实例共享的类级锁
import threading

class SharedCounter:
    '''
    A counter object that can be shared by multiple threads.
    '''
    _lock = threading.RLock()
    def __init__(self, initial_value = 0):
        self._value = initial_value

    def incr(self,delta=1):
        '''
        Increment the counter with locking
        '''
        with SharedCounter._lock:
            self._value += delta

    def decr(self,delta=1):
        '''
        Decrement the counter with locking
        '''
        with SharedCounter._lock:
             self.incr(-delta)

防止死锁的加锁机制

import threading
from contextlib import contextmanager

# Thread-local state to stored information on locks already acquired
_local = threading.local()

@contextmanager
def acquire(*locks):
    # Sort locks by object identifier
    locks = sorted(locks, key=lambda x: id(x))

    # Make sure lock order of previously acquired locks is not violated
    acquired = getattr(_local,'acquired',[])
    if acquired and max(id(lock) for lock in acquired) >= id(locks[0]):
        raise RuntimeError('Lock Order Violation')

    # Acquire all of the locks
    acquired.extend(locks)
    _local.acquired = acquired

    try:
        for lock in locks:
            lock.acquire()
        yield
    finally:
        # Release locks in reverse order of acquisition
        for lock in reversed(locks):
            lock.release()
        del acquired[-len(locks):]

import threading
x_lock = threading.Lock()
y_lock = threading.Lock()

def thread_1():
    while True:
        with acquire(x_lock, y_lock):
            print('Thread-1')

def thread_2():
    while True:
        with acquire(y_lock, x_lock):
            print('Thread-2')

t1 = threading.Thread(target=thread_1)
t1.daemon = True
t1.start()

t2 = threading.Thread(target=thread_2)
t2.daemon = True
t2.start()

如果你执行这段代码,你会发现它即使在不同的函数中以不同的顺序获取锁也没有发生死锁。 其关键在于,在第一段代码中,我们对这些锁进行了排序。通过排序,使得不管用户以什么样的顺序来请求锁,这些锁都会按照固定的顺序被获取

避免死锁

死锁的检测与恢复是一个几乎没有优雅的解决方案的扩展话题。一个比较常用的死锁检测与恢复的方案是引入看门狗计数器。当线程正常运行的时候会每隔一段时间重置计数器,在没有发生死锁的情况下,一切都正常进行。一旦发生死锁,由于无法重置计数器导致定时器超时,这时程序会通过重启自身恢复到正常状态。

避免死锁是另外一种解决死锁问题的方式,在进程获取锁的时候会严格按照对象id升序排列获取,经过数学证明,这样保证程序不会进入死锁状态。避免死锁的主要思想是,单纯地按照对象id递增的顺序加锁不会产生循环依赖,而循环依赖是 死锁的一个必要条件,从而避免程序进入死锁状态。

保存线程的状态信息

可使用 thread.local() 创建一个本地线程存储对象。 对这个对象的属性的保存和读取操作都只会对执行线程可见,而其他线程并不可见。

from socket import socket, AF_INET, SOCK_STREAM
import threading

class LazyConnection:
    def __init__(self, address, family=AF_INET, type=SOCK_STREAM):
        self.address = address
        self.family = AF_INET
        self.type = SOCK_STREAM
        self.local = threading.local()

    def __enter__(self):
        if hasattr(self.local, 'sock'):
            raise RuntimeError('Already connected')
        self.local.sock = socket(self.family, self.type)
        self.local.sock.connect(self.address)
        return self.local.sock

    def __exit__(self, exc_ty, exc_val, tb):
        self.local.sock.close()
        del self.local.sock

from functools import partial
def test(conn):
    with conn as s:
        s.send(b'GET /index.html HTTP/1.0\r\n')
        s.send(b'Host: www.python.org\r\n')

        s.send(b'\r\n')
        resp = b''.join(iter(partial(s.recv, 8192), b''))

    print('Got {} bytes'.format(len(resp)))

if __name__ == '__main__':
    conn = LazyConnection(('www.python.org', 80))

    t1 = threading.Thread(target=test, args=(conn,))
    t2 = threading.Thread(target=test, args=(conn,))
    t1.start()
    t2.start()
    t1.join()
    t2.join()

初始化为一个 threading.local() 实例。 其他方法操作被存储为 self.local.sock 的套接字对象。每个线程会创建一个自己专属的套接字连接(存储为self.local.sock)。 因此,当不同的线程执行套接字操作时,由于操作的是不同的套接字,因此它们不会相互影响。

其原理是,每个 threading.local() 实例为每个线程维护着一个单独的实例字典。 所有普通实例操作比如获取、修改和删除值仅仅操作这个字典。 每个线程使用一个独立的字典就可以保证数据的隔离了。

线程池

ThreadPoolExecutor 实现

from socket import AF_INET, SOCK_STREAM, socket
from concurrent.futures import ThreadPoolExecutor

def echo_client(sock, client_addr):
    '''
    Handle a client connection
    '''
    print('Got connection from', client_addr)
    while True:
        msg = sock.recv(65536)
        if not msg:
            break
        sock.sendall(msg)
    print('Client closed connection')
    sock.close()

def echo_server(addr):
    pool = ThreadPoolExecutor(128)
    sock = socket(AF_INET, SOCK_STREAM)
    sock.bind(addr)
    sock.listen(5)
    while True:
        client_sock, client_addr = sock.accept()
        pool.submit(echo_client, client_sock, client_addr)

echo_server(('',15000))

Queue 实现

from socket import socket, AF_INET, SOCK_STREAM
from threading import Thread
from queue import Queue

def echo_client(q):
    '''
    Handle a client connection
    '''
    sock, client_addr = q.get()
    print('Got connection from', client_addr)
    while True:
        msg = sock.recv(65536)
        if not msg:
            break
        sock.sendall(msg)
    print('Client closed connection')

    sock.close()

def echo_server(addr, nworkers):
    # Launch the client workers
    q = Queue()
    for n in range(nworkers):
        t = Thread(target=echo_client, args=(q,))
        t.daemon = True
        t.start()

    # Run the server
    sock = socket(AF_INET, SOCK_STREAM)
    sock.bind(addr)
    sock.listen(5)
    while True:
        client_sock, client_addr = sock.accept()
        q.put((client_sock, client_addr))

echo_server(('',15000), 128)

使用 ThreadPoolExecutor 相对于手动实现的一个好处在于它使得任务提交者更方便的从被调用函数中获取返回值。

from concurrent.futures import ThreadPoolExecutor
import urllib.request

def fetch_url(url):
    u = urllib.request.urlopen(url)
    data = u.read()
    return data

pool = ThreadPoolExecutor(10)
# Submit work to the pool
a = pool.submit(fetch_url, 'http://www.python.org')
b = pool.submit(fetch_url, 'http://www.pypy.org')

# Get the results back
x = a.result()
y = b.result()

a.result() 操作会阻塞进程直到对应的函数执行完成并返回一个结果。

当创建一个线程时,操作系统会预留一个虚拟内存区域来放置线程的执行栈(通常是8MB大小)。但是这个内存只有一小片段被实际映射到真实内存中。 因此,Python 进程使用到的真实内存其实很小 (比如,对于2000个线程来讲,只使用到了70MB的真实内存,而不是9GB)。 如果你担心虚拟内存大小,可以使用 threading.stack_size() 函数来降低它。注意线程栈大小必须至少为32768字节,通常是系统内存页大小(4096、8192等)的整数倍。

简单的并行编程

concurrent.futures 库提供了一个 ProcessPoolExecutor 类, 可被用来在一个单独的 Python 解释器中执行计算密集型函数。

# findrobots.py

import gzip
import io
import glob
from concurrent import futures

def find_robots(filename):
    '''
    Find all of the hosts that access robots.txt in a single log file
    '''
    robots = set()
    with gzip.open(filename) as f:
        for line in io.TextIOWrapper(f,encoding='ascii'):
            fields = line.split()
            if fields[6] == '/robots.txt':
                robots.add(fields[0])
    return robots

def find_all_robots(logdir):
    '''
    Find all hosts across and entire sequence of files
    '''
    files = glob.glob(logdir+'/*.log.gz')
    all_robots = set()
    # pool map 函数
    with futures.ProcessPoolExecutor() as pool:
        for robots in pool.map(find_robots, files):
            all_robots.update(robots)
    return all_robots

if __name__ == '__main__':
    robots = find_all_robots('logs')
    for ipaddr in robots:
        print(ipaddr)

pool 通用使用方式

# A function that performs a lot of work
def work(x):
    ...
    return result

# Nonparallel code
results = map(work, data)

# Parallel implementation
with ProcessPoolExecutor() as pool:
    results = pool.map(work, data)

# Some function
def work(x):
    ...
    return result

with ProcessPoolExecutor() as pool:
    ...
    # Example of submitting work to the pool
    future_result = pool.submit(work, arg)

    # Obtaining the result (blocks until done)
    r = future_result.result()
    ...

# callback

def when_done(r):
    print('Got:', r.result())

with ProcessPoolExecutor() as pool:
     future_result = pool.submit(work, arg)
     future_result.add_done_callback(when_done)

如果你手动提交一个任务,结果是一个 Future 实例。 要获取最终结果,你需要调用它的 result() 方法。 它会阻塞进程直到结果被返回来。如果不想阻塞,你还可以使用一个回调函数,回调函数接受一个 Future 实例,被用来获取最终的结果(比如通过调用它的result()方法)。

Note

  • 这种并行处理技术只适用于那些可以被分解为互相独立部分的问题。
  • 被提交的任务必须是简单函数形式。对于方法、闭包和其他类型的并行执行还不支持。
  • 函数参数和返回值必须兼容pickle,因为要使用到进程间的通信,所有解释器之间的交换数据必须被序列化
  • 被提交的任务函数不应保留状态或有副作用。除了打印日志之类简单的事情。

定义一个Actor任务

// TODO 理解

actor 模式是一种最古老的也是最简单的并行和分布式计算解决方案。一个 actor 就是一个并发执行的任务,只是简单的执行发送给它的消息任务。响应这些消息时,它可能还会给其他 actor 发送更进一步的消息。actor 之间的通信是单向和异步的。

from queue import Queue
from threading import Thread, Event

# Sentinel used for shutdown
class ActorExit(Exception):
    pass

class Actor:
    def __init__(self):
        self._mailbox = Queue()

    def send(self, msg):
        '''
        Send a message to the actor
        '''
        self._mailbox.put(msg)

    def recv(self):
        '''
        Receive an incoming message
        '''
        msg = self._mailbox.get()
        if msg is ActorExit:
            raise ActorExit()
        return msg

    def close(self):
        '''
        Close the actor, thus shutting it down
        '''
        self.send(ActorExit)

    def start(self):
        '''
        Start concurrent execution
        '''
        self._terminated = Event()
        t = Thread(target=self._bootstrap)
        t.daemon = True
        t.start()

    def _bootstrap(self):
        try:
            self.run()
        except ActorExit:
            pass
        finally:
            self._terminated.set()

    def join(self):
        self._terminated.wait()

    def run(self):
        '''
        Run method to be implemented by the user
        '''
        while True:
            msg = self.recv()

# Sample ActorTask
class PrintActor(Actor):
    def run(self):
        while True:
            msg = self.recv()
            print('Got:', msg)

# Sample use
p = PrintActor()
p.start()
p.send('Hello')
p.send('World')
p.close()
p.join()

使用actor实例的 send() 方法发送消息给它们。 其机制是,这个方法会将消息放入一个队里中, 然后将其转交给处理被接受消息的一个内部线程。 close() 方法通过在队列中放入一个特殊的哨兵值(ActorExit)来关闭这个actor。 用户可以通过继承Actor并定义实现自己处理逻辑 run() 方法来定义新的actor。 ActorExit 异常的使用就是用户自定义代码可以在需要的时候来捕获终止请求 (异常被 get() 方法抛出并传播出去)。

生成器方法

def print_actor():
    while True:
        try:
            msg = yield      # Get a message
            print('Got:', msg)
        except GeneratorExit:
            print('Actor terminating')

# Sample use
p = print_actor()
next(p)     # Advance to the yield (ready to receive)
p.send('Hello')
p.send('World')
p.close()

拓展

actor 模式的魅力就在于它的简单性。 实际上,这里仅仅只有一个核心操作 send()。 甚至,对于在基于 actor 系统中的“消息”的泛化概念可以已多种方式被扩展。

class TaggedActor(Actor):
    def run(self):
        while True:
             tag, *payload = self.recv()
             getattr(self,'do_'+tag)(*payload)

    # Methods correponding to different message tags
    def do_A(self, x):
        print('Running A', x)

    def do_B(self, x, y):
        print('Running B', x, y)

# Example
a = TaggedActor()
a.start()
a.send(('A', 1))      # Invokes do_A(1)
a.send(('B', 2, 3))   # Invokes do_B(2,3)

actor 允许在一个工作者中运行任意的函数, 并且通过一个特殊的 Result 对象返回结果。

from threading import Event
class Result:
    def __init__(self):
        self._evt = Event()
        self._result = None

    def set_result(self, value):
        self._result = value
        self._evt.set()

    def result(self):
        self._evt.wait()
        return self._result

class Worker(Actor):
    def submit(self, func, *args, **kwargs):
        r = Result()
        self.send((func, args, kwargs, r))
        return r

    def run(self):
        while True:
            func, args, kwargs, r = self.recv()
            r.set_result(func(*args, **kwargs))

# Example use
worker = Worker()
worker.start()
r = worker.submit(pow, 2, 3)
print(r.result())

实现消息发布/订阅模型

// TODO 理解

目的:基于线程通信的程序,想让它们实现发布/订阅模式的消息通信。

交换机

要实现发布/订阅的消息通信模式, 你通常要引入一个单独的“交换机”或“网关”对象作为所有消息的中介。

from collections import defaultdict

class Exchange:
    def __init__(self):
        self._subscribers = set()

    def attach(self, task):
        self._subscribers.add(task)

    def detach(self, task):
        self._subscribers.remove(task)

    def send(self, msg):
        for subscriber in self._subscribers:
            subscriber.send(msg)

# Dictionary of all created exchanges
_exchanges = defaultdict(Exchange)

# Return the Exchange instance associated with a given name
def get_exchange(name):
    return _exchanges[name]

一个交换机就是一个普通对象,负责维护一个活跃的订阅者集合,并为绑定、解绑和发送消息提供相应的方法。 每个交换机通过一个名称定位,get_exchange() 通过给定一个名称返回相应的 Exchange 实例。

# Example of a task.  Any object with a send() method

class Task:
    ...
    def send(self, msg):
        ...

task_a = Task()
task_b = Task()

# Example of getting an exchange
exc = get_exchange('name')

# Examples of subscribing tasks to it
exc.attach(task_a)
exc.attach(task_b)

# Example of sending messages
exc.send('msg1')
exc.send('msg2')

# Example of unsubscribing
exc.detach(task_a)
exc.detach(task_b)

使用发布订阅模式的优势

  • 使用一个交换机可以简化大部分涉及到线程通信的工作。
  • 交换机广播消息给多个订阅者的能力带来了一个全新的通信模式。
  • 兼容多个“task-like”对象。消息接受者可以是actor、协程、网络连接或任何实现了正确的 send() 方法的东西。

交换机拓展上下文管理

关于交换机的一个可能问题是对于订阅者的正确绑定和解绑。 为了正确的管理资源,每一个绑定的订阅者必须最终要解绑。

from contextlib import contextmanager
from collections import defaultdict

class Exchange:
    def __init__(self):
        self._subscribers = set()

    def attach(self, task):
        self._subscribers.add(task)

    def detach(self, task):
        self._subscribers.remove(task)

    @contextmanager
    def subscribe(self, *tasks):
        for task in tasks:
            self.attach(task)
        try:
            yield
        finally:
            for task in tasks:
                self.detach(task)

    def send(self, msg):
        for subscriber in self._subscribers:
            subscriber.send(msg)

# Dictionary of all created exchanges
_exchanges = defaultdict(Exchange)

# Return the Exchange instance associated with a given name
def get_exchange(name):
    return _exchanges[name]

# Example of using the subscribe() method
exc = get_exchange('name')
with exc.subscribe(task_a, task_b):
     ...
     exc.send('msg1')
     exc.send('msg2')
     ...

# task_a and task_b detached here

生成器代替线程

yield 语句会让一个生成器挂起它的执行,这样就可以编写一个调度器, 将生成器当做某种“任务”并使用任务协作切换来替换它们的执行。

任务调度器

# Two simple generator functions
def countdown(n):
    while n > 0:
        print('T-minus', n)
        yield
        n -= 1
    print('Blastoff!')

def countup(n):
    x = 0
    while x < n:
        print('Counting up', x)
        yield
        x += 1

from collections import deque

class TaskScheduler:
    def __init__(self):
        self._task_queue = deque()

    def new_task(self, task):
        '''
        Admit a newly started task to the scheduler

        '''
        self._task_queue.append(task)

    def run(self):
        '''
        Run until there are no more tasks
        '''
        while self._task_queue:
            task = self._task_queue.popleft()
            try:
                # Run until the next yield statement
                next(task)
                self._task_queue.append(task)
            except StopIteration:
                # Generator is no longer executing
                pass

# Example use
sched = TaskScheduler()
sched.new_task(countdown(10))
sched.new_task(countdown(5))
sched.new_task(countup(15))
sched.run()

# T-minus 5
# Counting up 0
# T-minus 9
# T-minus 4
# Counting up 1
# T-minus 8
# T-minus 3
# Counting up 2
# T-minus 7
# T-minus 2
# ...

TaskScheduler 类在一个循环中运行生成器集合——每个都运行到碰到yield语句为止。到此为止,我们实际上已经实现了一个“操作系统”的最小核心部分。 生成器函数就是认为,而 yield 语句是任务挂起的信号。调度器循环检查任务列表直到没有任务要执行为止。

生成器版 actor

from collections import deque

class ActorScheduler:
    def __init__(self):
        self._actors = { }          # Mapping of names to actors
        self._msg_queue = deque()   # Message queue

    def new_actor(self, name, actor):
        '''
        Admit a newly started actor to the scheduler and give it a name
        '''
        self._msg_queue.append((actor,None))
        self._actors[name] = actor

    def send(self, name, msg):
        '''
        Send a message to a named actor
        '''
        actor = self._actors.get(name)
        if actor:
            self._msg_queue.append((actor,msg))

    def run(self):
        '''
        Run as long as there are pending messages.
        '''
        while self._msg_queue:
            actor, msg = self._msg_queue.popleft()
            try:
                 actor.send(msg)
            except StopIteration:
                 pass

# Example use
if __name__ == '__main__':
    def printer():
        while True:
            msg = yield
            print('Got:', msg)

    def counter(sched):
        while True:
            # Receive the current count
            n = yield
            if n == 0:
                break
            # Send to the printer task
            sched.send('printer', n)
            # Send the next count to the counter task (recursive)
            sched.send('counter', n-1)

    sched = ActorScheduler()
    # Create the initial actors
    sched.new_actor('printer', printer())
    sched.new_actor('counter', counter(sched))

    # Send an initial message to the counter to initiate
    sched.send('counter', 10000)
    sched.run()

生成器版并发网络程序

from collections import deque
from select import select

# This class represents a generic yield event in the scheduler
class YieldEvent:
    def handle_yield(self, sched, task):
        pass
    def handle_resume(self, sched, task):
        pass

# Task Scheduler
class Scheduler:
    def __init__(self):
        self._numtasks = 0       # Total num of tasks
        self._ready = deque()    # Tasks ready to run
        self._read_waiting = {}  # Tasks waiting to read
        self._write_waiting = {} # Tasks waiting to write

    # Poll for I/O events and restart waiting tasks
    def _iopoll(self):
        rset,wset,eset = select(self._read_waiting,
                                self._write_waiting,[])
        for r in rset:
            evt, task = self._read_waiting.pop(r)
            evt.handle_resume(self, task)
        for w in wset:
            evt, task = self._write_waiting.pop(w)
            evt.handle_resume(self, task)

    def new(self,task):
        '''
        Add a newly started task to the scheduler
        '''

        self._ready.append((task, None))
        self._numtasks += 1

    def add_ready(self, task, msg=None):
        '''
        Append an already started task to the ready queue.
        msg is what to send into the task when it resumes.
        '''
        self._ready.append((task, msg))

    # Add a task to the reading set
    def _read_wait(self, fileno, evt, task):
        self._read_waiting[fileno] = (evt, task)

    # Add a task to the write set
    def _write_wait(self, fileno, evt, task):
        self._write_waiting[fileno] = (evt, task)

    def run(self):
        '''
        Run the task scheduler until there are no tasks
        '''
        while self._numtasks:
             if not self._ready:
                  self._iopoll()
             task, msg = self._ready.popleft()
             try:
                 # Run the coroutine to the next yield
                 r = task.send(msg)
                 if isinstance(r, YieldEvent):
                     r.handle_yield(self, task)
                 else:
                     raise RuntimeError('unrecognized yield event')
             except StopIteration:
                 self._numtasks -= 1

# Example implementation of coroutine-based socket I/O
class ReadSocket(YieldEvent):
    def __init__(self, sock, nbytes):
        self.sock = sock
        self.nbytes = nbytes
    def handle_yield(self, sched, task):
        sched._read_wait(self.sock.fileno(), self, task)
    def handle_resume(self, sched, task):
        data = self.sock.recv(self.nbytes)
        sched.add_ready(task, data)

class WriteSocket(YieldEvent):
    def __init__(self, sock, data):
        self.sock = sock
        self.data = data
    def handle_yield(self, sched, task):

        sched._write_wait(self.sock.fileno(), self, task)
    def handle_resume(self, sched, task):
        nsent = self.sock.send(self.data)
        sched.add_ready(task, nsent)

class AcceptSocket(YieldEvent):
    def __init__(self, sock):
        self.sock = sock
    def handle_yield(self, sched, task):
        sched._read_wait(self.sock.fileno(), self, task)
    def handle_resume(self, sched, task):
        r = self.sock.accept()
        sched.add_ready(task, r)

# Wrapper around a socket object for use with yield
class Socket(object):
    def __init__(self, sock):
        self._sock = sock
    def recv(self, maxbytes):
        return ReadSocket(self._sock, maxbytes)
    def send(self, data):
        return WriteSocket(self._sock, data)
    def accept(self):
        return AcceptSocket(self._sock)
    def __getattr__(self, name):
        return getattr(self._sock, name)

if __name__ == '__main__':
    from socket import socket, AF_INET, SOCK_STREAM
    import time

    # Example of a function involving generators.  This should
    # be called using line = yield from readline(sock)
    def readline(sock):
        chars = []
        while True:
            c = yield sock.recv(1)
            if not c:
                break
            chars.append(c)
            if c == b'\n':
                break
        return b''.join(chars)

    # Echo server using generators
    class EchoServer:
        def __init__(self,addr,sched):
            self.sched = sched
            sched.new(self.server_loop(addr))

        def server_loop(self,addr):
            s = Socket(socket(AF_INET,SOCK_STREAM))

            s.bind(addr)
            s.listen(5)
            while True:
                c,a = yield s.accept()
                print('Got connection from ', a)
                self.sched.new(self.client_handler(Socket(c)))

        def client_handler(self,client):
            while True:
                line = yield from readline(client)
                if not line:
                    break
                line = b'GOT:' + line
                while line:
                    nsent = yield client.send(line)
                    line = line[nsent:]
            client.close()
            print('Client closed')

    sched = Scheduler()
    EchoServer(('',16000),sched)
    sched.run()

如果使用生成器编程,要提醒你的是它还是有很多缺点的。 特别是,你得不到任何线程可以提供的好处。例如,如果你执行CPU依赖或I/O阻塞程序, 它会将整个任务挂起知道操作完成。为了解决这个问题, 你只能选择将操作委派给另外一个可以独立运行的线程或进程。 另外一个限制是大部分Python库并不能很好的兼容基于生成器的线程。

多个线程队列轮询

目的:你有一个线程队列集合,想为传入的参数轮询.

对于轮询问题的一个常见解决方案中有个很少有人知道的技巧,包含了一个隐藏的回路网络连接。 本质上讲其思想就是:对于每个你想要轮询的队列,你创建一对连接的套接字。 然后你在其中一个套接字上面编写代码来标识存在的数据, 另外一个套接字被传给 select() 或类似的一个轮询数据到达的函数。

import queue
import socket
import os

class PollableQueue(queue.Queue):
    def __init__(self):
        super().__init__()
        # Create a pair of connected sockets
        if os.name == 'posix':
            self._putsocket, self._getsocket = socket.socketpair()
        else:
            # Compatibility on non-POSIX systems
            server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            server.bind(('127.0.0.1', 0))
            server.listen(1)
            self._putsocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self._putsocket.connect(server.getsockname())
            self._getsocket, _ = server.accept()
            server.close()

    def fileno(self):
        return self._getsocket.fileno()

    def put(self, item):
        super().put(item)
        self._putsocket.send(b'x')

    def get(self):
        self._getsocket.recv(1)
        return super().get()

# 消费者

import select
import threading

def consumer(queues):
    '''
    Consumer that reads data on multiple queues simultaneously
    '''
    while True:
        can_read, _, _ = select.select(queues,[],[])
        for r in can_read:
            item = r.get()
            print('Got:', item)

q1 = PollableQueue()
q2 = PollableQueue()
q3 = PollableQueue()
t = threading.Thread(target=consumer, args=([q1,q2,q3],))
t.daemon = True
t.start()

# Feed data to the queues
q1.put(1)
q2.put(10)
q3.put('hello')
q2.put(15)

fileno() 方法使用一个函数比如 select() 来让这个队列可以被轮询。 它仅仅只是暴露了底层被 get() 函数使用到的socket的文件描述符而已。

这个方案通过将队列和套接字等同对待来解决了大部分的问题。 一个单独的 select() 调用可被同时用来轮询。 使用超时或其他基于时间的机制来执行周期性检查并没有必要。 甚至,如果数据被加入到一个队列,消费者几乎可以实时的被通知。 尽管会有一点点底层的I/O损耗,使用它通常会获得更好的响应时间并简化编程。

Unix 系统上面启动守护进程

tdonald

#!/usr/bin/env python3
# daemon.py

import os
import sys

import atexit
import signal

def daemonize(pidfile, *, stdin='/dev/null',
                          stdout='/dev/null',
                          stderr='/dev/null'):

    if os.path.exists(pidfile):
        raise RuntimeError('Already running')

    # First fork (detaches from parent)
    try:
        if os.fork() > 0:
            raise SystemExit(0)   # Parent exit
    except OSError as e:
        raise RuntimeError('fork #1 failed.')

    os.chdir('/')
    os.umask(0)
    # 使进程成为会话组长,脱离终端
    os.setsid()
    # Second fork (relinquish session leadership)
    try:
        if os.fork() > 0:
            raise SystemExit(0)
    except OSError as e:
        raise RuntimeError('fork #2 failed.')

    # Flush I/O buffers
    sys.stdout.flush()
    sys.stderr.flush()

    # Replace file descriptors for stdin, stdout, and stderr
    with open(stdin, 'rb', 0) as f:
        os.dup2(f.fileno(), sys.stdin.fileno())
    with open(stdout, 'ab', 0) as f:
        os.dup2(f.fileno(), sys.stdout.fileno())
    with open(stderr, 'ab', 0) as f:
        os.dup2(f.fileno(), sys.stderr.fileno())

    # Write the PID file
    with open(pidfile,'w') as f:
        print(os.getpid(),file=f)

    # Arrange to have the PID file removed on exit/signal
    atexit.register(lambda: os.remove(pidfile))

    # Signal handler for termination (required)
    def sigterm_handler(signo, frame):
        raise SystemExit(1)

    signal.signal(signal.SIGTERM, sigterm_handler)

def main():
    import time
    sys.stdout.write('Daemon started with pid {}\n'.format(os.getpid()))
    while True:
        sys.stdout.write('Daemon Alive! {}\n'.format(time.ctime()))
        time.sleep(10)

if __name__ == '__main__':
    PIDFILE = '/tmp/daemon.pid'

    if len(sys.argv) != 2:
        print('Usage: {} [start|stop]'.format(sys.argv[0]), file=sys.stderr)
        raise SystemExit(1)

    if sys.argv[1] == 'start':
        try:
            daemonize(PIDFILE,
                      stdout='/tmp/daemon.log',
                      stderr='/tmp/dameon.log')
        except RuntimeError as e:
            print(e, file=sys.stderr)
            raise SystemExit(1)

        main()

    elif sys.argv[1] == 'stop':
        if os.path.exists(PIDFILE):
            with open(PIDFILE) as f:
                os.kill(int(f.read()), signal.SIGTERM)
        else:
            print('Not running', file=sys.stderr)
            raise SystemExit(1)

    else:
        print('Unknown command {!r}'.format(sys.argv[1]), file=sys.stderr)
        raise SystemExit(1)