root/hodgestar/PythonCode/PyramidProblem/pyramid.py

Revision 720, 3.5 kB (checked in by hodgestar, 2 years ago)

Always print sum of residual.

Line 
1import numpy as np
2import numpy.linalg as lg
3
4class Pyramid(object):
5
6    def __init__(self, height, initial):
7        self.height = height
8        self.initial = initial
9        self.solution = None
10        self.residual = None
11
12    def _loc(self, row, col):
13        return (row*(row-1)/2) + col - 1
14
15    def _below(self, row, col):
16        return (row+1, col), (row+1, col+1)
17
18    def solve(self):
19        """Solve pyramid puzzle.
20
21        Parameters
22        ----------
23        height : int
24            height of pyramid
25        initial: (row, col) -> int
26            mapping for initial values
27        """
28        n = self.height
29        n_unknowns = (n+1)*(n)/2
30        n_tri_constraints = (n)*(n-1)/2
31        n_init_contraints = len(self.initial)
32        n_constraints = n_tri_constraints + n_init_contraints
33
34        constraints = np.zeros((n_constraints, n_unknowns))
35        rhs = np.zeros((n_constraints, 1))
36        constraint = 0
37        loc = self._loc
38        below = self._below
39
40        # populate triangle constraints
41        for row in range(1, n):
42            for col in range(1, row+1):
43                b1, b2 = below(row, col)
44                constraints[constraint, loc(row, col)] = -1
45                constraints[constraint, loc(*b1)] = 1
46                constraints[constraint, loc(*b2)] = 1
47                rhs[constraint] = 0
48                constraint += 1
49
50        # populate initial constraints
51        for (row, col), val in self.initial.items():
52            constraints[constraint, loc(row, col)] = 1
53            rhs[constraint] = val
54            constraint += 1
55
56        # phew!
57        #print constraints
58        #print rhs
59        self.solution, self.residuals = lg.lstsq(constraints, rhs)[:2]
60        #print self.solution
61
62    def __str__(self):
63        """Return a string representation of the pyramid."""
64        s = []
65        s.append("Pyramid (height: %d)" % (self.height,))
66        s.append("  [solved: %s]" % ("yes" if self.solution is not None else "no"))
67        s.append("  -----")
68
69        if self.solution is not None:
70            for row in range(1, self.height+1):
71                acol = []
72                for col in range(1, row+1):
73                    acol.append("%5d" % int(np.round(self.solution[self._loc(row, col)])))
74                s.append("  " + "".join(acol))
75
76            s.append("  -----")
77
78            good = True
79            for (row, col), val in self.initial.items():
80                solval = self.solution[self._loc(row, col)]
81                if np.abs(solval - val) > 0.00001:
82                    s.append("  !! (%d, %d) != %d (%f)" % (row, col, val, solval))
83                    good = False
84
85            if good:
86                s.append("  No constraints violated.")
87
88            s.append("  Residual: %f" % np.sum(self.residuals))
89            s.append("  -----")
90
91        return "\n".join(s)
92
93
94
95def print_pyramid(pyramid):
96    """Print a pyramid."""
97    print pyramid
98
99if __name__ == "__main__":
100    import sys
101    if len(sys.argv) < 3:
102        print "Usage: pyramid.py <height> <inital value> ..."
103        print "e.g. pyramid.py 6 1,1=234 4,2=29 5,2=29 5,4=56 6,1=79 6,4=-69 6,5=125"
104        print "e.g. pyramid.py 2 1,1=2 2,1=1"
105        sys.exit(1)
106
107    height = int(sys.argv[1])
108    initial = {}
109    for arg in sys.argv[2:]:
110        pos, val = arg.split("=")
111        row, col = pos.split(",")
112        val = int(val)
113        row = int(row)
114        col = int(col)
115        initial[(row, col)] = val
116
117    pyramid = Pyramid(height, initial)
118    pyramid.solve()
119    print str(pyramid)
Note: See TracBrowser for help on using the browser.