2018/09/25

Python Tips: functools.reduce() を活用したい

Python が標準で提供する関数のひとつに functoolsreduce() があります。

from functools import reduce

reduce() は一見使いどころがわかりづらいのですが、活用できるようになるととても便利な関数です。今回はそんな reduce() について、 使いどころと実践的なサンプル をご紹介してみたいと思います。

使いどころ


reduce() は「 シーケンス → ひとつの値 」という処理をしたいときに利用できる関数です。関数でいえば「引数にシーケンスを受け取り、何らかの処理をして、戻り値をひとつだけ返す」というふるまいの関数を書きたいときに使えます。

このパターンは標準的なビジネスロジックの中にたくさん登場します。いくつか例をあげてみます。

注文の合計金額を計算する: 「注文のラインアイテム」というシーケンスから「合計金額」を計算する。

複数の権限がすべて満たされているかをチェックする: 「複数の権限」というシーケンスから「操作の可否」を計算する。

複数のロールのうちひとつでもユーザが所属するものがあるかをチェックする: 「複数のロール」というシーケンスから「ユーザの所属の有無」を計算する。

リストを連結する: 「複数のリスト」というシーケンスから「連結したリスト」を計算する。

集合の共通部分を取る: 「複数の集合」というシーケンスから「共通部分」を計算する。

集合の和集合を取る: 「複数の集合」というシーケンスから「和集合」を計算する。

複数のフラグをまとめる: 「複数のフラグ」というシーケンスから「合算したフラグ」を計算する。

reduce() については「 reduce() を使うと [3, 5, 10]150 といった数列の積の計算がかんたんにできます」といったシンプルな例を使った說明がよくなされますが、それだけ聞いても「え、それ、業務のプログラミングで使います?」となりがちです。対して、上のような実際にありそうな例を見ると「 reduce() が活用できる場面は意外に多いなぁ」と思えるのではないかと思うのですがいかがでしょうか。

続いて、上にあげた例のいくつかに対してサンプルコードを見てみましょう。

実践的なサンプル


注文の合計金額を計算する

注文のラインアイテムから合計金額を計算します。

from collections import namedtuple
from functools import reduce

LineItem = namedtuple('LineItem', ['商品ID', '単価', '数量'])

def calc_total(items):
    """ラインアイテムの合計金額を計算する"""
    return reduce(lambda accum, line_item: accum + line_item.単価 * line_item.数量, items, 0)

line_items = [
    LineItem('shirt a', 1000, 2),
    LineItem('shirt b', 1200, 1),
    LineItem('shirt c', 4000, 2),
]

calc_total(line_items)
# => 11200

関数 calc_total() は各ラインアイテムの単価と数量を見てその合計金額を計算する関数です。 reduce() の中の accum には「単価 ✕ 数量」の値が蓄積され、最終的に注文全体の合計金額が出ます。

余談ですが、 accum の値がどのように変化していくのかを見たい場合は lambda をローカル関数に差し替えて中身を出力するとよいでしょう。

def calc_total(items):
    """ラインアイテムの合計金額を計算する"""
    def _sumproduct(accum, line_item):
        print(accum)
        return accum + line_item.単価 * line_item.数量

    return reduce(_sumproduct, items, 0)
    # return reduce(lambda accum, line_item: accum + line_item.単価 * line_item.数量, items, 0)

calc_total(line_items)
# 0
# 2000
# 3200
# => 11200

ただ、この処理は、例えば内包表記を使って次のように書くこともできます。

def calc_total(line_items):
    return sum(x.単価 * x.数量 for x in line_items)

reduce() を使った書き方がしっくり来ない場合はムリに reduce() を使わず、そのときどきでコードの意図がわかりやすい・メンテナンスしやすい書き方を選ぶとよいかと思います。

複数の権限がすべて満たされているかをチェックする

あらかじめ定義されている権限を対象ユーザがすべて持っているかどうかをチェックします。

from functools import reduce


class User:
    def __init__(self, permissions):
        self._permissions = permissions

    def has_perm(self, permission):
        return permission in self._permissions

    def has_all_perms(self, permissions):
        """指定された権限をすべて持っているかチェックする"""
        return reduce(lambda accum, p: accum and self.has_perm(p), permissions, True)


required_permissions = ('perm_a', 'perm_b', 'perm_c')


user1 = User(('perm_c',))
print(user1.has_all_perms(required_permissions))
# => False

user2 = User(('perm_a', 'perm_b', 'perm_c'))
print(user2.has_all_perms(required_permissions))
# => True

クラス User のメソッド has_all_perms() は、 list あるいは tuple で渡された一連の権限をユーザがすべて持っているかどうかをチェックします。 reduce() を使うことでわかりやすくシンプルに書くことができています。

