马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。
您需要 登录 才可以下载或查看,没有账号?我要加入
x
- function [u,sig,t,iter] = fit_mix_gaussian( X,M )
- clear all
- %
- % fit_mix_gaussian - fit parameters for a mixed-gaussian distribution using EM algorithm
- %
- % format: [u,sig,t,iter] = fit_mix_gaussian( X,M )
- %
- % input: X - input samples, N*1 vector
- % M - number of gaussians which are assumed to compose the distribution
- %
- % output: u - fitted mean for each gaussian
- % sig - fitted standard deviation for each gaussian
- % t - probability of each gaussian in the complete distribution
- % iter- number of iterations done by the function
- %
- % run with default values
- % if ~nargin
- % % M = round(rand*5)+1;
- % M=1;
- % sig = rand(1,M)*3;
- % u = randn(1,M)*8;
- % prob= rand(1,M);
- % [u,sig,t,iter] = fit_mix_gaussian( build_mix_gaussian( u,sig,prob,1000*M ),M );
- % return
- % end
- %%%%%%%%added by me%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
- M=2;
- sig = 3;
- u = 8;
- prob= 1;
- [X,N]=build_mix_gaussian( u,sig,prob,256*M );
- figure;
- plot(X);
- %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
- % initialize and initial guesses
- N = length( X );
- Z = ones(N,M) * 1/M; % indicators vector
- P = zeros(N,M); % probabilities vector for each sample and each model
- t = ones(1,M) * 1/M; % distribution of the gaussian models in the samples
- u = linspace(min(X),max(X),M); % mean vector
- sig2 = ones(1,M) * var(X) / sqrt(M); % variance vector
- C = 1/sqrt(2*pi); % just a constant
- Ic = ones(N,1); % - enable a row replication by the * operator
- Ir = ones(1,M); % - enable a column replication by the * operator
- Q = zeros(N,M); % user variable to determine when we have converged to a steady solution
- thresh = 1e-3;
- step = N;
- last_step = inf;
- iter = 0;
- min_iter = 10;
- % main convergence loop, assume gaussians are 1D
- while ((( abs((step/last_step)-1) > thresh) & (step>(N*eps)) ) | (iter<min_iter) )
-
- % E step
- % ========
- Q = Z;
- P = C ./ (Ic*sqrt(sig2)) .* exp( -((X*Ir - Ic*u).^2)./(2*Ic*sig2) );
- %%%%%%1/N*1X1*M*exp((N*1XM-N*1X1*M)/N*1*1*M)
-
-
- for m = 1:M
- Z(:,m) = (P(:,m)*t(m))./(P*t(:));
- end
-
- % estimate convergence step size and update iteration number
- prog_text = sprintf(repmat( '\b',1,(iter>0)*12+ceil(log10(iter+1)) ));
- iter = iter + 1;
- last_step = step * (1 + eps) + eps;
- step = sum(sum(abs(Q-Z)));
- fprintf( '%s%d iterations\n',prog_text,iter );
- % M step
- % ========
- Zm = sum(Z); % sum each column
- Zm(find(Zm==0)) = eps; % avoid devision by zero
- u = (X')*Z ./ Zm;
-
- sig2 = sum(((X*Ir - Ic*u).^2).*Z) ./ Zm;
- t = Zm/N;
- end
- % plot the fitted distribution
- % =============================
- sig = sqrt( sig2 );
- plot_mix_gaussian( u,sig,t );
复制代码 |