|
楼主

楼主 |
发表于 2012-5-24 21:50:53
|
只看该作者
Stochastic Gradient Decending Logistic Regression in SAS
From oloolo's blog on SasProgramming
<p><a href="http://feedads.g.doubleclick.net/~a/PjvzRIUQaqOILxxFQcDQHXpxHU8/0/da"><img src="http://feedads.g.doubleclick.net/~a/PjvzRIUQaqOILxxFQcDQHXpxHU8/0/di" border="0" ismap="true"></img></a><br/>
<a href="http://feedads.g.doubleclick.net/~a/PjvzRIUQaqOILxxFQcDQHXpxHU8/1/da"><img src="http://feedads.g.doubleclick.net/~a/PjvzRIUQaqOILxxFQcDQHXpxHU8/1/di" border="0" ismap="true"></img></a></p>Test the Stochastic Gradient Decending Logistic Regression in SAS. The logic and code follows the code piece of Ravi Varadhan, Ph.D from this <a href="http://r.789695.n4.nabble.com/Stochastic-Gradient-Ascent-for-logistic-regression-td884272.html" target="_blank"><strong><em>discussion</em></strong> </a>of R Help. The blog <a href="http://sasdiehard.blogspot.com/" target="_blank"><strong><em>SAS Die Hard</em></strong></a> also has a post about SGD Logistic Regression in SAS.<br />
<br />
<br />
<br />
<pre style="background-color: #ebebeb; border-bottom: #999999 1px dashed; border-left: #999999 1px dashed; border-right: #999999 1px dashed; border-top: #999999 1px dashed; color: #000001; font-family: Andale Mono, Lucida Console, Monaco, fixed, monospace; font-size: 12px; line-height: 14px; overflow: auto; padding-bottom: 5px; padding-left: 5px; padding-right: 5px; padding-top: 5px; width: 100%;"><code>
filename foo url "http://www.biostat.jhsph.edu/~ririzarr/Teaching/754/lbw.dat" ;
data temp;
infile foo length=len;
input low age lwt race smoke ptl ht ui ftv bwt;
put low age lwt race smoke ptl ht ui ftv bwt;
if _n_>1;
run;
proc standard data=temp out=temp mean=0 std=1;
var age lwt smoke ht ui;
run;
proc contents data=temp out=vars(keep=varnum name type) noprint; run;
proc sql noprint;
select name into :covars separated by " "
from vars
where substr(name, 1, 1)="x"
;
select cats("b_", name) into :covars2 separated by " "
from vars
where substr(name, 1, 1)="x"
;
select count(*)+1 into :nparms
from vars
where substr(name, 1, 1)="x"
;
quit;
%put &covars2;
sasfile _xbeta close;
%lr_sgd(temp, beta, z, &covars, alpha=0.008, decay=0.8, criterion=0.00001, maxiter=1000);
options fullstimer;
proc logistic data=temp outest=_beta desc noprint;
model low = age lwt smoke ht ui;
run;
</code></pre>
The macro %LR_SGD. <br />
<pre style="background-color: #ebebeb; border-bottom: #999999 1px dashed; border-left: #999999 1px dashed; border-right: #999999 1px dashed; border-top: #999999 1px dashed; color: #000001; font-family: Andale Mono, Lucida Console, Monaco, fixed, monospace; font-size: 12px; line-height: 14px; overflow: auto; padding-bottom: 5px; padding-left: 5px; padding-right: 5px; padding-top: 5px; width: 100%;"><code>
/*
SAS macro:
Logistic Regression using Stochastic Gradient Descent.
Name:
%ls_sgd();
Copyright (c) 2009, Liang Xie (Contact me @ xie1978 at gmail dot com)
The SAS macro is a demonstration of an implementation of logistic
regression modelstrained by Stochastic Gradient Decent (SGD).This
program reads a training set specified as &dsn_in, trains a logistic
regression model, and outputs the estimated coefficients to &outest.
Example usage:
%let inputdata=train_data;
%let beta=coefficient;
%let response=Event;
%lr_sgd(&inputdata, &beta, &response, &covars,
alpha=0.008, decay=0.8,
criterion=0.00001, maxiter=1000);
The following topics are not covered for simplicity:
- bias term
- regularization
- multiclass logistic regression (maximum entropy model)
- calibration of learning rate
<i>Distributed under GNU Affero General Public License version 3. This
program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, only version 3 of the
License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
</i>*/
%macro logistic(dsn_in, outest, response, alpha=0.0005);
proc score data=&dsn_in score=&outest type=parms out=score(keep=score);
var intercept &covars;
run;
data _xtemp/view=_xtemp;
merge &dsn score ;
_w=&response - 1/(1+exp(-score));
/*
array x{*} intercept &covars;
_w=&response - 1/(1+exp(-score));
do i=1 to dim(x); x[i]=x[i]*_w; end;
*/
run;
data _x&outest;
array x{*} intercept &covars;
array _x{*} b_intercept &covars2;
retain b_intercept &covars2;
retain logneg logpos 0;
modify _x&outest;
do i=1 to dim(x); x[i]=_x[i]; end;
do until (eof);
set _xtemp end=eof;
do i=1 to dim(x);
_x[i]=_x[i]+&alpha*x[i]*_w;
end;
end;
replace;
run;
%mend;
%macro compare(dsn1, dsn2);
data _null_;
merge &dsn1 &dsn2;
array _x1{*} intercept &covars;
array _x2{*} b_intercept &covars2;
retain maxdiff 0;
do i=1 to dim(_x1);
maxdiff=max(maxdiff, abs(_x1[i]-_x2[i]));
*put _x1[*]=;
*put _x2[*]=;
end;
call symput('maxdiff', maxdiff);
run;
%mend;
%macro lr_sgd(dsn, outest, response, covars,
alpha=0.0005, decay=0.9,
criterion=0.00001, maxiter=1000);
options nosource nonotes;
options nomlogic nomprint;
%local i t0 t1 dt maxdiff status stopiter;
%let t00=%sysfunc(datetime());
data &dsn;
set &dsn;
intercept=1; _w=1;
run;
data &outest;
retain _TYPE_ "PARMS" _NAME_ "SCORE";
array x{*} intercept &covars;
do i=1 to dim(x);
x[i]=0;
end;
drop i;
output;
run;
data _x&outest;
retain _TYPE_ "PARMS" _NAME_ "SCORE";
array bx{*} b_intercept &covars2;
array x{*} intercept &covars;
set &outest;
do j=1 to dim(x); bx[j]=x[j]; end;
keep b_intercept &covars2 _TYPE_ _NAME_;
drop j;
run;
sasfile _x&outest load;
%let stopiter=&maxiter;
%let status=Not Converged.;
%do i=1 %to &maxiter;
%let t0=%sysfunc(datetime());
%logistic(&dsn, &outest, &response, alpha=&alpha);
%compare(&outest, _x&outest);
data &outest;
retain _TYPE_ "PARMS" _NAME_ "SCORE";
array bx{*} b_intercept &covars2;
array x{*} intercept &covars;
set _x&outest;
do j=1 to dim(x); x[j]=bx[j]; end;
keep intercept &covars _TYPE_ _NAME_;
drop j;
run;
%let alpha=%sysevalf(&alpha * &decay);
%let alpha=%sysfunc(max(0.00005, &alpha));
%let t1=%sysfunc(datetime());
%let dt=%sysfunc(round(&t1-&t0, 0.001));
%put Iteration &i, time used &dt, converge criteria is &maxdiff;
%if %sysevalf(&maxdiff<&criterion) %then %do;
%let stopiter=&i;
%let i=%eval(&maxiter+1);
%let status=Converged.;
%end;
%end;
sasfile _x&outest close;
%let t11=%sysfunc(datetime());
%let dt=%sysfunc(round(&t11-&t00, 0.01));
%put Total Time is &dt sec.;
%put Total Iteration is &stopiter, convergence status is &status;
%put At Final Iteration, max difference is &maxdiff;
options mlogic mprint notes source;
%mend;
</code></pre><div class="blogger-post-footer"><img width='1' height='1' src='https://blogger.googleusercontent.com/tracker/29815492-2505374346425236390?l=www.sas-programming.com' alt='' /></div><img src="http://feeds.feedburner.com/~r/SasProgramming/~4/d_1rR80APks" height="1" width="1"/> |
|