上の「注文の合計金額」が sum() を使って書けたのと同様に、こちらは all() と内容表記を使って書くこともできます。

def has_all_perms(self, permissions):
        """指定された権限をすべて持っているかチェックする"""
        return all(self.has_perm(p) for p in permissions)

こちらも all() の方が読みやすくメンテナンスしやすいのであればムリに reduce() を使う必要はありません。

ちなみに、上の「 複数のロールのうちひとつでもユーザが所属するものがあるかをチェックする 」はこのチェック対象のデータが権限からロールに変わって all()any() に変わるようなイメージです。

リストを連結する

各要素がリストのリストに対して、その要素をすべて連結した大きなリストを生成します。

from functools import reduce


def concat_lists(lists):
    """リストを連結する"""
    return reduce(lambda accum, x: accum + x, lists, [])


lists1 = [[1, 2, 3], [5, 8], [13, 21, 34, 55]]
print(concat_lists(lists1))
# => [1, 2, 3, 5, 8, 13, 21, 34, 55]

関数 concat_lists() は、リストからなるリストの各要素を連結します。こちらも reduce() を使うことで、わかりやすく簡潔に書くことができています。

もし処理の途中経過が見たければ、上で述べたように lambda の部分を print() 等で出力を行うローカル関数に差し替えるとよいでしょう。

def concat_lists(lists):
    """リストを連結する"""
    def _extend(accum, x):
        print(accum)
        return accum + x
    return reduce(_extend, lists, [])
    # return reduce(lambda accum, x: accum + x, lists, [])

lists1 = [[1, 2, 3], [5, 8], [13, 21, 34, 55]]
print(concat_lists(lists1))
# []
# [1, 2, 3]
# [1, 2, 3, 5, 8]
# [1, 2, 3, 5, 8, 13, 21, 34, 55]
# => [1, 2, 3, 5, 8, 13, 21, 34, 55]

この concat_lists() の処理は sum() を使って書くことも可能です。

def concat_lists(lists):
    """リストを連結する"""
    return sum(lists, [])

この単純な例だと sum() を使った方がむしろわかりやすいかもしれませんね。

集合の和集合を取る

複数ある集合の和集合を計算します。

from functools import reduce


def union_multiple(*sets):
    """和集合を作る"""
    return reduce(lambda accum, x: accum | x, sets, set({}))


set1 = {'鹿児島', '宮崎', '熊本'}
set2 = {'大分', '佐賀'}
set3 = {'長崎', '福岡'}

print(union_multiple(set1, set2, set3))
# => {'熊本', '長崎', '福岡', '大分', '宮崎', '佐賀', '鹿児島'}

関数 union_multiple() は任意の数の引数( set )を受け取って、それらの和集合を返す関数です。 Python では 2 つの集合の和集合は set_a | set_b (あるいは set_a.union(set_b) )で計算できるので、ここで必要なのは reduce() を使ってそれを蓄積していくことだけです。

標準ライブラリ operator になじみがあって、コードのわかりやすさが損なわれないと考えられるのであれば、同じ処理は operator.or_ を使って次のように書いてもよいでしょう。

from functools import reduce
from operator import or_


def union_multiple(*sets):
    """和集合を作る"""
    # return reduce(lambda accum, x: accum | x, sets, set({}))
    return reduce(or_, sets, set({}))

Python では +-*/&| 等の 2 項演算子をオブジェクトとして扱うことができませんが、その代わりに operator の中にそれらに相当する関数が用意されています。

一例:

  • +: operator.add() または operator.concat()
  • -: operator.sub()
  • *: operator.mul()
  • /: operator.truediv()
  • &: operator.and_()
  • |: operator.or_()

ちなみに、 reduce()operator が提供する関数を知っておくとよりシンプルにわかりやすく書けることがあるので、 reduce() を活用したい方は operator もあわせてチェックしておくとよいかもしれません。

尚、上の「 集合の和集合を取る 」と「 複数のフラグをまとめる 」については、ここではサンプルコードは示しませんがこの union_multiple() と同じような考え方で書くことができます。

応用的なサンプル


ちょっと応用的(トリッキー?)な例もいくつかあげてみます。 reduce() をむやみやたらと使ってしまうとかえってコードがわかりづらくなりますが、こういう使い方もできるということを知っておくと実装の選択肢の幅が広がってよいのではないかと思います。

入れ子の辞書の要素にアクセスする

これは Stack Overflow で紹介されていた使い方です。私は目からウロコでした。

入れ子になった dict がありそれを掘り下げる一連のキーが変数として与えられたときに、要素にアクセスするロジックを reduce() を使ってスムーズに書くことができます。

