%% 
%% 
% Fixed-sized clusters k-means clustering algorithm
% Uses Hungarian algorithm in assignment phase
% Mikko Malinen
% University of Eastern Finland
% 2013, 2024

function partition_best=Fixed_sized_clusters(X,sizes_of_clusters,k)
% sizes_of_clusters is a vector containing cluster sizes 



tic;

% data



%X = load('X:\clustering\datasets\s1.txt'); % 2 dim. 5000 points 15 clusters
%X = load('X:\clustering\datasets\s2.txt'); % 2 dim. 5000 points 15 clusters
%X = load('X:\clustering\datasets\s3.txt'); % 2 dim. 5000 points 15 clusters
%X = load('X:\clustering\datasets\s4.txt'); % 2 dim. 5000 points 15 clusters
%X = load('X:\clustering\datasets\bridge.txt'); % 16 dim, 4096 points, 256 clusters
%X = load('X:\clustering\datasets\thyroid.txt'); % 5dim. 215 points 2 clust
%X = load('X:\clustering\datasets\iris.txt'); % 4dim. 150 points 3 clust
%X = load('X:\clustering\datasets\housec8.txt');  % 3 dim 34k vect 256 clust
%X = load('X:\clustering\datasets\missa1.txt');  % 16 dim 6480 points 256 clust
%X = load('X:\clustering\datasets\wine.txt'); % 13 dim, 178 points, 3 clust
%X = load('X:\clustering\datasets\subsets\s1_subset_150.txt');  % 150 points, k=15, d=2
%X = load('X:\clustering\datasets\subsets\s4_subset_150.txt');  % 150 points, k=15, d=2
%X = load('X:\clustering\datasets\thyroid.txt'); % 5dim. 215 points 2 clust
%X = load('X:\clustering\datasets\wine.txt'); % 13 dim, 178 points, 3 clust

%X = load('X:\clustering\datasets\a_5clust.txt');
%X = load('X:\clustering\datasets\breast.txt');  % 2 clust
%X = load('X:\clustering\datasets\yeast.txt');  % 1484 points, 10 clust
%X = load('X:\clustering\datasets\a1.txt'); % 3000 points, 20 clust
%X = load('X:\clustering\datasets\DIM032.txt'); % 1024 points,16 clust
%X = load('X:\clustering\datasets\DIM064.txt'); % 1024 points, 16 clust, 64 dim.
%X = load('X:\clustering\datasets\DIM128.txt'); % 1024 points,16 clust, 128 dim.
%X = load('X:\clustering\datasets\dim256.txt'); % 1024 points,16 clust, 256 dim.
%X = load('X:\clustering\datasets\dim512.txt'); % 1024 points,16 clust, 512 dim.
%X = load('X:\clustering\datasets\dim1024.txt'); % 1024 points,16 clust, 512 dim.
%X = load('X:\clustering\datasets\wdbc.txt'); % 569 points, 2 clust
%X = load('X:\clustering\datasets\glass.txt'); % 214 points, 7 clust

%X =load('X:\clustering\datasets\iris.txt');  % 150 points, k=3, d=4
%X = load('X:\clustering\datasets\subsets\iris_subset_50.txt'); % 50 points, k=3, d= 4
%X = load('X:\clustering\datasets\subsets\thyroid_subset_50.txt'); % 50
%points , k= 2, d= 5
%X = load('X:\clustering\datasets\subsets\wine_subset_50.txt'); % 50
%points, k = 3, d = 13
%X = load('X:\clustering\datasets\subsets\breast_subset_50.txt'); % 50
%points, k= 2, d=9
%X = load('X:\clustering\datasets\subsets\s1_subset_150.txt');  % 150 points, k=15, d=2
%X = load('X:\clustering\datasets\subsets\s2_subset_150.txt');  % 150 points, k=15, d=2
%X = load('X:\clustering\datasets\subsets\s3_subset_150.txt');  % 150 points, k=15, d=2
%X = load('X:\clustering\datasets\subsets\s4_subset_150.txt');  % 150 points, k=15, d=2
%
%X = load('X:\clustering\datasets\subsets\s1_subset_50.txt');  % 50 points, k=15, d=2
%X = load('X:\clustering\datasets\subsets\s2_subset_50.txt');  % 50 points, k=15, d=2
%X = load('X:\clustering\datasets\subsets\s3_subset_50.txt');  % 50 points, k=15, d=2
%X = load('X:\clustering\datasets\subsets\s4_subset_50.txt');  % 50 points, k=15, d=2
%X = load('X:\clustering\datasets\subsets\s1_subset_500.txt');  % 500 points, k=15, d=2
%X = load('X:\clustering\datasets\subsets\s2_subset_500.txt');  % 500 points, k=15, d=2
%X = load('X:\clustering\datasets\subsets\s3_subset_500.txt');  % 500 points, k=15, d=2
%X = load('X:\clustering\datasets\subsets\s4_subset_500.txt');  % 500 points, k=15, d=2
%X = load('X:\clustering\datasets\subsets\s1_subset_1000.txt');  % 1000 points, k=15, d=2
%X = X(1:200,:);
%X = load('X:\clustering\datasets\s2.txt');  % 5000 points, k=15, d=2
%X = load('X:\clustering\datasets\thyroid.txt'); % 5dim. 215 points 2 clust
%X = load('X:\clustering\datasets\wine.txt'); % 13 dim, 178 points, 3 clust

% rottavalikoitu.txt
%fid = fopen('rottavalikoitu.txt');
% n vectors in 14 dimensions, 1st dim is gender
%X = textscan(fid,'%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f,%f')
%fclose(fid);
%X = cell2mat(X);
%X

