Python グリッド用ライブラリ

グリッド用ライブラリ

競プロでグリッド系の問題が出てとき、手間取りがちなのでライブラリにまとめました。
numpy実装でnumpy特有のスライスも対応してます。
np.allやnp.whereなどndarrayを引数にする関数を使いたい場合は、Grid.gridをお使いください()

import numpy as np
class Grid():
    def __init__(self, grid, w=0, h=0, function=lambda x: x):
        self.w = w = w if w else len(grid[0])
        self.h = h = h if h else len(grid)
        dtype = type(function(grid[0][0]))
        self.grid = np.empty((h, w), dtype=dtype)
        for i, row in zip(range(h), grid):
            for j, val in zip(range(w), row):
                self.grid[i][j] = function(val)
    
    def is_valid_x(self, x):
        return 0 <= x < self.w
    def is_valid_y(self, y):
        return 0 <= y < self.h
    def is_valid_xy(self, x, y):
        return self.is_valid_x(x) and self.is_valid_y(y) 
    
    def __iter__(self):
        return iter(self.grid)
    def __repr__(self):
        return '\n'.join([' '.join(map(str, row)) for row in self.grid])
    def __getitem__(self, x):
        return self.grid[x]
    def __setitem__(self, x, val):
        self.grid[x] = val

def dfs(root):
    x, y = root
    grid[y, x] = 0
    stack = [root]
    while stack:
        x, y = stack.pop()
        for dx, dy in zip([1, 0, -1, 0], [0, 1, 0, -1]):
            if grid.is_valid_xy(x+dx, y+dy) and grid[y+dy, x+dx]:
                stack.append((x+dx, y+dy))
                grid[y, x] = 0

from collections import deque
def bfs(root):
    x, y = root
    grid[y, x] = 0
    queue = deque([root])
    while queue:
        x, y = queue.popleft()
        for dx, dy in zip([1, 0, -1, 0], [0, 1, 0, -1]):
            if grid.is_valid_xy(x+dx, y+dy) and grid[y+dy, x+dx]:
                queue.append((x+dx, y+dy))
                grid[y+dy, x+dx] = 0

解説

引数にはインデクシングが可能な2次元イテレータを入れればOKです。
要はAtCoderの入力の2次元リストや文字列のリストをそのまま入れれば大丈夫です。
入力がmapやgeneratorの場合は w, hを引数にする必要があるので注意!
functionはmapみたいな感じで初期化時にgrid全体に関数が適用されます。
ついでにdfs,bfs実装のテンプレートも書いてます。

使用例

AtCoder ABC 96 C - Grid Repainting 2
AtCoder ABC 151 D - Maze Master
AtCoder AGC 43 A - Range Flip Find Route

# C - Grid Repainting 2
from itertools import product
from collections import deque
import numpy as np
def Grid():
## 略
inputs = open(0).readlines()
h, w = map(int, inputs[0].split())
grid = Grid(inputs[1:], function=lambda x: int(x == '#'))

def dfs(root):
    count = 1
    x, y = root
    grid[y, x] = 0
    stack = [root]
    while stack:
        x, y = stack.pop()
        grid[y, x] = 0
        for dx, dy in zip([1, 0, -1, 0], [0, 1, 0, -1]):
            if grid.is_valid_xy(x+dx, y+dy) and grid[y+dy, x+dx]:
                stack.append((x+dx, y+dy))
                count += 1
    return count
    
ans = 'Yes'
for i, j in product(range(h), range(w)):
    if grid[i, j]:
        if dfs((j, i)) == 1:
            ans = 'No'
            break
print(ans)
# D - Maze Master
from copy import deepcopy
from collections import deque
from itertools import product
import numpy as np
def Grid():
## 略
inputs = open(0).readlines()
h, w = map(int, inputs[0].split())
grid_origin = Grid(inputs[1:], function=lambda x: int(x == '.'))

def bfs(root):
    x, y, _ = root
    grid[y, x] = 0
    queue = deque([root])
    while queue:
        x, y, d = queue.popleft()
        for dx, dy in zip([1, 0, -1, 0], [0, 1, 0, -1]):
            if grid.is_valid_xy(x+dx, y+dy) and grid[y+dy, x+dx]:
                queue.append((x+dx, y+dy, d+1))
                grid[y+dy, x+dx] = 0
    return d
    
ans = 0
for i, j in product(range(h), range(w)):
    if grid_origin[i, j]:
        grid = deepcopy(grid_origin)
        ans = max(ans, bfs((j, i, 0)))
print(ans)
# Grid Compression
import numpy as np
def Grid():
## 略
inputs = [s.strip() for s in open(0).readlines()]
h, w = map(int, inputs[0].split())
grid = Grid(inputs[1:])

dp = Grid(np.zeros((h, w), np.int))
dp[0, 0] = int(grid[0, 0] == '#')
for i in range(1, w):
    dp[0, i] = dp[0, i-1] + int(grid[0, i-1] + grid[0, i] == '.#')
for i in range(1, h):
    dp[i, 0] = dp[i-1, 0] + int(grid[i-1, 0] + grid[i, 0] == '.#')
x = y = 1
while not (x == w-1 and y == h-1):
    for i in range(x, w):
        dp[y, i] = min(dp[y, i-1] + int(grid[y, i-1] + grid[y, i] == '.#'),
                       dp[y-1, i] + int(grid[y-1, i] + grid[y, i] == '.#'))
    for i in range(y+1, h):
        dp[i, x] = min(dp[i-1, x] + int(grid[i-1, x] + grid[i, x] == '.#'),
                       dp[i, x-1] + int(grid[i, x-1] + grid[i, x] == '.#'))
    x, y = min(x+1, w-1), min(y+1, h-1)
print(min(dp[y-1, x] + int(grid[y-1, x] + grid[y, x] == '.#'), 
          dp[y, x-1] + int(grid[y, x-1] + grid[y, x] == '.#')))