from functools import reduce


def dict_deep_access(adict, key_tree):
    """ネストされた dict の要素にアクセスする"""
    return reduce(lambda elem, key: elem[key], key_tree, adict)
    # 次のように書くこともできる
    # return reduce(dict.__getitem__, key_tree, adict)


d1 = {
    '沖縄': {
        '恩納村': {
            '万座毛': '隆起サンゴの断崖',
            'なかゆくい市場': '道の駅',
            '真栄田岬': '夕日スポット',
        }
    }
}

key_trees = [('沖縄', '恩納村', '万座毛'), ('沖縄', '恩納村', '真栄田岬')]

for tree in key_trees:
    print(dict_deep_access(d1, tree))
# =>
# 隆起サンゴの断崖
# 夕日スポット

特定のディレクトリからの相対パスでアクセス

特定のディレクトリから対象ファイルへの相対パスが list として与えられたときに、対象ファイルのパスを生成するというものです。

from functools import reduce
from pathlib import Path


def file_relative_from(root, relative_path):
    """指定された場所からの相対パスでファイルを取得する"""
    file = reduce(lambda parent, dir: parent / dir, relative_path, root)
    return file.resolve()


path1 = Path('/private/tmp/abc/abc1.txt')
relative_path1 = ['..', '..', 'def', 'def1.txt']

file_relative_from(path1, relative_path1)
# => /private/tmp/def/def1.txt

関数の組み合わせ

複数の関数を組み合わせて順次適用していくというものです。

from functools import reduce


def apply_filters(seq, filters):
    """シーケンスに複数のフィルタを連続で適用する"""
    return reduce(lambda result, fn: fn(result), filters, seq)


filters = (
    lambda seq: [x.upper() for x in seq],
    lambda seq: [x.center(12) for x in seq],
    lambda seq: ['⚡ {} ⚡'.format(x) for x in seq],
    lambda seq: '\n'.join(seq),
)

関数 apply_filters() は、第 2 引数に callable のシーケンスを受け取り、その要素を第 1 引数に順次適用した結果を返してくれます。

apply_filters()filters を組み合わせると次のような処理になります。

apply_filters(['teenage', 'mutant', 'ninja', 'turtles'], filters)
# =>
# ⚡   TEENAGE    ⚡
# ⚡    MUTANT    ⚡
# ⚡    NINJA     ⚡
# ⚡   TURTLES    ⚡

apply_filters(['genetically', 'modified', 'punk', 'rock', 'pandas'], filters)
# =>
# ⚡ GENETICALLY  ⚡
# ⚡   MODIFIED   ⚡
# ⚡     PUNK     ⚡
# ⚡     ROCK     ⚡
# ⚡    PANDAS    ⚡

この処理は例えば map() や内包表記を使うと次のように書くこともできます。

# 内包表記で書いた場合
def apply_filters_hard_1(seq):
    return '\n'.join(
        '⚡ {} ⚡'.format(x) for x in (
            x.center(12) for x in (
                x.upper() for x in seq
            )
        )
    )

# map() で書いた場合
def apply_filters_hard_2(seq):
    return '\n'.join(
        map(
            lambda x: '⚡ {} ⚡'.format(x),
            map(lambda x: x.center(12),
                map(lambda x: x.upper(),
                    seq
                )
            )
        )
    )

・・・が、これだと関数の適用順序とは逆に lambda を書かなくてはなりません。このように書くぐらいであれば各処理の結果を変数に格納し行を分けて書いた方がよいでしょう。

サンプルの紹介は以上です。

まとめ


というわけで、 Python の reduce() の使いどころと実践的なサンプルについてでした。

これは受け売りですが、 reduce() については、要は「 sum()all()any() 等の「シーケンスからひとつの値を生成する」タイプの処理を抽象化したものが reduce() 」という捉え方ができます。もし OOP のクラスと同じような親子関係(継承関係)が関数にもあるなら、 reduce()sum()all() の親関数、と捉えてもよさそうです。

余談ですが、 Python ・ Ruby ・ JavaScript あたりではこの操作を行う関数に reduce という名前がつけられていますが、関数型ベースの言語においては reduce よりも fold という名前が選ばれることが多いような気がします。

以上です。

関連記事


参考

reduce() については以下のページ等もおもしろいです。

2018/09/10

Python のアトリビュート取り扱いの仕組み

Python で オブジェクトのアトリビュートへのアクセスがあったときに内部で起こっていること について説明してみます。

# オブジェクトのアトリビュートへのアクセスがあると・・・?
obj.attr1

私は他の言語においてこのあたりの仕組みをよく理解してないため厳密な比較はできませんが、 Python のこの仕組みはとてもユニークでおもしろいと思います。

