1/*
    2EM clustering on the iris dataset
    3*/

    4:- use_module(library(clpfd), [transpose/2]).
    5:- if(current_predicate(use_rendering/1)).
    6:- use_rendering(c3).
    7:- endif.
    8
    9
   10em(It,True,Post):-
   11  findall([A,B,C,D,Type],data(A,B,C,D,Type),LA),
   12  pca(LA,True),
   13  data_mean_var(1,DM1,V1),
   14  data_mean_var(2,DM2,V2),
   15  data_mean_var(3,DM3,V3),
   16  data_mean_var(4,DM4,V4),
   17  findall(M,(between(1,3,_),gauss(DM1,1.0,M)),[M11,M21,M31]),
   18  findall(M,(between(1,3,_),gauss(DM2,1.0,M)),[M12,M22,M32]),
   19  findall(M,(between(1,3,_),gauss(DM3,1.0,M)),[M13,M23,M33]),
   20  findall(M,(between(1,3,_),gauss(DM4,1.0,M)),[M14,M24,M34]),
   21  em_it(0,It,LA,par([c1:0.2,c2:0.5,c3:0.3],
   22    [M11,M12,M13,M14,V1,V2,V3,V4],[M21,M22,M23,M24,V1,V2,V3,V4],
   23    [M31,M32,M33,M34,V1,V2,V3,V4]),_Par,_,LA1),
   24  find_most_likely(LA1,LPostCat),
   25  maplist(replace_cat,LA,LPostCat,LPost),
   26  pca(LPost,Post).
   27
   28em_it(It,It,_LA,Par,Par,LAOut,LAOut):-!.
   29
   30em_it(It0,It,LA,Par0,Par,_,LAOut):-
   31  expect(LA,Par0,LA1,LL),
   32  write('LL '),write(LL),nl,
   33  maxim(LA1,Par1),
   34  It1 is It0+1,
   35  em_it(It1,It,LA,Par1,Par,LA1,LAOut).
   36
   37expect(LA,par([c1:P1,c2:P2,c3:P3],
   38  G1,G2,G3),[L1,L2,L3],LL):-
   39  maplist(weight(G1,P1),LA,L01),
   40  maplist(weight(G2,P2),LA,L02),
   41  maplist(weight(G3,P3),LA,L03),
   42  normal(L01,L02,L03,L1,L2,L3),
   43  log_lik(L01,L02,L3,P1,P2,P3,LL).
   44
   45maxim([LA1,LA2,LA3],par([c1:P1,c2:P2,c3:P3],C1,C2,C3)):-
   46  stats(LA1,W1,C1),
   47  stats(LA2,W2,C2),
   48  stats(LA3,W3,C3),
   49  SW is W1+W2+W3,
   50  P1 is W1/SW,
   51  P2 is W2/SW,
   52  P3 is W3/SW.
   53
   54find_most_likely([L1,L2,L3],LC):-
   55  maplist(classify,L1,L2,L3,LC).
   56
   57classify(_-W1,_-W2,_-W3,Cat):-
   58  find_max([W1,W2,W3],Cat).
   59
   60normal(L01,L02,L03,L1,L2,L3):-
   61  maplist(px,L01,L02,L03,L1,L2,L3).
   62
   63px(X-W01,X-W02,X-W03,X-W1,X-W2,X-W3):-
   64  S is W01+W02+W03,
   65  W1 is W01/S,
   66  W2 is W02/S,
   67  W3 is W03/S.
   68
   69weight([M1,M2,M3,M4,V1,V2,V3,V4],P,[A,B,C,D,_],[A,B,C,D]-W):-
   70  gauss_density_0(M1,V1,A,W1),
   71  gauss_density_0(M2,V2,B,W2),
   72  gauss_density_0(M3,V3,C,W3),
   73  gauss_density_0(M4,V4,D,W4),
   74  W is W1*W2*W3*W4*P.
   75
   76log_lik(L1,L2,L3,P1,P2,P3,LL):-
   77  foldl(combine(P1,P2,P3),L1,L2,L3,0,LL).
   78
   79combine(P1,P2,P3,_-W1,_-W2,_-W3,LL0,LL):-
   80  LLs is log(P1*W1+P2*W2+P3*W3),
   81  LL is LL0+LLs.
   82
   83
   84
   85find_max(Counts,MaxC):-
   86  max_list(Counts,MV),
   87  nth1(Max,Counts,MV),
   88   concat_atom(['c',Max],MaxC).
   89
   90replace_cat([A,B,C,D,_],Cat,[A,B,C,D,Cat]).
   91
   92
   93pcac(LA,P):-
   94  length(LA,NP),
   95  maplist(add_cat,LA,LCat,L),
   96  L=[H|_],
   97  length(H,Comp),
   98  list_to_set(LCat,Cats),
   99  append(L,LLin),
  100  D =..[c|LLin],
  101  data<- matrix(D,ncol=Comp,byrow='TRUE'),
  102  pc<- prcomp(data),
  103  Data0<-pc["x"],
  104  Data0=[Data1],
  105  foldl(getn(NP),Data2,Data1,[]),
  106  transpose(Data2,Data),
  107  maplist(add_cat,DC,LCat,Data),
  108  maplist(separate(DC),Cats,Mat1),
  109  maplist(keep2,Mat1,Mat2),
  110  maplist(axis,Cats,Axis,LAxis),
  111  dict_create(Ax,_,LAxis),
  112  maplist(tocol,Mat2,Axis,ColData),
  113  append(ColData,Cols),
  114  P = c3{data:_{
  115    xs:Ax, columns:Cols,type:scatter},
  116    axis:_{x:_{ tick:_{fit: false}}}
  117    }.
  118
  119:- <- library("ggplot2").
  120
  121
  122
  123pca(LA,_P):-
  124  length(LA,NP),
  125  maplist(add_cat,LA,LCat,L),
  126  L=[H|_],
  127  length(H,Comp),
  128  append(L,LLin),
  129  D =..[c|LLin],
  130  data<- matrix(D,ncol=Comp,byrow='TRUE'),
  131  pc<- prcomp(data),
  132  Data0<-pc["x"],
  133  Data0=[Data1],
  134  foldl(getn(NP),Data2,Data1,[]),
  135  transpose(Data2,Data),
  136  maplist(getx,Data,X),
  137  maplist(gety,Data,Y),
  138  x<- X,
  139  y<-Y,
  140  class<-LCat,
  141  <-qplot(x, y, colour=class),
  142  r_download. 
  143
  144getn(N,LN,L,Rest):-
  145    length(LN,N),
  146    append(LN,Rest,L).
  147
  148tocol(Data,[X,Y],[[X|DX],[Y|DY]]):-
  149  maplist(xy,Data,DX,DY).
  150
  151getx([X,_,_,_],X).
  152
  153gety([_,Y,_,_],Y).
  154
  155ab([X,Y,_,_,_],[X,Y]).
  156
  157xy([X,Y],X,Y).
  158
  159add_cat([X,Y,Z,W,C],C,[X,Y,Z,W]).
  160
  161cat(Cat,[_,_,_,_,Cat]).
  162
  163group([A,B],[C,D],[E,F],g(A,B,C,D,E,F)).
  164
  165axis(N,[NX,N],NA-NX):-
  166  (number(N)->
  167    atom_number(NA,N)
  168  ;
  169    NA=N
  170  ),
  171  atom_concat(x,NA,NX).
  172
  173keep2(Quad,Cou):-
  174  maplist(ab,Quad,Cou).
  175
  176
  177separate(DC,Cat,DataClass):-
  178  include(cat(Cat),DC,DataClass).
  179
  180
  181assert_data_means:-
  182  findall([A,B,C,D],data(A,B,C,D,_Type),LA),
  183  maplist(component,LA,CA,CB,CC,CD),
  184  mean(CA,M1),
  185  mean(CB,M2),
  186  mean(CC,M3),
  187  mean(CD,M4),
  188  variance(CA,M1,V1),
  189  variance(CB,M2,V2),
  190  variance(CC,M3,V3),
  191  variance(CD,M4,V4),
  192  assert(data_mean_var(1,M1,V1)),
  193  assert(data_mean_var(2,M2,V2)),
  194  assert(data_mean_var(3,M3,V3)),
  195  assert(data_mean_var(4,M4,V4)).
  196
  197mean(L,M):-
  198  length(L,N),
  199  sum_list(L,S),
  200  M is S/N.
  201
  202variance(L,M,Var):-
  203  length(L,N),
  204  foldl(agg_var(M),L,0,S),
  205  Var is S/N.
  206
  207stats(LA,SW,[M1,M2,M3,M4,V1,V2,V3,V4]):-
  208  maplist(component_weight,LA,CA,CB,CC,CD),
  209  weighted_mean(CA,M1,SW),
  210  weighted_mean(CB,M2,_),
  211  weighted_mean(CC,M3,_),
  212  weighted_mean(CD,M4,_),
  213  weighted_var(CA,M1,V1),
  214  weighted_var(CB,M2,V2),
  215  weighted_var(CC,M3,V3),
  216  weighted_var(CD,M4,V4).
  217 
  218weighted_var(L,M,Var):-
  219  foldl(agg_val_var(M),L,(0,0),(S,SW0)),
  220  SW is SW0,
  221  (SW=:=0.0->
  222    write(zero_var),nl,
  223    Var=1.0
  224  ;
  225    Var is S/SW
  226  ).
  227
  228weighted_mean(L,M,SW):-
  229  foldl(agg_val,L,(0,0),(S,SW0)),
  230  SW is SW0,
  231  (SW =:=0.0->
  232    write(zero_mean),nl,
  233    M is 0
  234  ;
  235    M is S/SW
  236  ).
  237
  238agg_val(V -N,(S,SW),(S+V*N,SW+N)).
  239agg_val_var(M,V -N,(S,SW),(S+(M-V)^2*N,SW+N)).
  240agg_var(M,V,S,S+(M-V)^2).
  241
  242
  243component([A,B,C,D],A,B,C,D).
  244component_weight([A,B,C,D]-W,A-W,B-W,C-W,D-W).
  245
  246gauss_density_0(M,V,X,W):-
  247  (V=:=0.0->
  248   write(zero_var_gauss),
  249    W=0.0
  250  ;
  251    gauss_density(M,V,X,W)
  252  ).
  253
  254gauss_density(Mean,Variance,S,D):-
  255  StdDev is sqrt(Variance),
  256  D is 1/(StdDev*sqrt(2*pi))*exp(-(S-Mean)*(S-Mean)/(2*Variance)).
  257
  258gauss(Mean,Variance,S):-
  259  number(Mean),!,
  260  random(U1),
  261  random(U2),
  262  R is sqrt(-2*log(U1)),
  263  Theta is 2*pi*U2,
  264  S0 is R*cos(Theta),
  265  StdDev is sqrt(Variance),
  266  S is StdDev*S0+Mean.

