본문 바로가기
Algorithm(CodeTree, Python)/Backtracking

[코드트리] 단순한 동전 챙기기 Python

by kurooru 2023. 2. 1.
# n 입력
n = int(input())
# grid 입력
grid = [
    input()
    for _ in range(n)
]

# 함수들
# get_min_dist(curr_comb)
def get_min_dist(curr_comb):
    # num_1, num_2, num_3
    num_1, num_2, num_3 = curr_comb[0], curr_comb[1], curr_comb[2]
    
    # grid를 돌면서
    for i in range(n):
        for j in range(n):
            # 'S'를 찾으면,
            if grid[i][j] == 'S':
                # sx, sy 기록
                sx, sy = i, j
            # 'E'를 찾으면,
            elif grid[i][j] == 'E':
                # ex, ey 기록
                ex, ey = i, j
            # num_1을 찾으면
            elif grid[i][j] == str(num_1):
                # num_1_x, num_1_y 기록
                num_1_x, num_1_y = i, j
            # num_2을 찾으면
            elif grid[i][j] == str(num_2):
                # num_2_x, num_2_y 기록
                num_2_x, num_2_y = i, j
            # num_3을 찾으면
            elif grid[i][j] == str(num_3):
                # num_3_x, num_3_y 기록
                num_3_x, num_3_y = i, j
    
    # curr_dist
    curr_dist = 0

    # s -> num_1 -> num_2 -> num_3 -> e
    curr_dist += abs(sx - num_1_x) + abs(sy - num_1_y)
    curr_dist += abs(num_2_x - num_1_x) + abs(num_2_y - num_1_y)
    curr_dist += abs(num_3_x - num_2_x) + abs(num_3_y - num_2_y)
    curr_dist += abs(ex - num_3_x) + abs(ey - num_3_y)

    # 반환
    return curr_dist

# make_comb(curr_idx)
def make_comb(curr_idx):
    # 전역 변수 선언
    global min_dist

    # 종료조건
    if curr_idx == 4:
        # min_dist update
        min_dist = min(min_dist, get_min_dist(comb))
        return

    for num in num_list:
        if curr_idx >= 2 and comb[-1] >= num:
            continue
        else:
            comb.append(num)
            make_comb(curr_idx + 1)
            comb.pop()

# 설계
# num_list
num_list = []

# grid를 돌면서
for i in range(n):
    for j in range(n):
        # 숫자면,
        try:
            # num_list에 추가
            num_list.append(int(grid[i][j]))
        # 숫자가 아니면,
        except:
            # 
            continue

# num_list 2 이하면
if len(num_list) <= 2:
    # -1 출력
    print(-1)
# 2 이상이면
else:
    # min_dist
    import sys
    min_dist = sys.maxsize
    # num_list 정렬
    num_list.sort()
    # comb
    comb = []
    # make_comb
    make_comb(1)
    # 출력
    print(min_dist)