早速説明していきます。馴染みの無い方にとっては少し複雑なので、ざっくりとした説明から始めて徐々により詳細で厳密な説明へと進んでいきます。

目次


  • レベル 1: 基本 1
  • レベル 2: 基本 2
  • レベル 3: 基本 3
  • レベル 4: 発展 1
  • レベル 5: 発展 2
  • レベル 6: 発展 3


レベル 1: 基本 1


アトリビュートへのアクセスがあると、オブジェクト固有のデータ保持領域で要素が探索される。要素が存在すればその値が返され、存在しなければ AttributeError があがる。

これは標準的なオブジェクト指向の概念に慣れている方にとっては直感的な挙動ですね。

「オブジェクト固有のデータ保持領域」というのは、具体的には各オブジェクトに備わった __dict__ アトリビュートのことを指しています。 __dict__ はデフォルトでは空の dict です。

コードで確認してみましょう。

class A:
    pass

a1 = A()
# アトリビュートがセットされなければ、 __dict__ の初期状態は空の dict
print(a1.__dict__)
# => {}

# アトリビュートへのアクセスがあると __dict__ 内の該当する要素が返される
a1.__dict__['attr1'] = 10
print(a1.attr1)
# => 10

# __dict__ 内に該当する要素がなければ AttributeError があがる
a2 = A()
print(a2.attr1)
# => AttributeError

a1 については、あらかじめ __dict__attr1 というキーで要素を格納したあとに a1.attr1 にアクセスしています。 a1.attr1 にアクセスすると a1.__dict__['attr1'] の値が返されることが確認できています。

一方、 a2 では前準備などせずすぐに a2.attr1 にアクセスしています。結果、例外 AttributeError があがります。

まずはこれが基本です。

レベル 2: 基本 2


レベル 1 の説明を少し更新します。

アトリビュートへのアクセスがあると、オブジェクト固有のデータ保持領域で要素が探索される。要素が存在すればその値が返される。 存在しなければ、オブジェクトのクラスが __getattr__ メソッドを持っているかどうかがチェックされる。持っていればそれが呼び出され、その戻り値が返される。 持っていなければ AttributeError があがる。

強調部分がレベル 1 との違いです。

レベル 1 では「オブジェクト固有のデータ保持領域に該当する要素が存在しなければ AttributeError があがる」と説明しましたが、実はその間にはプログラマが自由に処理をはさめるようになっていて、そのための仕組みが __getattr__ メソッドです。

コードで確認してみましょう。

class B:
    def __getattr__(self, name):
        return name

# __dict__ 内に該当する要素がなくて、クラスが __getattr__ を定義していればその戻り値が返される
b1 = B()
print(b1.attr1)
# => 'attr1'

a1 のところで見たとおり、コンストラクタで何もせず単純にオブジェクトを生成すると、そのオブジェクトの __dict__ は空の dict となります。 b1__dict__ は空なので、キー attr1 に対応する要素は存在しません。結果、 __getattr__ の呼び出しが発生し、その戻り値が返されます。

__getattr__ の引数 name には オブジェクト.アトリビュートアトリビュート に相当する文字列が渡されます。つまり、 b1.attr1 が実行された場合の name には文字列 'attr1' が格納されています。 B__getattr__ は戻り値として name をそのまま返しているので、結果として b1.attr1 にアクセスすると文字列 'attr1' が返ってきます。

ここでは説明のために __getattr__name を返す単純は実装にしていますが、実践的なコードではここにさまざまな工夫を加えます。例えば次のようにすると、データ保持領域の要素を int に変換して返させることができます。

class B2:
    def __getattr__(self, name):
        # 実在するアトリビュートの後ろに _as_int をつけた名前に対応する
        if name.endswith('_as_int'):
            stripped = name[:-len('_as_int')]
            if stripped in self.__dict__:
                return int(self.__dict__[stripped])
        raise AttributeError()

b2 = B2()
b2.pi = 3.14
b2.radius = 5.25

# __getattr__ の戻り値が返される
print(b2.pi_as_int)
# => 3
print(b2.radius_as_int)
# => 5

レベル 3: 基本 3


アトリビュートへのアクセスがあると、オブジェクト固有のデータ保持領域で要素が探索される。要素が存在すればその値が返される。存在しなければ、 オブジェクトのクラスのデータ保持領域で要素が探索される。存在しなければ、継承をたどってすべての親のデータ保持領域で要素が探索される。それでも存在しなければ、 オブジェクトのクラスが __getattr__ メソッドを持っているかどうかがチェックされる。持っていればそれが呼び出され、その戻り値が返される。持っていなければ AttributeError があがる。

強調部分がレベル 2 との違いです。

