ADMB Documentation  11.1.2192
 All Classes Files Functions Variables Typedefs Friends Defines
vgamdev.cpp
Go to the documentation of this file.
00001 /*
00002   $Id: vgamdev.cpp 1935 2014-04-26 02:02:58Z johnoel $
00003 
00004   Author: David Fournier
00005   Copyright (c) 2008, 2009, 2010 Regents of the University of California
00006  */
00007 #include <fvar.hpp>
00008 #define ITMAX 100
00009 //#define EPS 3.0e-7
00010 #define EPS 1.0e-9
00011 #define FPMIN 1.0e-30
00012 static void gcf(double& gammcf,double a,double x,double &gln);
00013 static void gser(double& gamser,double a,double x,double& gln);
00014 
00019   dvariable gamma_deviate(const prevariable& _x,const prevariable& _a)
00020   {
00021     prevariable& x= (prevariable&)(_x);
00022     prevariable& a= (prevariable&)(_a);
00023 
00024     dvariable y=cumd_norm(x);
00025 
00026     y=.9999*y+.00005;
00027 
00028     dvariable z=inv_cumd_gamma(y,a);
00029 
00030     return z;
00031   }
00032 
00033 
00034 static double gammp(double a,double x)
00035 {
00036   double gamser,gammcf,gln;
00037 
00038   if (x < 0.0 || a <= 0.0)
00039     cerr << "Invalid arguments in routine gammp" << endl;
00040   if (x < (a+1.0)) {
00041     gser(gamser,a,x,gln);
00042     return gamser;
00043   } else {
00044     gcf(gammcf,a,x,gln);
00045     return 1.0-gammcf;
00046   }
00047 }
00048 
00055 static void gcf(double& gammcf,double a,double x,double &gln)
00056 {
00057   int i;
00058   double an,b,c,d,del,h;
00059 
00060   gln=gammln(a);
00061   b=x+1.0-a;
00062   c=1.0/FPMIN;
00063   d=1.0/b;
00064   h=d;
00065   for (i=1;i<=ITMAX;i++) {
00066     an = -i*(i-a);
00067     b += 2.0;
00068     d=an*d+b;
00069     if (fabs(d) < FPMIN) d=FPMIN;
00070     c=b+an/c;
00071     if (fabs(c) < FPMIN) c=FPMIN;
00072     d=1.0/d;
00073     del=d*c;
00074     h *= del;
00075     if (fabs(del-1.0) < EPS) break;
00076   }
00077   if (i > ITMAX)
00078     cerr << "a too large, ITMAX too small in gcf" << endl;
00079   gammcf=exp(-x+a*log(x)-(gln))*h;
00080 }
00081 
00088 static void gser(double& gamser,double a,double x,double& gln)
00089 {
00090   int n;
00091   double sum,del,ap;
00092 
00093   gln=gammln(a);
00094   if (x <= 0.0) {
00095     if (x < 0.0)
00096       cerr << "x less than 0 in routine gser" << endl;
00097     gamser=0.0;
00098     return;
00099   } else {
00100     ap=a;
00101     del=sum=1.0/a;
00102     for (n=1;n<=ITMAX;n++) {
00103       ++ap;
00104       del *= x/ap;
00105       sum += del;
00106       if (fabs(del) < fabs(sum)*EPS) {
00107         gamser=sum*exp(-x+a*log(x)-(gln));
00108         return;
00109       }
00110     }
00111     cerr << "a too large, ITMAX too small in routine gser" << endl;
00112     return;
00113   }
00114 }
00115 
00116 static double get_initial_u(double a,double y);
00117 
00118 static double Sn(double x,double a);
00119 
00120 #include <df32fun.h>
00121 df3_two_variable cumd_gamma(const df3_two_variable& x,
00122   const df3_two_variable& a);
00123 
00124 dvariable inv_cumd_gamma(const prevariable& _y,const prevariable& _a)
00125 {
00126   double a=value(_a);
00127   double y=value(_y);
00128   if (a<0.05)
00129   {
00130     cerr << "a musdt be > 0.1" << endl;
00131     ad_exit(1);
00132   }
00133   double u=get_initial_u(a,y);
00134   double h;
00135   int loop_counter=0;
00136   do
00137   {
00138     loop_counter++;
00139     double z=gammp(a,a*exp(u));
00140     double d=y-z;
00141     //cout << d << endl;
00142     double log_fprime=a*log(a)+a*(u-exp(u)) -gammln(a);
00143     double fprime=exp(log_fprime);
00144     h=d/fprime;
00145     u+=h;
00146     if (loop_counter>1000)
00147     {
00148       cerr << "Error in inv_cumd_gamma"
00149         " maximum number of interations exceeded for values"
00150         << endl << "  x = " << y << "  a =  " << a  << "  h =  " << h  << endl;
00151     }
00152   }
00153   while(fabs(h)>1.e-12);
00154 
00155   double x=a*exp(u);
00156 
00157   init_df3_two_variable xx(x);
00158   init_df3_two_variable aa(a);
00159   *xx.get_u_x()=1.0;
00160   *aa.get_u_y()=1.0;
00161 
00162   df3_two_variable z=cumd_gamma(xx,aa);
00163   double F_x=1.0/(*z.get_u_x());
00164   double F_y=-F_x*(*z.get_u_y());
00165 
00166   dvariable vz=0.0;
00167   value(vz)=x;
00168 
00169   gradient_structure::GRAD_STACK1->set_gradient_stack(default_evaluation,
00170     &(vz.v->x),&(_y.v->x),F_x,&(_a.v->x),F_y);
00171 
00172   return vz;
00173 }
00174 
00175 #undef ITMAX
00176 #undef EPS
00177 
00178 static double Sn(double x,double a)
00179 {
00180   int i=1;
00181   double xp=x;
00182   double prod=1.0;
00183   double summ=1.0;
00184   double summand;
00185   do
00186   {
00187     prod*=(a+i);
00188     summand=xp/prod;
00189     if (summand<1.e-4) break;
00190     summ+=summand;
00191     i++;
00192     if (i>50)
00193     {
00194       cerr << "convergence error" << endl;
00195       ad_exit(1);
00196     }
00197   }
00198   while(1);
00199   return summ;
00200 }
00201 
00202 static double get_initial_u(double a,double y)
00203 {
00204   const double c=0.57721;
00205   // note that P = y;
00206   double logP=log(y);
00207   double logQ=log(1-y);
00208   double logB=logQ+gammln(a);
00209   double x0=1.e+100;
00210   double log_x0=1.e+100;
00211 
00212   if (a<1.0)
00213   {
00214     if ( logB>log(.6) || (logB > log(.45) && a>=.3) )
00215     {
00216       double logu;
00217       if (logB+logQ > log(1.e-8))
00218       {
00219         logu=(logP+gammln(1.0+a))/a;
00220       }
00221       else
00222       {
00223         logu=-exp(logQ)/a -c;
00224       }
00225       double u=exp(logu);
00226       x0=u/(1-u/(1.0+a));
00227       double tmp=log(1-u/(1.0+a));
00228       log_x0=logu;
00229       log_x0-=tmp;
00230     }
00231     else if ( a<.3 && log(.35) <= logB && logB <= log(.6) )
00232     {
00233       double t=exp(-c-exp(logB));
00234       double logt=-c-exp(logB);
00235       double u=t*exp(t);
00236       x0=t*exp(u);
00237       log_x0=logt+u;
00238     }
00239     else if ( (log(.15)<=logB && logB <=log(.35)) ||
00240        ((log(.15)<=logB && logB <=log(.45)) && a>=.3) )
00241     {
00242       double y=-logB;
00243       double v=y-(1-a)*log(y);
00244       x0=y-(1-a)*log(v)-log(1+(1.0-a)/(1.0+v));
00245       log_x0=log(x0);
00246     }
00247     else if (log(.01)<logB && logB < log(.15))
00248     {
00249       double y=-logB;
00250       double v=y-(1-a)*log(y);
00251       x0=y-(1-a)*log(v)-log((v*v+2*(3-a)*v+(2-a)*(3-a))/(v*v +(5-a)*v+2));
00252       log_x0=log(x0);
00253     }
00254     else if (logB < log(.01))
00255     {
00256       double y=-logB;
00257       double v=y-(1-a)*log(y);
00258       x0=y-(1-a)*log(v)-log((v*v+2*(3-a)*v+(2-a)*(3-a))/(v*v +(5-a)*v+2));
00259       log_x0=log(x0);
00260     }
00261     else
00262     {
00263       cerr << "this can't happen" << endl;
00264       ad_exit(1);
00265     }
00266   }
00267   else  if (a>=1.0)
00268   {
00269     const double a0 = 3.31125922108741;
00270     const double b1 = 6.61053765625462;
00271     const double a1 = 11.6616720288968;
00272     const double b2 = 6.40691597760039;
00273     const double a2 = 4.28342155967104;
00274     const double b3 = 1.27364489782223;
00275     const double a3 = .213623493715853;
00276     const double b4 = .03611708101884203;
00277 
00278     int sgn=1;
00279     double logtau;
00280     if (logP< log(0.5))
00281     {
00282       logtau=logP;
00283       sgn=-1;
00284     }
00285     else
00286     {
00287       logtau=logQ;
00288       sgn=1;
00289     }
00290 
00291     double t=sqrt(-2.0*logtau);
00292 
00293     double num = (((a3*t+a2)*t+a1)*t)+a0;
00294     double den = ((((b4*t+b3)*t+b2)*t)+b1)*t+1;
00295     double s=sgn*(t-num/den);
00296     double s2=s*s;
00297     double s3=s2*s;
00298     double s4=s3*s;
00299     double s5=s4*s;
00300     double roota=sqrt(a);
00301     double w=a+s*roota+(s2-1)/3.0+(s3-7.0*s)/(36.*roota)
00302       -(3.0*s4+7.0*s2-16)/(810.0*a)
00303       +(9.0*s5+256.0*s3-433.0*s)/(38880.0*a*roota);
00304     if (logP< log(0.5))
00305     {
00306       if (w>.15*(a+1))
00307       {
00308         x0=w;
00309       }
00310       else
00311       {
00312         double v=logP+gammln(a+1);
00313         double u1=exp(v+w)/a;
00314         double S1=1+u1/(a+1);
00315         double u2=exp((v+u1-log(S1))/a);
00316         double S2=1+u2/(a+1)+u2*u2/((a+1)*(a+2));
00317         double u3=exp((v+u2-log(S2))/a);
00318         double S3=1+u3/(a+1)+u3*u3/((a+1)*(a+2))
00319          + u3*u3*u3/((a+1)*(a+2)*(a+3));
00320         double z=exp((v+u3-log(S3))/a);
00321         if (z<.002*(a+1.0))
00322         {
00323           x0=z;
00324         }
00325         else
00326         {
00327           double sn=Sn(z,a);
00328           double zbar=exp((v+z-log(sn))/a);
00329           x0=zbar*(1.0-(a*log(zbar)-zbar-v+log(sn))/(a-zbar));
00330         }
00331       }
00332       log_x0=log(x0);
00333     }
00334     else
00335     {
00336       double u = -logB +(a-1.0)*log(w)-log(1.0+(1.0-a)/(1+w));
00337       x0=u;
00338       log_x0=log(x0);
00339     }
00340   }
00341   if (a==1.0)
00342   {
00343     x0=-log(1.0-y);
00344     log_x0=log(x0);
00345   }
00346   return log_x0-log(a);
00347 }