/* Generate all possible expressions by backtracking.

   Each expression is given a weight.  First the weight 1 expressions
   are generated; then the weight 2; and so on.  The weight of an
   atomic expression is defined by the clauses for "unknwn",
   "number", and "logical".  The weight of a complex expression is the
   sum of the weight for its operator and the weights of each argument.

   The arguments "Old_Max" and "New_Max" in the various clauses are used
   to control the generation of unknowns, to avoid generating expressions
   that are equivalent except for different names for unknowns.
   "Old_Max" is passed in, and indicates that unknowns up to the
   Old_Max-th one can be used.  "New_Max" is then returned. */



generate(Expr) :- generate(1,Expr).

generate(Weight,Expr) :- expression(Weight,Expr,1,_).
generate(Weight,Expr) :-
   W1 is Weight+1,
   generate(W1,Expr).

expression(Weight,Expr,Old_Max,New_Max) :-
   unknown(Weight,Expr,Old_Max,New_Max).
expression(Weight,Expr) :- number(Weight,Expr).
expression(Weight,Expr) :- logical(Weight,Expr).

expression(Weight,Expr,Old_Max,New_Max) :-
   /* hack to speed things up for small weights -- given
      current weights, this can't succeed for Weight<2 */
   Weight >= 2,
   unary(Op,Op_Weight),
   W1 is Weight - Op_Weight,
   W1 >= 1,
   expression(W1,E1,Old_Max,New_Max),
   Expr =.. [Op,E1].

expression(Weight,Expr,Old_Max,New_Max) :-
   Weight >= 3,
   binary(Op,Op_Weight),
   W1 is Weight - Op_Weight,
   W1 >= 1,
   sum(I,J,W1),
   expression(I,E1,Old_Max,Max1),
   expression(J,E2,Max1,New_Max),
   Expr =.. [Op,E1,E2].


/* sum(I,J,N) produces by backtracking all I and J that sum to N */

sum(N,0,N).
sum(I,J,N) :-
   N>0,
   N1 is N-1,
   sum(I,J1,N1),
   J is J1+1.





unknown(1,Expr,Old_Max,New_Max) :-
   unknwn(Old_Max,Expr),
   New_Max is Old_Max+1.

unknown(1,Expr,Old_Max,Old_Max) :-
   Old_Max>1,
   J is Old_Max-1,
   unknown(1,Expr,J,_).


/* unknwn(I,Name) returns a canonical name for the Ith unknown */

unknwn(1,u) :- !.
unknwn(2,v) :- !.
unknwn(3,w) :- !.
unknwn(4,x) :- !.
unknwn(5,y) :- !.
unknwn(6,z) :- !.

unknwn(I,Name) :-
   !,
   I>6,
   J is I-6,
   concat(u,J,Name).



/* The weighting for constants gives preference to 0,1, and 2;
   numbers greater than 2 are heavily penalized. */

number(1,0).
number(1,1).
number(1,2).

number(N,N) :-
   N>2.

logical(1,true).
logical(1,false).



unary(-,2).
unary(sin,2).
unary(cos,2).
unary(tan,4).

binary(+,1).
binary(*,1).
binary(^,4).
binary(=,2).


go :- generate(E), write(E), nl, fail.