レベル 2 では「オブジェクトのデータ保持領域で対応する要素が見つからなかった場合は __getattr__ メソッドを持っているかどうかがチェックされ・・・」と説明しましたが、実は、「オブジェクトのデータ保持領域に要素が見つからなかった」と「 __getattr__ メソッドを持っているかどうかがチェックされ」の間に、クラスのデータ保持領域での要素の探索が発生します。

コードで確認してみましょう。

class C:
    attr1 = 10

    def __getattr__(self, name):
        return name

# クラスアトリビュートは __dict__ に格納される
print(C.__dict__['attr1'])
# => 10

c1 = C()

# クラス C の __dict__['attr1'] が返される
print(c1.attr1)
# => 10

# __getattr__ の戻り値が返される
print(c1.attr2)
# => 'attr2'

オブジェクト c1 において c1.attr1 にアクセスすると、クラス C で定義されたアトリビュート attr1 の値が返されました。これは内部的には C のデータ保持領域である C.__dict__ に格納された値です。クラスのデータ保持領域に該当する要素が見つかった場合は、メソッド __getattr__ の呼び出しは発生しません。

このサンプルでは示されていませんが、クラスのデータ保持領域の探索はクラスの継承関係を辿って行われます。つまり、 C.__dict__['attr1'] が存在しなかった場合はその親クラスの __dict__['attr1'] が探索され、そこに無ければまた親の・・・と続いていきます。最終的な親である object.__dict__ まで探索しても要素が見つからなかった場合は、レベル 2 での説明のとおり __getattr__ へと処理が移っていきます。上の c1.attr2 でのアクセスではまさにこの流れを辿った末に値が返された結果 'attr2' という文字列が返ってきています。

内部的には __dict__ が介在していますが、表面的には単純に「オブジェクトに該当するアトリビュートがなければ、クラスの同名のアトリビュートにフォールバックする」という挙動になるので、このあたりはふだん利用するときには難しく考えなくても直感的に利用できるでしょう。

と、ここまでは他の言語でもわりとよく見られるパターンなので、何らかのプログラミング言語に馴染みのある方であればすんなり受け入れられるところではないかと思います。続いて、 Python の特徴である descriptor (ディスクリプタ)も含めた説明へと進みます。

レベル 4: 発展 1


アトリビュートへのアクセスがあると、オブジェクト固有のデータ保持領域で要素が探索される。要素が存在すればその値が返される。存在しなければ、オブジェクトのクラスのデータ保持領域で要素が探索される。 該当する要素が存在した場合、それが __get__ メソッドを備えていれば __get__ メソッドが呼ばれ、その戻り値が返される。 __get__ メソッドを備えていなければそのオブジェクトそのものが返される。 同様の探索が、継承をたどってすべての親のデータ保持領域で行われる。それでも存在しなければ、オブジェクトのクラスが __getattr__ メソッドを持っているかどうかがチェックされる。持っていればそれが呼び出され、その戻り値が返される。持っていなければ AttributeError があがる。

強調部分がレベル 3 との違いです。ここで新たに __get__ メソッドというものが出てきました。

レベル 2 で述べたとおり、オブジェクトのアトリビュートへのアクセスが起こったときに、オブジェクトそのもののデータ保持領域 __dict__ に該当する要素がなければ、クラスのデータ保持領域 __dict__ での探索が行われます。その際、ふつうはその要素(=オブジェクト)そのものが返されるのですが、それが __get__ メソッドを持っている場合にかぎり、 __get__ メソッドが実行され、その戻り値が返されます。

ことばでの説明だけだと意味が分かりづらいですね。コードで確認してみましょう。

class Descriptor1:
    def __init__(self, name):
        self._name = name

    def __get__(self, instance, owner):
        print(self, instance, owner)
        return '{}.__get__ for {}'.format(self.__class__.__name__, self._name)

class D:
    attr1 = Descriptor1(name='attr1')
    attr2 = 10

d1 = D()

# Descriptor1 の __get__ メソッドの戻り値が返される
d1.attr1
# => 'Descriptor1.__get__ for attr1'

# __get__ を持たないアトリビュートの場合は値がそのまま返される
d1.attr2
# => 10

d1.attr1 にアクセスすると 'Descriptor1.__get__ for attr1' という文字列が返ってきます。これは Descriptor1 で定義されているメソッド __get__ の戻り値です。

通常、オブジェクトのアトリビュートへのアクセスでクラスのアトリビュートの参照が発生するとその値がそのまま返されますが、そのアトリビュートの値が __get__ メソッドを持っている場合にかぎり __get__ メソッドが実行されその戻り値が返されます。