% normalizing rottavalikoitu.txt
%for i=1:size(X,2) % dimensions
%    X(:,i) = (X(:,i)-mean(X(:,i)))/std(X(:,i))
%end
    





% number of points
n = size(X,1);

% number of clusters
%k = 3;    % argument

% cluster sizes

%size_of_cluster = [3 3 1];   % argument
size_of_cluster_cumsum = cumsum(sizes_of_clusters);

% minimum size of a cluster

minimum_size_of_a_cluster = floor(n/k);


% dimensionality

d = size(X,2);

MSE_best = 0; % dummy value

number_of_iterations_distribution = zeros(100,1);

best_earth_movers = 0; % some value

tic;

for repeats = 1:10      % 1:10   1:100

% initial centroids

for j = 1:k
pass = 0;
while pass == 0
    i = randi(n);
    pass = 1;
    for l = 1:j-1
       if X(i,:) == C(l,:) pass = 0;
       end
    end
end
C(j,:) = X(i,:);
end


partition = 0;                 % dummy value
partition_previous = -1;       % dummy value
partition_changed = 1;

kmeans_iteration_number = 0;

while ((partition_changed)&&(kmeans_iteration_number<100))% kmeans iterations
    
partition_previous = partition;

% kmeans assignment step

% setting cost matrix for Hungarian algorithm
costMat = zeros(n);
for i=1:n
    for j = 1:n
        costMat(i,j) = (X(j,:)-C(cl_number_of_cl_slot(i,k,size_of_cluster_cumsum),:))*(X(j,:)-C(cl_number_of_cl_slot(i,k,size_of_cluster_cumsum),:))';
    end
end

% Execute Hungarian algorithm
[assignment,cost] = munkres(costMat);

% zero partitioning
for i = 1:n
    partition(i) = 0;
end

% find current partitioning from hungarian algorithm result
for i = 1:n 
    if assignment(i) ~= 0
            partition(assignment(i))=cl_number_of_cl_slot(i,k,size_of_cluster_cumsum);
    end
end

% kmeans update step

for j = 1:k
C(j,:) = mean(X(find(partition==j),:));
end


kmeans_iteration_number = kmeans_iteration_number +1

partition_changed = sum(partition~=partition_previous);

end  % kmeans iterations


MSE = 0;
for i = 1:n
    MSE = MSE + ((X(i,:)-C(partition(i),:))*(X(i,:)-C(partition(i),:))')/n;
end

if (MSE<MSE_best)||(repeats==1)
    MSE_best = MSE;
    C_best = C;
    partition_best = partition;
end

MSE_repeats(repeats) = MSE;

number_of_iterations_distribution(kmeans_iteration_number) = number_of_iterations_distribution(kmeans_iteration_number)+1;



%calculate the earth mover's distance
for i = 1:k
    cluster_size(i) = 0;
end
for j = 1:n
    cluster_size(partition(j)) = cluster_size(partition(j)) + 1;
end
earth_movers = 0;
for i = 1:k
    %cluster_size(i)
    if cluster_size(i)>n/k+0.9999 
        earth_movers = earth_movers+cluster_size(i)-ceil(n/k);
    end
end
earth_movers = earth_movers*2;
earth_movers_array(repeats) = earth_movers;
if (earth_movers < best_earth_movers)||(repeats==1)
    best_earth_movers = earth_movers;
    best_partition = partition;
end




end % repeats
    

% new notation

C = C_best;
partition = partition_best;
MSE = MSE_best;

number_of_iterations_distribution

MSE

mean_MSE_repeats = mean(MSE_repeats)
std_MSE_repeats = std(MSE_repeats)

mean_earthmovers_array = mean(earth_movers_array)


%print cluster contents
for i=1:k
cluster_number = i
    find(partition_best==i)
end


toc;

figure   
%X(:,1) = X(:,3);  % to view some other dimension
%plot(C(:,1),C(:,2),'gO');
%hold on
plot(X(find(partition==1),1),X(find(partition==1),2),'r+');
if k>1
    hold on
    plot(X(find(partition==2),1),X(find(partition==2),2),'bO');
end
if k>2
    hold on
    plot(X(find(partition==3),1),X(find(partition==3),2),'r.');
end
if k>3
    hold on
    plot(X(find(partition==4),1),X(find(partition==4),2),'b.');
end
if k>4
    hold on
    plot(X(find(partition==5),1),X(find(partition==5),2),'r+');
end
if k>5
    hold on
    plot(X(find(partition==6),1),X(find(partition==6),2),'bO');
end
if k>6
    hold on
    plot(X(find(partition==7),1),X(find(partition==7),2),'b+');
end
if k>7
    hold on
    plot(X(find(partition==8),1),X(find(partition==8),2),'b+');
end
if k>8
    hold on
    plot(X(find(partition==9),1),X(find(partition==9),2),'r.');
end
if k>9
    hold on
    plot(X(find(partition==10),1),X(find(partition==10),2),'b.');
end
if k>10
    hold on
    plot(X(find(partition==11),1),X(find(partition==11),2),'g+');
end
if k>11
    hold on
    plot(X(find(partition==12),1),X(find(partition==12),2),'gO');
end
if k>12
    hold on
    plot(X(find(partition==13),1),X(find(partition==13),2),'g+');
end
if k>13
    hold on
    plot(X(find(partition==14),1),X(find(partition==14),2),'g.');
end
if k>14
    hold on
    plot(X(find(partition==15),1),X(find(partition==15),2),'g.');
end

toc;
    


end



function cl_number=cl_number_of_cl_slot(i,k,size_of_cluster_cumsum)

cl_number = min(min(find(i<=size_of_cluster_cumsum)));

end