ADMB Documentation  11.1.2274
 All Classes Files Functions Variables Typedefs Friends Defines
df1b2chkder.cpp
Go to the documentation of this file.
00001 /*
00002  * $Id: df1b2chkder.cpp 1935 2014-04-26 02:02:58Z johnoel $
00003  *
00004  * Author: David Fournier
00005  * Copyright (c) 2008-2012 Regents of the University of California
00006  */
00011 #  include <fvar.hpp>
00012 #  include <admodel.h>
00013 #  include <df1b2fun.h>
00014 #  include <adrndeff.h>
00015 double evaluate_function(const dvector& x,function_minimizer * pfmin);
00016 void get_second_ders(int xs,int us,const init_df1b2vector y,dmatrix& Hess,
00017   dmatrix& Dux, df1b2_gradlist * f1b2gradlist,function_minimizer * pfmin,
00018   laplace_approximation_calculator* lap);
00019 double calculate_laplace_approximation(const dvector& x,const dvector& u0,
00020   const dmatrix& Hess,const dvector& _xadjoint,const dvector& _uadjoint,
00021   const dmatrix& _Hessadjoint,function_minimizer * pmin);
00022 
00023 double calculate_importance_sample(const dvector& x,const dvector& u0,
00024   const dmatrix& Hess,const dvector& _xadjoint,const dvector& _uadjoint,
00025   const dmatrix& _Hessadjoint,function_minimizer * pmin);
00026 
00027 double calculate_importance_sample_funnel(const dvector& x,const dvector& u0,
00028   const dmatrix& Hess,const dvector& _xadjoint,const dvector& _uadjoint,
00029   const dmatrix& _Hessadjoint,function_minimizer * pmin);
00030 
00031 dmatrix choleski_decomp_positive(const dmatrix& M,double b);
00032 
00037 void laplace_approximation_calculator::
00038   check_derivatives(const dvector& _x,function_minimizer * pfmin,double f)
00039 {
00040   cerr << "need to define this" << endl;
00041   ad_exit(1);
00042 }
00043 
00048 dvector laplace_approximation_calculator::
00049   default_calculations_check_derivatives(const dvector& _x,
00050     function_minimizer * pfmin, const double& _f)
00051 {
00052   // for use when there is no separability
00053   ADUNCONST(dvector,x)
00054   int i,j;
00055   double& f = (double&)_f;
00056 
00057   initial_params::set_inactive_only_random_effects();
00058   gradient_structure::set_NO_DERIVATIVES();
00059   initial_params::reset(x);    // get current x values into the model
00060 
00061 
00062   pfmin->AD_uf_inner();
00063   double fval1=value(*objective_function_value::pobjfun);
00064 
00065   gradient_structure::set_YES_DERIVATIVES();
00066 
00067   initial_params::set_active_only_random_effects();
00068   initial_params::xinit(uhat);    // get current x values into the model
00069   //int lmn_flag=0;
00070   if (ad_comm::time_flag)
00071   {
00072     if (ad_comm::ptm1)
00073     {
00074       ad_comm::ptm1->get_elapsed_time_and_reset();
00075     }
00076     if (ad_comm::ptm)
00077     {
00078       ad_comm::ptm->get_elapsed_time_and_reset();
00079     }
00080   }
00081   if (ad_comm::time_flag)
00082   {
00083     if (ad_comm::ptm)
00084     {
00085       double time=ad_comm::ptm->get_elapsed_time();
00086       if (ad_comm::global_logfile)
00087       {
00088         (*ad_comm::global_logfile) << " Time pos 0 "
00089           << time << endl;
00090       }
00091     }
00092   }
00093 
00094   double maxg = 0;
00095   dvector uhat_old(1,usize);
00096   //double f_from_1=0.0;
00097 
00098   for (i=1;i<=xsize;i++)
00099   {
00100     y(i)=x(i);
00101   }
00102   for (i=1;i<=usize;i++)
00103   {
00104     y(i+xsize)=uhat(i);
00105   }
00106 
00107   int ierr=0;
00108   int niters=0;
00109   if (function_minimizer::first_hessian_flag)
00110     niters=num_nr_iters+1;
00111   else
00112     niters=num_nr_iters;
00113 
00114   int nv=0;
00115   if (quadratic_prior::get_num_quadratic_prior()>0)
00116   {
00117     nv=initial_df1b2params::set_index();
00118     if (allocated(used_flags))
00119     {
00120       if (used_flags.indexmax() != nv)
00121       {
00122         used_flags.safe_deallocate();
00123       }
00124     }
00125     if (!allocated(used_flags))
00126     {
00127       used_flags.safe_allocate(1,nv);
00128     }
00129   }
00130 
00131   for(int ii=1;ii<=niters;ii++)
00132   {
00133     if (quadratic_prior::get_num_quadratic_prior()>0)
00134     {
00135       check_pool_size();
00136     }
00137     {
00138       // test newton raphson
00139       Hess.initialize();
00140       cout << "Checking derivatives " << ii << endl;
00141       check_derivatives(x,pfmin,fval1);
00142 
00143       if (quadratic_prior::get_num_quadratic_prior()>0)
00144       {
00145         laplace_approximation_calculator::where_are_we_flag=2;
00146         /*double maxg = */evaluate_function_quiet(uhat,pfmin);
00147         laplace_approximation_calculator::where_are_we_flag=0;
00148         quadratic_prior::get_cHessian_contribution(Hess,xsize);
00149         quadratic_prior::get_cgradient_contribution(grad,xsize);
00150       }
00151 
00152       /*
00153       if (ii == 1)
00154         { double diff = fabs(re_objective_function_value::fun_without_pen - objective_function_value::fun_without_pen); }
00155       */
00156 
00157       dvector step;
00158       int print_hess_in_newton_raphson_flag=0;
00159       if (print_hess_in_newton_raphson_flag)
00160       {
00161         cout << norm2(Hess-trans(Hess)) << endl;
00162         if (ad_comm::global_logfile)
00163         {
00164           (*ad_comm::global_logfile) << setprecision(4) << setscientific()
00165             << setw(12) << sort(eigenvalues(Hess)) << endl;
00166           (*ad_comm::global_logfile) << setprecision(4) << setscientific()
00167             << setw(12) << Hess << endl;
00168         }
00169       }
00170 
00171 #if defined(USE_ATLAS)
00172       if (!ad_comm::no_atlas_flag)
00173       {
00174         step=-atlas_solve_spd(Hess,grad,ierr);
00175       }
00176       else
00177       {
00178         dmatrix A=choleski_decomp_positive(Hess,ierr);
00179         if (!ierr)
00180         {
00181           step=-solve(Hess,grad);
00182           //step=-solve(A*trans(A),grad);
00183         }
00184       }
00185       if (ierr)
00186       {
00187         f1b2gradlist->reset();
00188         f1b2gradlist->list.initialize();
00189         f1b2gradlist->list2.initialize();
00190         f1b2gradlist->list3.initialize();
00191         f1b2gradlist->nlist.initialize();
00192         f1b2gradlist->nlist2.initialize();
00193         f1b2gradlist->nlist3.initialize();
00194         break;
00195       }
00196 #else
00197       step=-solve(Hess,grad);
00198 #endif
00199 
00200       if (ad_comm::time_flag)
00201       {
00202         if (ad_comm::ptm)
00203         {
00204           double time=ad_comm::ptm->get_elapsed_time_and_reset();
00205           if (ad_comm::global_logfile)
00206           {
00207             (*ad_comm::global_logfile) << " time_in solve " <<  ii << "  "
00208               << time << endl;
00209           }
00210         }
00211       }
00212 
00213       f1b2gradlist->reset();
00214       f1b2gradlist->list.initialize();
00215       f1b2gradlist->list2.initialize();
00216       f1b2gradlist->list3.initialize();
00217       f1b2gradlist->nlist.initialize();
00218       f1b2gradlist->nlist2.initialize();
00219       f1b2gradlist->nlist3.initialize();
00220 
00221       uhat_old=uhat;
00222       uhat+=step;
00223 
00224       double maxg_old=maxg;
00225       maxg=fabs(evaluate_function(uhat,pfmin));
00226       if (maxg>maxg_old)
00227       {
00228         uhat=uhat_old;
00229         evaluate_function(uhat,pfmin);
00230         break;
00231       }
00232       if (maxg < 1.e-13)
00233       {
00234         break;
00235       }
00236     }
00237     for (i=1;i<=usize;i++)
00238     {
00239       y(i+xsize)=uhat(i);
00240     }
00241   }
00242 
00243   if (num_nr_iters<=0)
00244   {
00245     evaluate_function(uhat,pfmin);
00246   }
00247 
00248   for (i=1;i<=usize;i++)
00249   {
00250     y(i+xsize)=uhat(i);
00251   }
00252 
00253 
00254   if (ad_comm::time_flag)
00255   {
00256     if (ad_comm::ptm)
00257     {
00258       double time=ad_comm::ptm->get_elapsed_time_and_reset();
00259       if (ad_comm::global_logfile)
00260       {
00261         (*ad_comm::global_logfile) << " Time in reset and evaluate function"
00262           << time << endl;
00263       }
00264     }
00265   }
00266   get_second_ders(xsize,usize,y,Hess,Dux,f1b2gradlist,pfmin,this);
00267   //int sgn=0;
00268 
00269   if (ad_comm::time_flag)
00270   {
00271     if (ad_comm::ptm)
00272     {
00273       double time=ad_comm::ptm->get_elapsed_time_and_reset();
00274       if (ad_comm::global_logfile)
00275       {
00276         (*ad_comm::global_logfile) << " Time in dget second ders "
00277           << time << endl;
00278       }
00279     }
00280   }
00281   if (!ierr)
00282   {
00283     if (num_importance_samples==0)
00284     {
00285       //cout << "Hess " << endl << Hess << endl;
00286       f=calculate_laplace_approximation(x,uhat,Hess,xadjoint,uadjoint,
00287         Hessadjoint,pfmin);
00288     }
00289     else
00290     {
00291       if (isfunnel_flag==0)
00292       {
00293         f=calculate_importance_sample(x,uhat,Hess,xadjoint,uadjoint,
00294           Hessadjoint,pfmin);
00295       }
00296       else
00297       {
00298         f=calculate_importance_sample_funnel(x,uhat,Hess,xadjoint,uadjoint,
00299           Hessadjoint,pfmin);
00300       }
00301     }
00302   }
00303   else
00304   {
00305     f=1.e+30;
00306   }
00307 
00308   if (ad_comm::time_flag)
00309   {
00310     if (ad_comm::ptm)
00311     {
00312       double time=ad_comm::ptm->get_elapsed_time_and_reset();
00313       if (ad_comm::global_logfile)
00314       {
00315         (*ad_comm::global_logfile) << "Time in calculate laplace approximation "
00316           << time << endl;
00317       }
00318     }
00319   }
00320 
00321   for (int ip=num_der_blocks;ip>=1;ip--)
00322   {
00323     df1b2variable::minder=minder(ip);
00324     df1b2variable::maxder=maxder(ip);
00325     int mind=y(1).minder;
00326     int jmin=max(mind,xsize+1);
00327     int jmax=min(y(1).maxder,xsize+usize);
00328     for (i=1;i<=usize;i++)
00329     {
00330       for (j=jmin;j<=jmax;j++)
00331       {
00332         //Hess(i,j-xsize)=y(i+xsize).u_bar[j-mind];
00333         y(i+xsize).get_u_bar_tilde()[j-mind]=Hessadjoint(i,j-xsize);
00334       }
00335     }
00336 
00337     if (initial_df1b2params::separable_flag)
00338     {
00339       for (j=1;j<=xsize+usize;j++)
00340       {
00341         *y(j).get_u_tilde()=0;
00342       }
00343       Hess.initialize();
00344       initial_df1b2params::separable_calculation_type=3;
00345       pfmin->user_function();
00346     }
00347     else
00348     {
00349       if (ip<num_der_blocks)
00350       {
00351         f1b2gradlist->reset();
00352         set_u_dot(ip);
00353         df1b2_gradlist::set_yes_derivatives();
00354         (*re_objective_function_value::pobjfun)=0;
00355         df1b2variable pen=0.0;
00356         df1b2variable zz=0.0;
00357 
00358         initial_df1b2params::reset(y,pen);
00359         pfmin->user_function();
00360 
00361         re_objective_function_value::fun_without_pen=
00362           value(*re_objective_function_value::pobjfun);
00363 
00364         (*re_objective_function_value::pobjfun)+=pen;
00365         (*re_objective_function_value::pobjfun)+=zz;
00366 
00367         set_dependent_variable(*re_objective_function_value::pobjfun);
00368         df1b2_gradlist::set_no_derivatives();
00369         df1b2variable::passnumber=1;
00370         df1b2_gradcalc1();
00371       }
00372 
00373       for (i=1;i<=usize;i++)
00374       {
00375         for (j=jmin;j<=jmax;j++)
00376         {
00377           //Hess(i,j-xsize)=y(i+xsize).u_bar[j-mind];
00378           y(i+xsize).get_u_bar_tilde()[j-mind]=Hessadjoint(i,j-xsize);
00379         }
00380       }
00381 
00382       //int mind=y(1).minder;
00383       df1b2variable::passnumber=2;
00384       df1b2_gradcalc1();
00385 
00386       df1b2variable::passnumber=3;
00387       df1b2_gradcalc1();
00388 
00389       f1b2gradlist->reset();
00390       f1b2gradlist->list.initialize();
00391       f1b2gradlist->list2.initialize();
00392       f1b2gradlist->list3.initialize();
00393       f1b2gradlist->nlist.initialize();
00394       f1b2gradlist->nlist2.initialize();
00395       f1b2gradlist->nlist3.initialize();
00396     }
00397 
00398     if (ad_comm::time_flag)
00399     {
00400       if (ad_comm::ptm)
00401       {
00402         double time=ad_comm::ptm->get_elapsed_time_and_reset();
00403         if (ad_comm::global_logfile)
00404         {
00405           (*ad_comm::global_logfile) << " time for 3rd derivatives "
00406             << time << endl;
00407         }
00408       }
00409     }
00410 
00411     dvector dtmp(1,xsize);
00412     for (i=1;i<=xsize;i++)
00413     {
00414       dtmp(i)=*y(i).get_u_tilde();
00415     }
00416     if (initial_df1b2params::separable_flag)
00417     {
00418       dvector scale(1,nvar);   // need to get scale from somewhere
00419       /*int check=*/initial_params::stddev_scale(scale,x);
00420       dvector sscale=scale(1,Dux(1).indexmax());
00421       for (i=1;i<=usize;i++)
00422       {
00423         Dux(i)=elem_prod(Dux(i),sscale);
00424       }
00425       dtmp=elem_prod(dtmp,sscale);
00426     }
00427 
00428     for (i=1;i<=xsize;i++)
00429     {
00430       xadjoint(i)+=dtmp(i);
00431     }
00432     for (i=1;i<=usize;i++)
00433       uadjoint(i)+=*y(xsize+i).get_u_tilde();
00434   }
00435  // *****************************************************************
00436  // new stuff to deal with quadraticprior
00437  // *****************************************************************
00438 
00439     int xstuff=3;
00440     if (xstuff && df1b2quadratic_prior::get_num_quadratic_prior()>0)
00441     {
00442       initial_params::straight_through_flag=0;
00443       funnel_init_var::lapprox=0;
00444       block_diagonal_flag=0;
00445       dvector scale1(1,nvar);   // need to get scale from somewhere
00446       initial_params::set_inactive_only_random_effects();
00447       /*int check=*/initial_params::stddev_scale(scale1,x);
00448 
00449       laplace_approximation_calculator::where_are_we_flag=3;
00450       quadratic_prior::in_qp_calculations=1;
00451       funnel_init_var::lapprox=this;
00452       df1b2_gradlist::set_no_derivatives();
00453       dvector scale(1,nvar);   // need to get scale from somewhere
00454       /*check=*/initial_params::stddev_scale(scale,x);
00455       dvector sscale=scale(1,Dux(1).indexmax());
00456 
00457       for (i=1;i<=usize;i++)
00458       {
00459         Dux(i)=elem_div(Dux(i),sscale);
00460       }
00461 
00462       if (xstuff>1)
00463       {
00464         df1b2quadratic_prior::get_Lxu_contribution(Dux);
00465       }
00466       quadratic_prior::in_qp_calculations=0;
00467       funnel_init_var::lapprox=0;
00468       laplace_approximation_calculator::where_are_we_flag=0;
00469 
00470       for (i=1;i<=usize;i++)
00471       {
00472         Dux(i)=elem_prod(Dux(i),sscale);
00473       }
00474       //local_dtemp=elem_prod(local_dtemp,sscale);
00475 
00476       if (xstuff>2)
00477       {
00478         dvector tmp=evaluate_function_with_quadprior(x,usize,pfmin);
00479         for (i=1;i<=xsize;i++)
00480         {
00481           xadjoint(i)+=tmp(i);
00482         }
00483       }
00484 
00485       if (xstuff>2)
00486       {
00487         quadratic_prior::get_cHessian_contribution_from_vHessian(Hess,xsize);
00488       }
00489     }
00490 
00491  // *****************************************************************
00492  // new stuff to deal with quadraticprior
00493  // *****************************************************************
00494   if (ad_comm::ptm)
00495   {
00496     /*double time=*/ad_comm::ptm->get_elapsed_time_and_reset();
00497   }
00498 
00499 #if defined(USE_ATLAS)
00500       if (!ad_comm::no_atlas_flag)
00501       {
00502         //xadjoint -= uadjoint*atlas_solve_spd_trans(Hess,Dux);
00503         xadjoint -= atlas_solve_spd_trans(Hess,uadjoint)*Dux;
00504       }
00505       else
00506       {
00507         //xadjoint -= uadjoint*solve(Hess,Dux);
00508         xadjoint -= solve(Hess,uadjoint)*Dux;
00509       }
00510 #else
00511       //xadjoint -= uadjoint*solve(Hess,Dux);
00512       xadjoint -= solve(Hess,uadjoint)*Dux;
00513 #endif
00514 
00515 
00516   if (ad_comm::ptm)
00517   {
00518     double time=ad_comm::ptm->get_elapsed_time_and_reset();
00519     if (ad_comm::global_logfile)
00520     {
00521       (*ad_comm::global_logfile) << " Time in second solve "
00522         << time << endl;
00523     }
00524   }
00525   if (ad_comm::ptm1)
00526   {
00527     double time=ad_comm::ptm1->get_elapsed_time_and_reset();
00528     if (ad_comm::global_logfile)
00529     {
00530       (*ad_comm::global_logfile) << " Total time in function evaluation "
00531         << time << endl << endl;
00532     }
00533   }
00534 
00535   return xadjoint;
00536 }