これが Python のいわゆる descriptor です。 Python の descriptor とは「そのインスタンスが他のクラスのアトリビュートとして利用されたときに特殊な挙動をするクラス」です。

descriptor プロトコルを構成するメソッドは __get__ の他に __set____delete__ があります。

ちなみに、上の Descriptor1__init__ メソッドは、 descriptor オブジェクト自身がアトリビュート名 を知れるように次の形で利用するためのものです。

attr1 = Descriptor1(name='attr1')

Python 3.6 で __set_name__ という特殊メソッドが追加され、 Python 3.6 以降では descriptor オブジェクト自身がアトリビュート名をかんたんに知れるようになりました。

class Descriptor1:
    def __set_name__(self, owner, name):
        self._name = name

    def __get__(self, instance, owner):
        print(self, instance, owner)
        return '{}.__get__ for {}'.format(self.__class__.__name__, self._name)

class D:
    attr1 = Descriptor1()

d1 = D()

d1.attr1
# => 'Descriptor1.__get__ for attr1'

attr1 = Descriptor1() が実行されると __set_name__ が呼び出され引数 name にアトリビュート名が渡されるので、 descriptor 側でアトリビュート名を利用することができます。

descriptor のロジックはこれだけではありません。

レベル 5: 発展 2


アトリビュートへのアクセスがあると、 オブジェクトのクラスのデータ保持領域で要素が探索される。該当する要素が存在し、かつ、該当する要素が __get__ メソッドと __set__ メソッドを備えていれば __get__ メソッドが呼ばれその戻り値が返される。 そうでない場合は、オブジェクト固有のデータ保持領域で要素が探索される。要素が存在すればその値が返される。存在しなければ、オブジェクトのクラスのデータ保持領域で要素が探索される。該当する要素が存在した場合、それが __get__ メソッドを備えていれば __get__ メソッドが呼ばれ、その戻り値が返される。 __get__ メソッドを備えていなければそのオブジェクトそのものが返される。同様の探索が、継承をたどってすべての親のデータ保持領域で行われる。それでも存在しなければ、オブジェクトのクラスが __getattr__ メソッドを持っているかどうかがチェックされる。持っていればそれが呼び出され、その戻り値が返される。持っていなければ AttributeError があがる。

強調部分がレベル 4 との違いです。

レベル 4 までは「アトリビュートのアクセスがあるとオブジェクト固有のデータ保持領域で要素が探索される」と言っていましたが、実は、アトリビュートのアクセスがあったときに最初に行われることは、オブジェクトではなくクラスのデータ保持領域 __dict__ での探索です。そこに該当する要素があり、なおかつその要素が __get____set__ の 2 つのメソッドを備えていれば、その __get__ メソッドが呼ばれて戻り値が返されます。クラスのデータ保持領域に該当する要素がなかったり、あっても __get__ メソッド・ __set__ メソッドを備えていない場合は、通常どおりオブジェクトのデータ保持領域 __dict__ での探索が行われます。以降の処理はレベル 4 での説明のとおりです。

コードで確認してみましょう。

class Descriptor2:
    def __set_name__(self, owner, name):
        self._name = name

    def __get__(self, instance, owner):
        return '{}.__get__ for {}'.format(self.__class__.__name__, self._name)

    def __set__(self, instance, value):
        pass

class E:
    attr1 = Descriptor2()

e1 = E()
e1.__dict__['attr1'] = 10

# オブジェクトの __dict__ よりもクラスの __dict__ が優先される
print(e1.attr1)
# => 'Descriptor2.__get__ for attr1'

e1.attr1 にアクセスすると 'Descriptor2.__get__ for attr1' という文字列が返されました。これは Descriptor2__get__ の戻り値です。

ポイントは、 e1.__dict__ には attr1 というキーの要素があらかじめセットされているにもかかわらず Descriptor2__get__ が優先して呼び出されている点です。通常はオブジェクトそのものの __dict__ が先に探索されますが、クラスのデータ保持領域にある同名の要素が __get____set__ の 2 つのメソッドを備えている場合のみ、それが優先的に利用されます。

一見とてもトリッキーな動きですが、 Python がこの仕組みを用意してくれているおかげで、プログラマは「クラス定義時にそのオブジェクトの特定のアトリビュートを特別扱いする指示ができる汎用的な方法」を作ることができます。

この仕組みを利用したかんたんな例をあげてみます。

class TitleField:
    def __init__(self, length):
        self._len = length

    def __set_name__(self, owner, name):
        self._name = name

    def __get__(self, instance, owner):
        return instance.__dict__[self._name]

    def __set__(self, instance, value):
        if not isinstance(value, str):
            raise ValueError('アトリビュート {} には文字列のみがセットできます。'.format(self.self._name))
        if len(value) > self._len:
            raise ValueError('アトリビュート {} の最大長さは {} です。'.format(self.self._name, self._len))
        instance.__dict__[self._name] = value

    def __delete__(self, instance):
        del instance.__dict__[self._name]


