use derivatives.frink use integrals.frink use solvingTransformations.frink use powerTransformations.frink // This class solves a system of equations. It contains a graph of // EquationNodes. Each EquationNode represents an equation in the system. // These nodes are connected by Edges which indicate that various // equations are interrelated by containing the same variables. class Solver { // This is an array of EquationNodes that make up the graph. var equationNodes // This is a set of strings indicating the variables that we are not going // to solve for. var ignoreSet // Boolean flag indicating if we have done phase 1 simplifications var initialized // The cached final solutions for each variable, keyed by variable name var finalSolutions // A (currently-unused) set of aliases that will simplify equations. var aliases // Create a new Solver. new[eqs, ignoreList=[]] := { ignoreSet = new set equationNodes = new array finalSolutions = new dict for i = ignoreList ignoreSet.put[i] for eq = eqs addEquation[eq] initialized = false } // Add an equation to the system. This creates a new EquationNode and // automatically connects it to all of the other EquationNodes in the // system that share its equations. // // If index is undef. this will push the node onto the end of the list. // if index is a number, this will replace the specified node. // // (public) addEquation[eq, index=undef] := { // THINK ABOUT: Call transformExpression to canonicalize first? // Check for duplicate equations. for n = equationNodes if structureEquals[eq, n.getOriginalEquation[]] { println["Eliminating duplicate equation $eq"] return } reducedUnknowns = setDifference[getSymbols[eq], ignoreSet] if length[reducedUnknowns] == 0 println["WARNING: Equation $eq has no unknowns!"] node = new EquationNode[eq, reducedUnknowns] // If replacing, disconnect the other node if index != undef remove[index] for other = equationNodes { // THINK ABOUT: Should we check to see if one equation is a proper // subset of the variables of the other and push the simpler into // the more complex right now? // This is a set of shared variables between the two equations. sharedVars = intersection[reducedUnknowns, other.getUnknowns[]] for varName = sharedVars connect[node, other, varName] } if index == undef equationNodes.push[node] else equationNodes.insert[index, node] // Discard any previously-found solutions finalSolutions = new dict initialized = false } // Method to initialize and simplify the system. initialize[] := { if ! initialized { pushSimpler[] // draw[] changed = solveSimultaneous[] // draw[] // pushAliases[] // draw[] if changed pushSimpler[] initialized = true } } // Removes the specified node from the graph. This removes all connections // to the specified node. // (public) remove[index] := { node = equationNodes.remove[index] disconnect[node] initialized = false finalSolutions = new dict // Discard any previously-found solutions. } // Disconnect the specified node from the graph. // (private) disconnect[node] := { node.disconnectAllEdges[] } // Connect the two specified equations by the specified variable. // (private) connect[n1 is EquationNode, n2 is EquationNode, varName is string] := { e = new Edge[n1, n2, varName] n1.addEdge[e] n2.addEdge[e] } // Return a count of the EquationNodes in the system // (public) getEquationCount[] := length[equationNodes] // Returns the EquationNode with the specified index // (public) getEquationNode[index] := equationNodes@index // Returns an array with each element in the array being an array // [ unknowns, index ] // of the equations in the system, ordered with the simplest equations // (those with the fewest unknowns) first. Unknowns is a set. getEquationsSortedByComplexity[] := { list = new array i = 0 last = length[equationNodes] - 1 for i = 0 to last list.push[[equationNodes@i.getUnknowns[], i]] sort[list, {|a,b| length[a@0] <=> length[b@0]}] return list } // Returns the unknowns for the specified index. getUnknowns[index] := equationNodes@index.getUnknowns[] // Prints out the state of the solver for debugging. // (public) dump[] := { last = getEquationCount[]-1 for i = 0 to last { node = getEquationNode[i] print["$i\t" + node.getOriginalEquation[] + "\t"] for e = node.getEdges[] { other = e.getOtherNode[node] print["[" + e.getVariableName[] + "," + getNodeIndex[other] +"] "] } println[] } println[] } // Draw a representation of the system. draw[g is graphics, left=0, top=0, right=1, bottom=2] := { last = getEquationCount[]-1 width = right-left height = bottom-top g.font["Serif", "italic", height/30] g.color[0,0,0] cy = top + height/2 w = 0.7 width/20; for i = 0 to last { node = getEquationNode[i] [x,y] = getPosition[i, left, top, right, bottom] for e = node.getEdges[] { oi = getNodeIndex[e.getOtherNode[node]] if (i < oi) { [ox, oy] = getPosition[oi, left, top,right,bottom] g.color[0,0,0] g.line[x+randomFloat[-w,w],y+randomFloat[-w,w],ox+randomFloat[-w, w],oy+randomFloat[-w,w]] } } g.color[.9,.9,.9] g.fillEllipseCenter[x,y,width/10, height/2/10] g.color[0,0,0] g.drawEllipseCenter[x,y,width/10, height/2/10] g.text[i,x,y,"center","center"] g.text["$i: " + node.getOriginalEquation[], left, i * height/20 + cy, "left", "top"] } } // Calculates the center of the specified node in the graph. getPosition[index, left, top, right, bottom] := { phase = circle / getEquationCount[] width = right-left height = bottom-top cx = left + width/2 cy = top + height/4 radius = 1/2 (.9) width x = radius sin[index phase] + cx y = radius * -cos[index phase] + cy return [x,y] } draw[] := { g = new graphics draw[g] g.show[] } // Gets the node index given the node. getNodeIndex[node is EquationNode] := { last = length[equationNodes] for i=0 to last if equationNodes@i == node return i return "Node not found!" } // Keep replacing simpler variables until we're done. pushSimpler[] := { // dump[] while pushSimplerOnce[] { // dump[] } } // Pushes the simpler equations into the more complex // equations. This returns true if changes were made to the system, // false otherwise. pushSimplerOnce[] := { sortedEqs = getEquationsSortedByComplexity[] last = length[sortedEqs]-1 for i=0 to last-1 { [unknownsI, nodeI] = sortedEqs@i JLOOP: for j=i+1 to last { [unknownsJ, nodeJ] = sortedEqs@j if isProperSubset[unknownsI, unknownsJ] { // println["Node $nodeI is a proper subset of $nodeJ"] simpleVarList = pickSimplestVariable[nodeI, nodeJ] for [simpleVar, count] = simpleVarList { sols = equationNodes@nodeI.getSolutions[simpleVar] if length[sols] < 1 { println["Ugh. Has " + length[sols] + " solutions in pushSimplerOnce. Maybe improve pickSimplestVariable? Solutions are: $sols, equations are " + equationNodes@nodeI.getOriginalEquation[] + ", " + equationNodes@nodeJ.getOriginalEquation[] ] next } else { origEq = equationNodes@nodeJ.getOriginalEquation[] // println["Removing node $nodeJ"] remove[nodeJ] if (length[sols] > 1) println["Warning: pushSimplerOnce split into " + length[sols] + " solutions: $sols"] for sol = sols { // println["Replacing $simpleVar in $origEq with " + child[sol,1]] newEq = replaceVar[origEq, simpleVar, child[sol,1]] [newEq, fullyReduced] = prettify[newEq] if !fullyReduced newEq = transformExpression[newEq] addEquation[newEq] } return true } } } } } return false // We made no changes. } // Pushes aliases like a == b or a == 2 b around the system so that // the system is as simple and disconnected as possible. pushAliases[] := { size = length[equationNodes] i = 0 while i rs ? ls : rs sols = node.getSolutions[first] if length[sols] != 1 { println["Ugh. Has " + length[sols] + " solutions in pushAliases. Solutions are:\n" + sols] } else { println["Replacing " + sols@0] replaceAll[sols@0, i] remove[i] i = i - 1 // Don't increment i size = size - 1 } } } i = i + 1 } } // Simplifies an equation to the form a === 10 if only one variable // remains. // returns [equation, remainingSymbol, reduced] // where reduced is a boolean flag indicating if the equation has been // simplified to the form above. prettify[eq] := { reduced = false solvedUnknowns = getSymbols[eq] solvedUnknowns = setDifference[solvedUnknowns, ignoreSet] remainingSymbol = undef if length[solvedUnknowns] == 1 { remainingSymbol = array[solvedUnknowns]@0 prettified = array[solveSingle[eq, remainingSymbol]] if length[prettified] == 1 { reduced = true eq = prettified@0 } else eq = transformExpression[eq] } return [eq, reduced] } // Picks the simplest variable shared by nodes 1 and 2. Node 1 should be // the simpler node. Returns a string. pickSimplestVariable[index1, index2] := { node1 = equationNodes@index1 node2 = equationNodes@index2 u1 = node1.getUnknowns[] u2 = node2.getUnknowns[] intersection = intersection[u1, u2] results = new array sortedUnknowns = getSymbolsByComplexity[node1.getOriginalEquation[]] //println["Sorted unknowns is $sortedUnknowns"] for [unknown, count] = sortedUnknowns if intersection.contains[unknown] results.push[[unknown, count]] sortedUnknowns = getSymbolsByComplexity[node2.getOriginalEquation[]] for [unknown, count] = sortedUnknowns if intersection.contains[unknown] for i = 0 to length[results]-1 if results@i@0 == unknown results@i@1 = results@i@1 + count sort[results, {|a,b| a@1 <=> b@1}] return results } // Returns true if the node at index 1 contains unknowns which are a // proper subset of the unknowns in index 2. This means that node 1 // is a "simpler" version of equation 2, and its values should be // substituted into equation 2. isSimpler[index1, index2] := { return isProperSubset[equationNodes@index1.getUnknowns[], equationNodes@index2.getUnknowns[]] } // Eliminate simultaneous equations in the system. // returns true if the system has been changed. solveSimultaneous[] := { changed = false // TODO: Sort by simplest equations? size = length[equationNodes] for i=0 to size-2 { JLOOP: for j = i+1 to size-1 { nodeI = equationNodes@i nodeJ = equationNodes@j ui = nodeI.getUnknowns[] uj = nodeJ.getUnknowns[] sharedUnknowns = intersection[ui, uj] // println[nodeI.getOriginalEquation[]] // println[nodeJ.getOriginalEquation[]] // println["$i: $ui\t$j: $uj"] // println["$i $j Shared unknowns are $sharedUnknowns"] if length[sharedUnknowns] >= 2 { varsToReplace = pickSimplestVariable[i, j] // println["varsToReplace is $varsToReplace"] for [varToReplace, count] = varsToReplace { skipNode = i solution = nodeI.getSolutions[varToReplace] if length[solution] != 1 { // Didn't find single solution, try solving // and replacing from the other node. solution = nodeJ.getSolutions[varToReplace] skipNode = j } if length[solution] == 1 { replaceAll[solution@0, skipNode] // dump[] changed = true break JLOOP } } println["Ugh. SolveSimultaneous fell through without replacing. Equations were " + nodeI.getOriginalEquation[] + " and " + nodeJ.getOriginalEquation[]] } } } return changed } // Replace the specified symbol, recursively, in all equations except // the index specified. // (private) replaceAll[solution, skipIndex] := { size = length[equationNodes] sym = child[solution,0] rep = child[solution,1] // Right-hand-side of solution // Substitute result into other equations. for k = 0 to size-1 if k != skipIndex { orig = equationNodes@k.getOriginalEquation[] subst = substituteExpression[orig, sym, rep] // println["orig is $orig, sym is $sym, solution is $solution, rep is $rep, subst is $subst"] if orig != subst // and length[getSymbols[subst]] <= length[getSymbols[orig]] { [subst, eqSolved] = prettify[subst] subst2 = transformExpression[subst] // THINK ABOUT: Do this? if structureEquals[_a === _b, subst2] // and structureEquals[child[subst2,0],sym] and ! expressionContains[child[subst2,1], sym] subst = subst2 else println["Warning: In replaceAll, did not get solution. Input was $solution, output was $subst2"] // println["Substituted $sym to " + rep + " in $orig, result is $subst"] addEquation[subst, k] // Replace equation. if eqSolved { // println["Going to recursively replace $sym"] replaceAll[subst, k] // Recursively replace others } } } } // Return a set of all unknowns in the system. getAllUnknowns[] := { allUnknowns = new set for node = equationNodes allUnknowns = union[allUnknowns, node.getUnknowns[]] return allUnknowns } // Solves for all variables in the system. solveAll[] := { allUnknowns = getAllUnknowns[] results = new array for u = allUnknowns { res = solveFor[u] for eq = res results.push[eq] } return results } // Solves the system for the specified variable name. // (public) solveFor[varName] := { if !initialized initialize[] cached = finalSolutions@varName if cached return cached results = new array size = getEquationCount[] for i=0 to size-1 { if getUnknowns[i].contains[varName] { partialResults = solveNodeForVariable[i, varName] for r = partialResults results.push[r] } } // Cache results. finalSolutions@varName = results return results } // Solve for the specified variable name, substituting the list of // arguments. Args is an array of ["varname", value] pairs. // The answer will be returned symbolically as an equation in the form // varName === solution // with constants and units still intact. solveForSymbolic[varName, args] := { results = new array sols = solveFor[varName] for sol = sols { for [arg, val] = args { sym = constructExpression["Symbol", arg] sol = substituteExpression[sol, sym, val] } // THINK ABOUT: Transform expression here to simplify? // res = transformExpression[res] results.push[eval[sol]] } return eliminateOverconstrained[results, false, false] } // Solve for the specified variable name, substituting the list of // arguments. The result is a list of evaluated solutions. solveFor[varName, args] := { sols = solveForSymbolic[varName, args] results = new array for sol = sols { right = child[sol,1] final = eval[right] exists = false CHECKDUPLICATES: for r = results if (final conforms r) and (final == r) { exists = true break CHECKDUPLICATES } if ! exists results.push[final] } return results } // Recursive method to find the solutions for the specified variable // starting from the specified node. This recursively enumerates all // of the permutations of substitutions in the system. // This method just sets up parameters for the recursive call. // (Private method) solveNodeForVariable[index, varName, cachedSolutions = undef] := { if cachedSolutions == undef cachedSolutions = new dict node = getEquationNode[index] sols = node.getSolutions[varName] // println["Solutions for $varName are $sols"] results = solveNodeForVariable[node, varName, sols, new set, cachedSolutions] results = transformExpression[results] // return results return eliminateOverconstrained[results, true, false] } // The recursive (private) call to solve for the particular variable. solveNodeForVariable[node, varName, inEqs, usedEdges, cachedSolutions] := { // print["Solving for $varName in " + getNodeIndex[node] + ", {"] // for e = usedEdges // print[e.getVariableName[] + " "] // println["}"] // Return partial solution from cache if possible. if cachedSolutions.containsKey[node] { varDict = cachedSolutions@node if varDict.containsKey[varName] { edgeDict = varDict@varName if edgeDict.containsKey[usedEdges] return edgeDict@usedEdges } } results = inEqs.shallowCopy[] edges = setDifference[node.getEdges[], usedEdges] for e = edges if e.getVariableName[] == varName edges.remove[e] len = length[edges] if (len == 0) // No more replacements to do. { putCache[cachedSolutions, node, varName, usedEdges, results] return results } // Set up states array to enumerate through permutations. states = new array for i=0 to len-2 states@i = false states@(len-1) = true // Skip all-false state (no replacements) i = len-1 edgeArray = array[edges] //newUsedEdges = union[node.getEdges[], usedEdges] while i >= 0 { newUsedEdges = usedEdges.shallowCopy[] for j = 0 to len-1 if states@j newUsedEdges.put[edgeArray@j] // Perform replacements on each edge EDGELOOP: for j = 0 to len-1 { edge = edgeArray@j // newUsedEdges.put[edge] // Mark this edge as used. replacingVar = edge.getVariableName[] if states@j { replacingSymbol = edge.getSymbol[] otherNode = edge.getOtherNode[node] // newGlobalUsedNodes.put[newUsedNodesHere] // Recursively solve the other node for the variable // represented by this edge. repList = solveNodeForVariable[otherNode, replacingVar, otherNode.getSolutions[replacingVar], newUsedEdges, cachedSolutions ] // println["repList is $repList"] for repWithFull = repList { repWith = child[repWithFull, 1] // Get right-hand-side for eq = inEqs { res = substituteExpression[eq, replacingSymbol, repWith] // println["Replacing $replacingVar with $repWith in $eq, result is $res"] // Check to see if the variable we're solving for occurs on the right rightSyms = getSymbols[child[res,1]] if rightSyms.contains[varName] { //println["WARNING: Right side contains $varName in $res"] res2 = solveSingle[res, varName] //println["Re-solving: $res2"] // TODO: This may return a whole lot of solutions. // We need to evaluate each one and push them all // onto the solutions list. varSymbol = constructExpression["Symbol", varName] for subR = array[res2] if structureEquals[_a === _b, subR] and structureEquals[child[subR,0],varSymbol] and ! expressionContains[child[subR,1], varSymbol] { //println["Re-solving successful."] results.push[subR] } else { // println["WARNING: Right side contains $varName in $res"] println["Re-solving FAILED: $res2"] // println["Re-solving FAILED."] } } else { // res = transformExpression[res] results.push[res] } } } } } // Advance to next binary state flipped = false i = len-1 while i>=0 and !flipped { // Enter next state if states@i == false { states@i = true flipped = true } else { // Carry states@i = false i = i - 1 } } // i now contains the last index flipped. If i < 0, we're done } results = eliminateOverconstrained[results, true, false] putCache[cachedSolutions, node, varName, usedEdges, results] return results } // This function eliminates overconstrained equations. For example, a // system containing the solutions a===1/2 c r and a===c d^-1 r^2 is // overconstrained because a value can always be obtained with the first // equation. The second is not necessary, and could lead to // inconsistent results. This method ignores any symbols listed in the // ignoreSymbols list, (these are probably units,) eliminating them from // the equations. eliminateOverconstrained[eqArray, dupsOnly, debug=false] := { size = length[eqArray] unknowns = new array lefts = new array for i = 0 to size-1 { lefts@i = child[eqArray@i, 0] unknowns@i = setDifference[getSymbols[child[eqArray@i,1]], ignoreSet] } res = new array // Check for duplicates. for i=0 to size-1 { remove = false j = 0 do { if i != j and structureEquals[lefts@i, lefts@j] { remove = (i