?- findall([A,B,C,D,Type],data(A,B,C,D,Type),LA), pca(LA,True). ?- em(10,T,P).

*/

  274% Iris dataset
  275data(5.1,3.5,1.4,0.2,'Iris-setosa').
  276data(4.9,3.0,1.4,0.2,'Iris-setosa').
  277data(4.7,3.2,1.3,0.2,'Iris-setosa').
  278data(4.6,3.1,1.5,0.2,'Iris-setosa').
  279data(5.0,3.6,1.4,0.2,'Iris-setosa').
  280data(5.4,3.9,1.7,0.4,'Iris-setosa').
  281data(4.6,3.4,1.4,0.3,'Iris-setosa').
  282data(5.0,3.4,1.5,0.2,'Iris-setosa').
  283data(4.4,2.9,1.4,0.2,'Iris-setosa').
  284data(4.9,3.1,1.5,0.1,'Iris-setosa').
  285data(5.4,3.7,1.5,0.2,'Iris-setosa').
  286data(4.8,3.4,1.6,0.2,'Iris-setosa').
  287data(4.8,3.0,1.4,0.1,'Iris-setosa').
  288data(4.3,3.0,1.1,0.1,'Iris-setosa').
  289data(5.8,4.0,1.2,0.2,'Iris-setosa').
  290data(5.7,4.4,1.5,0.4,'Iris-setosa').
  291data(5.4,3.9,1.3,0.4,'Iris-setosa').
  292data(5.1,3.5,1.4,0.3,'Iris-setosa').
  293data(5.7,3.8,1.7,0.3,'Iris-setosa').
  294data(5.1,3.8,1.5,0.3,'Iris-setosa').
  295data(5.4,3.4,1.7,0.2,'Iris-setosa').
  296data(5.1,3.7,1.5,0.4,'Iris-setosa').
  297data(4.6,3.6,1.0,0.2,'Iris-setosa').
  298data(5.1,3.3,1.7,0.5,'Iris-setosa').
  299data(4.8,3.4,1.9,0.2,'Iris-setosa').
  300data(5.0,3.0,1.6,0.2,'Iris-setosa').
  301data(5.0,3.4,1.6,0.4,'Iris-setosa').
  302data(5.2,3.5,1.5,0.2,'Iris-setosa').
  303data(5.2,3.4,1.4,0.2,'Iris-setosa').
  304data(4.7,3.2,1.6,0.2,'Iris-setosa').
  305data(4.8,3.1,1.6,0.2,'Iris-setosa').
  306data(5.4,3.4,1.5,0.4,'Iris-setosa').
  307data(5.2,4.1,1.5,0.1,'Iris-setosa').
  308data(5.5,4.2,1.4,0.2,'Iris-setosa').
  309data(4.9,3.1,1.5,0.1,'Iris-setosa').
  310data(5.0,3.2,1.2,0.2,'Iris-setosa').
  311data(5.5,3.5,1.3,0.2,'Iris-setosa').
  312data(4.9,3.1,1.5,0.1,'Iris-setosa').
  313data(4.4,3.0,1.3,0.2,'Iris-setosa').
  314data(5.1,3.4,1.5,0.2,'Iris-setosa').
  315data(5.0,3.5,1.3,0.3,'Iris-setosa').
  316data(4.5,2.3,1.3,0.3,'Iris-setosa').
  317data(4.4,3.2,1.3,0.2,'Iris-setosa').
  318data(5.0,3.5,1.6,0.6,'Iris-setosa').
  319data(5.1,3.8,1.9,0.4,'Iris-setosa').
  320data(4.8,3.0,1.4,0.3,'Iris-setosa').
  321data(5.1,3.8,1.6,0.2,'Iris-setosa').
  322data(4.6,3.2,1.4,0.2,'Iris-setosa').
  323data(5.3,3.7,1.5,0.2,'Iris-setosa').
  324data(5.0,3.3,1.4,0.2,'Iris-setosa').
  325data(7.0,3.2,4.7,1.4,'Iris-versicolor').
  326data(6.4,3.2,4.5,1.5,'Iris-versicolor').
  327data(6.9,3.1,4.9,1.5,'Iris-versicolor').
  328data(5.5,2.3,4.0,1.3,'Iris-versicolor').
  329data(6.5,2.8,4.6,1.5,'Iris-versicolor').
  330data(5.7,2.8,4.5,1.3,'Iris-versicolor').
  331data(6.3,3.3,4.7,1.6,'Iris-versicolor').
  332data(4.9,2.4,3.3,1.0,'Iris-versicolor').
  333data(6.6,2.9,4.6,1.3,'Iris-versicolor').
  334data(5.2,2.7,3.9,1.4,'Iris-versicolor').
  335data(5.0,2.0,3.5,1.0,'Iris-versicolor').
  336data(5.9,3.0,4.2,1.5,'Iris-versicolor').
  337data(6.0,2.2,4.0,1.0,'Iris-versicolor').
  338data(6.1,2.9,4.7,1.4,'Iris-versicolor').
  339data(5.6,2.9,3.6,1.3,'Iris-versicolor').
  340data(6.7,3.1,4.4,1.4,'Iris-versicolor').
  341data(5.6,3.0,4.5,1.5,'Iris-versicolor').
  342data(5.8,2.7,4.1,1.0,'Iris-versicolor').
  343data(6.2,2.2,4.5,1.5,'Iris-versicolor').
  344data(5.6,2.5,3.9,1.1,'Iris-versicolor').
  345data(5.9,3.2,4.8,1.8,'Iris-versicolor').
  346data(6.1,2.8,4.0,1.3,'Iris-versicolor').
  347data(6.3,2.5,4.9,1.5,'Iris-versicolor').
  348data(6.1,2.8,4.7,1.2,'Iris-versicolor').
  349data(6.4,2.9,4.3,1.3,'Iris-versicolor').
  350data(6.6,3.0,4.4,1.4,'Iris-versicolor').
  351data(6.8,2.8,4.8,1.4,'Iris-versicolor').
  352data(6.7,3.0,5.0,1.7,'Iris-versicolor').
  353data(6.0,2.9,4.5,1.5,'Iris-versicolor').
  354data(5.7,2.6,3.5,1.0,'Iris-versicolor').
  355data(5.5,2.4,3.8,1.1,'Iris-versicolor').
  356data(5.5,2.4,3.7,1.0,'Iris-versicolor').
  357data(5.8,2.7,3.9,1.2,'Iris-versicolor').
  358data(6.0,2.7,5.1,1.6,'Iris-versicolor').
  359data(5.4,3.0,4.5,1.5,'Iris-versicolor').
  360data(6.0,3.4,4.5,1.6,'Iris-versicolor').
  361data(6.7,3.1,4.7,1.5,'Iris-versicolor').
  362data(6.3,2.3,4.4,1.3,'Iris-versicolor').
  363data(5.6,3.0,4.1,1.3,'Iris-versicolor').
  364data(5.5,2.5,4.0,1.3,'Iris-versicolor').
  365data(5.5,2.6,4.4,1.2,'Iris-versicolor').
  366data(6.1,3.0,4.6,1.4,'Iris-versicolor').
  367data(5.8,2.6,4.0,1.2,'Iris-versicolor').
  368data(5.0,2.3,3.3,1.0,'Iris-versicolor').
  369data(5.6,2.7,4.2,1.3,'Iris-versicolor').
  370data(5.7,3.0,4.2,1.2,'Iris-versicolor').
  371data(5.7,2.9,4.2,1.3,'Iris-versicolor').
  372data(6.2,2.9,4.3,1.3,'Iris-versicolor').
  373data(5.1,2.5,3.0,1.1,'Iris-versicolor').
  374data(5.7,2.8,4.1,1.3,'Iris-versicolor').
  375data(6.3,3.3,6.0,2.5,'Iris-virginica').
  376data(5.8,2.7,5.1,1.9,'Iris-virginica').
  377data(7.1,3.0,5.9,2.1,'Iris-virginica').
  378data(6.3,2.9,5.6,1.8,'Iris-virginica').
  379data(6.5,3.0,5.8,2.2,'Iris-virginica').
  380data(7.6,3.0,6.6,2.1,'Iris-virginica').
  381data(4.9,2.5,4.5,1.7,'Iris-virginica').
  382data(7.3,2.9,6.3,1.8,'Iris-virginica').
  383data(6.7,2.5,5.8,1.8,'Iris-virginica').
  384data(7.2,3.6,6.1,2.5,'Iris-virginica').
  385data(6.5,3.2,5.1,2.0,'Iris-virginica').
  386data(6.4,2.7,5.3,1.9,'Iris-virginica').
  387data(6.8,3.0,5.5,2.1,'Iris-virginica').
  388data(5.7,2.5,5.0,2.0,'Iris-virginica').
  389data(5.8,2.8,5.1,2.4,'Iris-virginica').
  390data(6.4,3.2,5.3,2.3,'Iris-virginica').
  391data(6.5,3.0,5.5,1.8,'Iris-virginica').
  392data(7.7,3.8,6.7,2.2,'Iris-virginica').
  393data(7.7,2.6,6.9,2.3,'Iris-virginica').
  394data(6.0,2.2,5.0,1.5,'Iris-virginica').
  395data(6.9,3.2,5.7,2.3,'Iris-virginica').
  396data(5.6,2.8,4.9,2.0,'Iris-virginica').
  397data(7.7,2.8,6.7,2.0,'Iris-virginica').
  398data(6.3,2.7,4.9,1.8,'Iris-virginica').
  399data(6.7,3.3,5.7,2.1,'Iris-virginica').
  400data(7.2,3.2,6.0,1.8,'Iris-virginica').
  401data(6.2,2.8,4.8,1.8,'Iris-virginica').
  402data(6.1,3.0,4.9,1.8,'Iris-virginica').
  403data(6.4,2.8,5.6,2.1,'Iris-virginica').
  404data(7.2,3.0,5.8,1.6,'Iris-virginica').
  405data(7.4,2.8,6.1,1.9,'Iris-virginica').
  406data(7.9,3.8,6.4,2.0,'Iris-virginica').
  407data(6.4,2.8,5.6,2.2,'Iris-virginica').
  408data(6.3,2.8,5.1,1.5,'Iris-virginica').
  409data(6.1,2.6,5.6,1.4,'Iris-virginica').
  410data(7.7,3.0,6.1,2.3,'Iris-virginica').
  411data(6.3,3.4,5.6,2.4,'Iris-virginica').
  412data(6.4,3.1,5.5,1.8,'Iris-virginica').
  413data(6.0,3.0,4.8,1.8,'Iris-virginica').
  414data(6.9,3.1,5.4,2.1,'Iris-virginica').
  415data(6.7,3.1,5.6,2.4,'Iris-virginica').
  416data(6.9,3.1,5.1,2.3,'Iris-virginica').
  417data(5.8,2.7,5.1,1.9,'Iris-virginica').
  418data(6.8,3.2,5.9,2.3,'Iris-virginica').
  419data(6.7,3.3,5.7,2.5,'Iris-virginica').
  420data(6.7,3.0,5.2,2.3,'Iris-virginica').
  421data(6.3,2.5,5.0,1.9,'Iris-virginica').
  422data(6.5,3.0,5.2,2.0,'Iris-virginica').
  423data(6.2,3.4,5.4,2.3,'Iris-virginica').
  424data(5.9,3.0,5.1,1.8,'Iris-virginica').
  425
  426:- assert_data_means.