class Article:
    title = TitleField(length=32)


a1 = Article()
a1.title = 10
# => ValueError: アトリビュート title には文字列のみがセットできます。


class Page:
    title = TitleField(length=20)


p1 = Page()
p1.title = 'Guardians of the Galaxy 2'
# => ValueError: アトリビュート title の最大長さは 20 です。

p1.title = 'Jurassic World'
print(p1.title)
# => 'Jurassic World'

クラス TitleField__get____set__ の 2 つのメソッドを持った descriptor クラスです。これを ArticlePage の 2 つのクラスで利用しています。

TitleField__set__ にバリデーション処理があるので、 ArticlePagetitle アトリビュートに文字列以外のオブジェクトや指定された長さよりも長い文字列をセットすることはできません。

このような処理は __setattr__ メソッドや property を使っても実装することができますが、 TitleField という独立したクラスに定義することによって、複数のクラス・複数のアトリビュートで使い回せるというメリットが生まれます。

尚、ここであげた TitleField の各メソッドの定義では descriptor として不十分なところがあります。 TitleField はあくまでも descriptor の少し実用的なイメージを示すためのサンプルなので、実際に descriptor クラスを書こうというときにはぜひ公式ドキュメントや詳しい書籍を参照してください。

書籍や記事でご存知の方にはおなじみですが、この __set__ メソッドを持つ desriptor を data descriptor (データ・ディスクリプタ) 、持たない descriptor を non-data decriptor (ノンデータ・ディスクリプタ) と呼びます。私は「 non-data 」を日本語で書く場合は「ノンデータ」とカタカナで書くのが好みですが、 non-data descriptor は「非データ・ディスクリプタ」と訳されているのをよく目にします。

この data decriptor ・ non-data descriptor という概念を使って見ると、レベル 4 でのディスクリプタの呼び出しタイミングは non-data descriptor の挙動の説明で、レベル 5 の「オブジェクトの __dict__ の前にクラスの __dict__ が参照される」は data descriptor の挙動の説明でした。

ここまででお腹いっぱいになりそうですが、もうひとレベルあります。次のレベルが最後です。

レベル 6: 発展 3


アトリビュートへのアクセスがあると、 真っ先にメソッド __getattribute__ が呼ばれる。オブジェクトのクラスとその先祖クラスで __getattribute__ を定義しているものがなければ、基底クラス object__getattribute__ が呼ばれる。その中で以下の処理が行われる。

まずは、オブジェクトのクラスのデータ保持領域で要素が探索される。該当する要素が存在し、かつ、該当する要素が __get__ メソッドと __set__ メソッドを備えていれば __get__ メソッドが呼ばれ、その戻り値が返される。そうでない場合は、オブジェクト固有のデータ保持領域で要素が探索される。要素が存在すればその値が返される。存在しなければ、オブジェクトのクラスのデータ保持領域で要素が探索される。該当する要素が存在した場合、それが __get__ メソッドを備えていれば __get__ メソッドが呼ばれその戻り値が返される。 __get__ メソッドを備えていなければそのオブジェクトそのものが返される。同様の探索が、継承をたどってすべての親のデータ保持領域で行われる。それでも存在しなければ、オブジェクトのクラスが __getattr__ メソッドを持っているかどうかがチェックされる。持っていればそれが呼び出され、その戻り値が返される。持っていなければ AttributeError があがる。

強調部分がレベル 5 との違いです。

レベル 5 の説明は、実は object.__getattribute__ が呼ばれた後の処理の流れを説明したものです。プログラマがクラスで __getattribute__ を定義すると、この流れをカスタマイズすることができます。

ただ、 __getattribute__ を定義しないといけないようなケースというのは非常に稀だと思います。 __getattribute__ を上書きできる仕組みは用意されてはいるものの、独自の __getattribute__ はおそらくメリットよりも多くのデメリットをもたらすので、よほどのことでないかぎり __getattribute__ の上書きが必要なケースは無いでしょう。

ここまで来ると、 Python のアトリビュートアクセスの仕組みをある程度正確に把握したと言ってよいのではないでしょうか。

ここでは obj.attr1 という形でアトリビュートが「参照」されたときの処理の流れだけを説明しましたが、 obj.attr1 = ... という「代入」のときの流れや del obj.attr1 という「削除」のときの流れも同様にあります。上の説明の中に __set__ メソッド・ __delete__ メソッドへの言及が少しありましたが、これらが「代入」や「削除」のときの流れをコントロールするためのものになります。

