声振论坛

 找回密码
 我要加入

QQ登录

只需一步,快速开始

查看: 2014|回复: 1

[人工智能] 隐节点合成算法matlab程序

[复制链接]
发表于 2007-10-13 15:55 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?我要加入

x
该程序来源于神经网络结构设计的理论与方法
我将它录入并调试通过
相关的理论部分可以参考书神经网络结果设计的理论与方法
程序较多 输入和调试花了我几天时间
一共一个主函数 7个子函数  
我就不在这里重复贴过来了
http://www.2nsoft.cn/bbs/read.php?tid=7675
有兴趣的可以学习一下
为了方便起见 想转帖贴过来也可以的 

评分

1

查看全部评分

回复
分享到:

使用道具 举报

 楼主| 发表于 2007-10-13 22:51 | 显示全部楼层
function main()
indim=5;
outdim=1;
hidden1unitnum=5;
hidden2unitnum=5;
allsamnum=32;
traindatanum=24;
testdatanum=8;
allsamin=[];
for num=0:allsamnum-1
    str=dec2bin(num);
    [xxx,len]=size(str);
    vect=[];
    for i=1:len
        ch=str(i);
        vect=[vect str2num(ch)];
    end
    if(len<indim)
        vect=[zeros(1,indim-len) vect];
    end
    allsamin=[allsamin vect'];
end
allsamout=(allsamin(1,:)|allsamin(2,:))&...
    (allsamin(3,:)|allsamin(4,:)|allsamin(5,:));
permpos=randperm(allsamnum);
traindatain=allsamin(:,permpos(:,1:traindatanum));
traindataout=allsamout(:,permpos(1:traindatanum));
testdatain=allsamin(:,permpos(:,traindatanum+1:traindatanum+testdatanum));
testdataout=allsamout...
    (:,permpos(:,traindatanum+1:traindatanum+testdatanum));
w1=0.5*rands(hidden1unitnum,indim);
b1=0.5*rands(hidden1unitnum,1);
w2=0.5*rands(hidden2unitnum,hidden1unitnum);
b2=0.5*rands(hidden2unitnum,1);
w3=0.5*rands(outdim,hidden2unitnum);
b3=0.5*rands(outdim,1);
lr=0.9;
alpha=0.9;
maxepoch=2000;
errcombine=0.001;
errgoal=0.00005;
unitscombinethreshold=0.8;
biascombinethreshold=0.01;
w1ex=[w1 b1];
w2ex=[w2 b2];
w3ex=[w3 b3];
traindatainex=[traindatain',ones(traindatanum,1)]';
errhistory=[];
resizeflag=1;
for epoch=1:maxepoch
    if(resizeflag==1),
        [hidden2unitnum,hidden1unitnum]=size(w2ex);
        hidden1unitnum=hidden1unitnum-1;
        w2=w2ex(:,1:hidden1unitnum);
        w3=w3ex(:,1:hidden2unitnum);
        dw1ex=zeros(size(w1ex));
        dw2ex=zeros(size(w2ex));
        dw3ex=zeros(size(w3ex));
        resizeflag=0;
    end
    hidden1out=logsig(w1ex*traindatainex);
    hidden1outex=[hidden1out',ones(traindatanum,1)]';
    hidden2out=logsig(w2ex*hidden1outex);
    hidden2outex=[hidden2out' ones(traindatanum,1)]';
    networkout=logsig(w3ex*hidden2outex);
    error=traindataout-networkout;
    sse=sumsqr(error);
    errhistory=[errhistory sse];
    if(sse<errcombine),
        hidden1var=var(hidden1out')';
        hidden2var=var(hidden2out')';
        hidden1corr=corrcoef(hidden1out');
        hidden2corr=corrcoef(hidden2out');
        [hidden1unit1,hidden1unit2]=findunittocombine(hidden1corr,...
            hidden1var,unitscombinethreshold,biascombinethreshold);
        if(hidden1unit1>0),
            if(hidden1unit2>0),
            [a,b]=linearreg(hidden1out(hidden1unit1,:),...
                hidden1out(hidden1unit2,:));
            epoch
            combinetype=11
            drawcorrelatedunitsout(hidden1out...
                (hidden1unit1,:),hidden1out(hidden1unit2,:));
            [w1ex,w2ex]=combinetwounits(hidden1unit1,...
                hidden1unit2,a,b,w1ex,w2ex);
        else
            epoch
            combine=12;
            drawbiasedunitout(hidden1out(hidden1unit1,:));
            unitmean=mean(hidden1out(hidden1unit1,:));
            [w1ex,w2ex]=combineunittobias...
                (hidden1unit1,unitmean,w1ex,w2ex);
        end
        resizeflag=1;
        continue;
    end
    [hidden2unit1,hidden2unit2]=findunittocombine(hidden2corr,...
        hidden2var,unitscombinethreshold,biascombinethreshold);
    if(hidden2unit1>0),
        if(hidden2unit2>0),
            epoch
            combinetype=21
            [a,b]=linearreg(hidden2out...
                (hidden2unit1,:),hidden2out(hidden2unit2,:));
               drawcorrelatedunitsout(hidden2out(hidden2unit1,:),...
                   hidden2out(hidden2unit2,:));
               [w2ex,w3ex]=combinetwounits(hidden2unit1,...
                   hidden2unit2,a,b,w2ex,w3ex);
           else
            epoch
            combinetype=22
            drawbiasedunitout(hidden2out(hidden2unit1,:));
            unitmean=mean(hidden2out(hidden2unit1,:));
            [w2ex,w3ex]=combineunittobias(hidden2unit1,unitmean,w2ex,w3ex);
        end
       resizeflag=1;
        continue;
    end
end
        
      
if(sse<errgoal),break,end
delta3=error.*networkout.*(1-networkout);
delta2=w3'*delta3.*hidden2out.*(1-hidden2out);
delta1=w2'*delta2.*hidden1out.*(1-hidden1out);
dw1ex0=lr*dw1ex;
dw2ex0=lr*dw2ex;
dw3ex0=lr*dw3ex;
dw3ex=delta3*hidden2outex';
dw2ex=delta2*hidden1outex';
dw1ex=delta1*traindatainex';
w1ex=w1ex+lr*dw1ex+alpha*dw1ex0;
w2ex=w2ex+lr*dw2ex+alpha*dw2ex0;
w3ex=w3ex+lr*dw3ex+alpha*dw3ex0;
w2=w2ex(:,1:hidden1unitnum);
w3=w3ex(:,1:hidden2unitnum);
end
hidden1unitnum
hidden2unitnum
w1=w1ex(:,1:indim);
b1=w1ex(:,indim+1);
w2=w2ex(:,1:hidden1unitnum);
b2=w2ex(:,hidden1unitnum+1);
w3=w3ex(:,1:hidden2unitnum);
b3=w3ex(:,hidden2unitnum+1);
testnnout=bpnet(testdatain,w1,b1,w2,b2,w3,b3);
binout=testnnout>0.5;
errnum=sum(testnnout-binout)
figure
echo off
axis on
grid
hold on
[xx,num]=size(errhistory);
semilogy(1:num,errhistory,'r-');
plot(1:num,errhistory,'r-');
function [unit1,unit2]=findunittocombine(hiddencorr,hiddenvar,...
    unitscombinethreshold,biascombinethreshold)
corrtri=triu(hiddencorr)-eye(size(hiddencorr));
while(1)
    [val,pos]=max(abs(corrtri));
    [maxcorr,unit2]=max(val);
    if(maxcorr<unitscombinethreshold)
        unit1=0;unit2=0;
        break
    end
    unit1=pos(unit2);
    if(hiddenvar(unit1)>biascombinethreshold &...
            hiddenvar(unit2)>biascombinethreshold)
        break
    else
        corrtri(unit1,unit2)=0;
    end
end
if(unit1>0)return;
end
[minvar,unit]=min(hiddenvar);
if(minvar<biascombinethreshold)
    unit1=unit;
    unit2=0;
end

function [a,b]=linearreg(vect1,vect2)
[xxx,n]=size(vect1);
meanv1=mean(vect1);
meanv2=mean(vect2);
a=(vect1*vect2'/n-meanv1*meanv2)/(vect1*vect2'/n-meanv1^2);
b=meanv2-a*meanv1;
function out=bpnet(in,w1,b1,w2,b2,w3,b3)
[xxx,innum]=size(in);
hidden1out=logsig(w1*in+repmat(b1,1,innum));
hidden2out=logsig(w2*hidden1out+repmat(b2,1,innum));
out=logsig(w3*hidden2out+repmat(b3,1,innum));

function [w1ex,w2ex]=combinetwounits(unit1,unit2,a,b,w1ex,w2ex)
[xxx,biascol]=size(w2ex);
w2ex(:,unit1)=w2ex(:,unit1)+a*w2ex(:,unit2);
w2ex(:,biascol)=w2ex(:,biascol)+b*w2ex(:,unit2);
w1ex(unit2,:)=[];
w2ex(:,unit2)=[];

function [w1ex,w2ex]=combineunittobias(unit,unitmean,w1ex,w2ex)
[xxx,biascol]=size(w2ex);
w2ex(:,biascol)=w2ex(:,biascol)+unitmean*w2ex(:,unit);
w1ex(unit,:)=[];
w2ex(:,unit)=[];

function drawcorrelatedunitsout(unitout1,unitout2)
[xxx,ptnum]=size(unitout1);
figure
echo off
axis([0 ptnum 0 1])
axis on
grid
hold on
plot(1:ptnum,unitout1,'b-')
plot(1:ptnum,unitout2,'k-')

function  drawbiasedunitout(unitout)
[xxx,ptnum]=size(unitout);
figure('position',[300 300 400 300])
echo off
axis([0 ptnum 0 1])
axis on
grid
hold on
您需要登录后才可以回帖 登录 | 我要加入

本版积分规则

QQ|小黑屋|Archiver|手机版|联系我们|声振论坛

GMT+8, 2024-11-28 12:12 , Processed in 0.072701 second(s), 21 queries , Gzip On.

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表