| 1 | import numpy as np |
|---|
| 2 | import numpy.linalg as lg |
|---|
| 3 | |
|---|
| 4 | class Pyramid(object): |
|---|
| 5 | |
|---|
| 6 | def __init__(self, height, initial): |
|---|
| 7 | self.height = height |
|---|
| 8 | self.initial = initial |
|---|
| 9 | self.solution = None |
|---|
| 10 | self.residual = None |
|---|
| 11 | |
|---|
| 12 | def _loc(self, row, col): |
|---|
| 13 | return (row*(row-1)/2) + col - 1 |
|---|
| 14 | |
|---|
| 15 | def _below(self, row, col): |
|---|
| 16 | return (row+1, col), (row+1, col+1) |
|---|
| 17 | |
|---|
| 18 | def solve(self): |
|---|
| 19 | """Solve pyramid puzzle. |
|---|
| 20 | |
|---|
| 21 | Parameters |
|---|
| 22 | ---------- |
|---|
| 23 | height : int |
|---|
| 24 | height of pyramid |
|---|
| 25 | initial: (row, col) -> int |
|---|
| 26 | mapping for initial values |
|---|
| 27 | """ |
|---|
| 28 | n = self.height |
|---|
| 29 | n_unknowns = (n+1)*(n)/2 |
|---|
| 30 | n_tri_constraints = (n)*(n-1)/2 |
|---|
| 31 | n_init_contraints = len(self.initial) |
|---|
| 32 | n_constraints = n_tri_constraints + n_init_contraints |
|---|
| 33 | |
|---|
| 34 | constraints = np.zeros((n_constraints, n_unknowns)) |
|---|
| 35 | rhs = np.zeros((n_constraints, 1)) |
|---|
| 36 | constraint = 0 |
|---|
| 37 | loc = self._loc |
|---|
| 38 | below = self._below |
|---|
| 39 | |
|---|
| 40 | # populate triangle constraints |
|---|
| 41 | for row in range(1, n): |
|---|
| 42 | for col in range(1, row+1): |
|---|
| 43 | b1, b2 = below(row, col) |
|---|
| 44 | constraints[constraint, loc(row, col)] = -1 |
|---|
| 45 | constraints[constraint, loc(*b1)] = 1 |
|---|
| 46 | constraints[constraint, loc(*b2)] = 1 |
|---|
| 47 | rhs[constraint] = 0 |
|---|
| 48 | constraint += 1 |
|---|
| 49 | |
|---|
| 50 | # populate initial constraints |
|---|
| 51 | for (row, col), val in self.initial.items(): |
|---|
| 52 | constraints[constraint, loc(row, col)] = 1 |
|---|
| 53 | rhs[constraint] = val |
|---|
| 54 | constraint += 1 |
|---|
| 55 | |
|---|
| 56 | # phew! |
|---|
| 57 | #print constraints |
|---|
| 58 | #print rhs |
|---|
| 59 | self.solution, self.residuals = lg.lstsq(constraints, rhs)[:2] |
|---|
| 60 | #print self.solution |
|---|
| 61 | |
|---|
| 62 | def __str__(self): |
|---|
| 63 | """Return a string representation of the pyramid.""" |
|---|
| 64 | s = [] |
|---|
| 65 | s.append("Pyramid (height: %d)" % (self.height,)) |
|---|
| 66 | s.append(" [solved: %s]" % ("yes" if self.solution is not None else "no")) |
|---|
| 67 | s.append(" -----") |
|---|
| 68 | |
|---|
| 69 | if self.solution is not None: |
|---|
| 70 | for row in range(1, self.height+1): |
|---|
| 71 | acol = [] |
|---|
| 72 | for col in range(1, row+1): |
|---|
| 73 | acol.append("%5d" % int(np.round(self.solution[self._loc(row, col)]))) |
|---|
| 74 | s.append(" " + "".join(acol)) |
|---|
| 75 | |
|---|
| 76 | s.append(" -----") |
|---|
| 77 | |
|---|
| 78 | good = True |
|---|
| 79 | for (row, col), val in self.initial.items(): |
|---|
| 80 | solval = self.solution[self._loc(row, col)] |
|---|
| 81 | if np.abs(solval - val) > 0.00001: |
|---|
| 82 | s.append(" !! (%d, %d) != %d (%f)" % (row, col, val, solval)) |
|---|
| 83 | good = False |
|---|
| 84 | |
|---|
| 85 | if good: |
|---|
| 86 | s.append(" No constraints violated.") |
|---|
| 87 | |
|---|
| 88 | s.append(" Residual: %f" % np.sum(self.residuals)) |
|---|
| 89 | s.append(" -----") |
|---|
| 90 | |
|---|
| 91 | return "\n".join(s) |
|---|
| 92 | |
|---|
| 93 | |
|---|
| 94 | |
|---|
| 95 | def print_pyramid(pyramid): |
|---|
| 96 | """Print a pyramid.""" |
|---|
| 97 | print pyramid |
|---|
| 98 | |
|---|
| 99 | if __name__ == "__main__": |
|---|
| 100 | import sys |
|---|
| 101 | if len(sys.argv) < 3: |
|---|
| 102 | print "Usage: pyramid.py <height> <inital value> ..." |
|---|
| 103 | print "e.g. pyramid.py 6 1,1=234 4,2=29 5,2=29 5,4=56 6,1=79 6,4=-69 6,5=125" |
|---|
| 104 | print "e.g. pyramid.py 2 1,1=2 2,1=1" |
|---|
| 105 | sys.exit(1) |
|---|
| 106 | |
|---|
| 107 | height = int(sys.argv[1]) |
|---|
| 108 | initial = {} |
|---|
| 109 | for arg in sys.argv[2:]: |
|---|
| 110 | pos, val = arg.split("=") |
|---|
| 111 | row, col = pos.split(",") |
|---|
| 112 | val = int(val) |
|---|
| 113 | row = int(row) |
|---|
| 114 | col = int(col) |
|---|
| 115 | initial[(row, col)] = val |
|---|
| 116 | |
|---|
| 117 | pyramid = Pyramid(height, initial) |
|---|
| 118 | pyramid.solve() |
|---|
| 119 | print str(pyramid) |
|---|