というわけで、 Python のオブジェクトにおいてアトリビュートへのアクセスがあったときに起こる処理の流れ についてでした。

Python でコードを書くときにこのあたりのところをどこまで押さえておくべきか、についてですが、私は必ずしもすべて頭に入れておく必要は無いと思います。ひとまず基本的な使い方をするだけであればレベル 1 ・ 2 あたりを押さえておけば十分で、アトリビュート周りについて発展的な使い方をしたいときにレベル 3 ・ 4 を、 descriptor を活用したパッケージにコントリビュートしたり自身で descriptor を使ったパッケージを書いたりしたい場合に必要に応じて 5 ・ 6 までを押さえる、というのがよいかと思います。

興味がある方のご参考になれば幸いです :)

descriptor を深掘りしたくて実践的な例を見てみたい方は、 peewee 等の ORM マッパライブラリのコードを見られるとよいかと思います。

参考


2018/08/31

Python Tips: switch 文を使いたい

Python に似た言語にはよくあって Python に無いもののひとつに「 switch 文」があります。今回は switch 文を持たない Python において switch 文を書きたくなったときの代わりの書き方 をご紹介したいと思います。

おそらく Pythonista の多くが使っているのは次の 2 つの方法のどちらかです。

  • switch 文の代わり 1: ifelif
  • switch 文の代わり 2: dict

switch 文の代わり 1: ifelif


第一の方法は単純に ifelif 構文を使うというものです。

def printer_factory(name):
    if name == 'json':
        return JsonPrinter()
    elif name == 'yaml':
        return YamlPrinter()
    elif name == 'csv':
        return CsvPrinter()
    else:
        raise ValueError('Invalid printer: {}'.format(name))

この方法だと name == の部分を分岐の数だけ繰り返す必要がありますが、キーワード elif が短いおかげでわりとシンプルに書くことができます。

最後の else のところに来たときに例外をあげるかフォールバック値を返すかはそのときどきで適切な方を選ぶとよいでしょう。

switch 文の代わり 2: dict


もうひとつの代表的なアプローチは dict を使った方法です。

def printer_factory_改(name):
    printer_map = {
        'json': JsonPrinter,
        'yaml': YamlPrinter,
        'csv': CsvPrinter,
        'xml': XmlPrinter,
        'html': HtmlPrinter,
        'mild': MildPrinter,
        'wild': WildPrinter,
    }

    try:
        return printer_map[name]()
    except KeyError as e:
        raise ValueError('Invalid printer: {}'.format(name))

分岐を表すマップをあらかじめ定義しておき、辞書のキールックアップを使って分岐させます。

指定された値が存在しなかったとき(= switch 文で default に来たときに相当)の挙動として、例外をあげたいのであれば KeyError をキャッチして適切な例外をあげ直せば OK です。フォールバック値を返したければブラケット( [] ではなく get() メソッドを使ってデフォルト値を設定しながら値を返すとよいでしょう。

# 該当するものが見つからなかった場合はフォールバック値を返す
default_value = HtmlPrinter
return printer_map.get(name, default_value)()

こちらの方法で気をつけるべき点は、マップを作成するときに値を評価してしまわないことです。例えば上の printer_factory_改 は次のように書くこともできますが、こうするとマップを用意しているときにすべての Printer クラスのインスタンスが生成されてしまうのであまりよくありません。

def printer_factory_改悪(name):
    printer_map = {
        'json': JsonPrinter(),
        'yaml': YamlPrinter(),
        'csv': CsvPrinter(),
        'xml': XmlPrinter(),
        'html': HtmlPrinter(),
        'mild': MildPrinter(),
        'wild': WildPrinter(),
    }

    try:
        return printer_map[name]
    except KeyError as e:
        raise ValueError('Invalid printer: {}'.format(name))

Python ではクラスそのものもオブジェクトでありクラスを dict の値として格納することができるので、ファクトリクラスはできるだけ printer_factory_改悪 よりも printer_factory_改 の形で書くのがよいでしょう。

以上です。

ifelifdict のどちらを使うべきかはそのときの分岐の数や周辺のコード、コーディングルール等によって変わってくると思うので、そのときどきでより適切な方を選ぶとよいでしょう。

アーキテクチャの良し悪しの観点からいえば、 switch 文を使うべき場面は非常にかぎられてくるはずなので、 switch 文が無いということは Python らしいいい制約、なのかもしれません。

というわけで、 Python における「 switch 文の代わりの書き方」についてでした。

参考

Python 公式のドキュメントの FAQ に「なぜ Python には switch case が無いの?」という項目があるので、経緯等に興味がある方はそちらもご覧になってみるとよいかと思います。