ADMB Documentation  11.2.2853
 All Classes Files Functions Variables Typedefs Friends Defines
df1b2chkder.cpp
Go to the documentation of this file.
00001 /*
00002  * $Id: df1b2chkder.cpp 2637 2014-11-13 05:58:08Z johnoel $
00003  *
00004  * Author: David Fournier
00005  * Copyright (c) 2008-2012 Regents of the University of California
00006  */
00011 #include <fvar.hpp>
00012 #include <admodel.h>
00013 #include <df1b2fun.h>
00014 #include <adrndeff.h>
00015 #ifndef OPT_LIB
00016   #include <cassert>
00017   #include <climits>
00018 #endif
00019 double evaluate_function(const dvector& x,function_minimizer * pfmin);
00020 void get_second_ders(int xs,int us,const init_df1b2vector y,dmatrix& Hess,
00021   dmatrix& Dux, df1b2_gradlist * f1b2gradlist,function_minimizer * pfmin,
00022   laplace_approximation_calculator* lap);
00023 double calculate_laplace_approximation(const dvector& x,const dvector& u0,
00024   const dmatrix& Hess,const dvector& _xadjoint,const dvector& _uadjoint,
00025   const dmatrix& _Hessadjoint,function_minimizer * pmin);
00026 
00027 double calculate_importance_sample(const dvector& x,const dvector& u0,
00028   const dmatrix& Hess,const dvector& _xadjoint,const dvector& _uadjoint,
00029   const dmatrix& _Hessadjoint,function_minimizer * pmin);
00030 
00031 double calculate_importance_sample_funnel(const dvector& x,const dvector& u0,
00032   const dmatrix& Hess,const dvector& _xadjoint,const dvector& _uadjoint,
00033   const dmatrix& _Hessadjoint,function_minimizer * pmin);
00034 
00035 dmatrix choleski_decomp_positive(const dmatrix& M,double b);
00036 
00041 void laplace_approximation_calculator::
00042   check_derivatives(const dvector& _x,function_minimizer * pfmin,double f)
00043 {
00044   cerr << "need to define this" << endl;
00045   ad_exit(1);
00046 }
00047 
00052 dvector laplace_approximation_calculator::
00053   default_calculations_check_derivatives(const dvector& _x,
00054     function_minimizer * pfmin, const double& _f)
00055 {
00056   // for use when there is no separability
00057   ADUNCONST(dvector,x)
00058   int i,j;
00059   double& f = (double&)_f;
00060 
00061   initial_params::set_inactive_only_random_effects();
00062   gradient_structure::set_NO_DERIVATIVES();
00063   initial_params::reset(x);    // get current x values into the model
00064 
00065 
00066   pfmin->AD_uf_inner();
00067   double fval1=value(*objective_function_value::pobjfun);
00068 
00069   gradient_structure::set_YES_DERIVATIVES();
00070 
00071   initial_params::set_active_only_random_effects();
00072   initial_params::xinit(uhat);    // get current x values into the model
00073   //int lmn_flag=0;
00074   if (ad_comm::time_flag)
00075   {
00076     if (ad_comm::ptm1)
00077     {
00078       ad_comm::ptm1->get_elapsed_time_and_reset();
00079     }
00080     if (ad_comm::ptm)
00081     {
00082       ad_comm::ptm->get_elapsed_time_and_reset();
00083     }
00084   }
00085   if (ad_comm::time_flag)
00086   {
00087     if (ad_comm::ptm)
00088     {
00089       double time=ad_comm::ptm->get_elapsed_time();
00090       if (ad_comm::global_logfile)
00091       {
00092         (*ad_comm::global_logfile) << " Time pos 0 "
00093           << time << endl;
00094       }
00095     }
00096   }
00097 
00098   double maxg = 0;
00099   dvector uhat_old(1,usize);
00100   //double f_from_1=0.0;
00101 
00102   for (i=1;i<=xsize;i++)
00103   {
00104     y(i)=x(i);
00105   }
00106   for (i=1;i<=usize;i++)
00107   {
00108     y(i+xsize)=uhat(i);
00109   }
00110 
00111   int ierr=0;
00112   int niters=0;
00113   if (function_minimizer::first_hessian_flag)
00114     niters=num_nr_iters+1;
00115   else
00116     niters=num_nr_iters;
00117 
00118   int nv=0;
00119   if (quadratic_prior::get_num_quadratic_prior()>0)
00120   {
00121     nv=initial_df1b2params::set_index();
00122     if (allocated(used_flags))
00123     {
00124       if (used_flags.indexmax() != nv)
00125       {
00126         used_flags.safe_deallocate();
00127       }
00128     }
00129     if (!allocated(used_flags))
00130     {
00131       used_flags.safe_allocate(1,nv);
00132     }
00133   }
00134 
00135   for(int ii=1;ii<=niters;ii++)
00136   {
00137     if (quadratic_prior::get_num_quadratic_prior()>0)
00138     {
00139       check_pool_size();
00140     }
00141     {
00142       // test newton raphson
00143       Hess.initialize();
00144       cout << "Checking derivatives " << ii << endl;
00145       check_derivatives(x,pfmin,fval1);
00146 
00147       if (quadratic_prior::get_num_quadratic_prior()>0)
00148       {
00149         laplace_approximation_calculator::where_are_we_flag=2;
00150         /*double maxg = */evaluate_function_quiet(uhat,pfmin);
00151         laplace_approximation_calculator::where_are_we_flag=0;
00152         quadratic_prior::get_cHessian_contribution(Hess,xsize);
00153         quadratic_prior::get_cgradient_contribution(grad,xsize);
00154       }
00155 
00156       /*
00157       if (ii == 1)
00158         { double diff = fabs(re_objective_function_value::fun_without_pen - objective_function_value::fun_without_pen); }
00159       */
00160 
00161       dvector step;
00162       int print_hess_in_newton_raphson_flag=0;
00163       if (print_hess_in_newton_raphson_flag)
00164       {
00165         cout << norm2(Hess-trans(Hess)) << endl;
00166         if (ad_comm::global_logfile)
00167         {
00168           (*ad_comm::global_logfile) << setprecision(4) << setscientific()
00169             << setw(12) << sort(eigenvalues(Hess)) << endl;
00170           (*ad_comm::global_logfile) << setprecision(4) << setscientific()
00171             << setw(12) << Hess << endl;
00172         }
00173       }
00174 
00175 #if defined(USE_ATLAS)
00176       if (!ad_comm::no_atlas_flag)
00177       {
00178         step=-atlas_solve_spd(Hess,grad,ierr);
00179       }
00180       else
00181       {
00182         dmatrix A=choleski_decomp_positive(Hess,ierr);
00183         if (!ierr)
00184         {
00185           step=-solve(Hess,grad);
00186           //step=-solve(A*trans(A),grad);
00187         }
00188       }
00189       if (ierr)
00190       {
00191         f1b2gradlist->reset();
00192         f1b2gradlist->list.initialize();
00193         f1b2gradlist->list2.initialize();
00194         f1b2gradlist->list3.initialize();
00195         f1b2gradlist->nlist.initialize();
00196         f1b2gradlist->nlist2.initialize();
00197         f1b2gradlist->nlist3.initialize();
00198         break;
00199       }
00200 #else
00201       step=-solve(Hess,grad);
00202 #endif
00203 
00204       if (ad_comm::time_flag)
00205       {
00206         if (ad_comm::ptm)
00207         {
00208           double time=ad_comm::ptm->get_elapsed_time_and_reset();
00209           if (ad_comm::global_logfile)
00210           {
00211             (*ad_comm::global_logfile) << " time_in solve " <<  ii << "  "
00212               << time << endl;
00213           }
00214         }
00215       }
00216 
00217       f1b2gradlist->reset();
00218       f1b2gradlist->list.initialize();
00219       f1b2gradlist->list2.initialize();
00220       f1b2gradlist->list3.initialize();
00221       f1b2gradlist->nlist.initialize();
00222       f1b2gradlist->nlist2.initialize();
00223       f1b2gradlist->nlist3.initialize();
00224 
00225       uhat_old=uhat;
00226       uhat+=step;
00227 
00228       double maxg_old=maxg;
00229       maxg=fabs(evaluate_function(uhat,pfmin));
00230       if (maxg>maxg_old)
00231       {
00232         uhat=uhat_old;
00233         evaluate_function(uhat,pfmin);
00234         break;
00235       }
00236       if (maxg < 1.e-13)
00237       {
00238         break;
00239       }
00240     }
00241     for (i=1;i<=usize;i++)
00242     {
00243       y(i+xsize)=uhat(i);
00244     }
00245   }
00246 
00247   if (num_nr_iters<=0)
00248   {
00249     evaluate_function(uhat,pfmin);
00250   }
00251 
00252   for (i=1;i<=usize;i++)
00253   {
00254     y(i+xsize)=uhat(i);
00255   }
00256 
00257 
00258   if (ad_comm::time_flag)
00259   {
00260     if (ad_comm::ptm)
00261     {
00262       double time=ad_comm::ptm->get_elapsed_time_and_reset();
00263       if (ad_comm::global_logfile)
00264       {
00265         (*ad_comm::global_logfile) << " Time in reset and evaluate function"
00266           << time << endl;
00267       }
00268     }
00269   }
00270   get_second_ders(xsize,usize,y,Hess,Dux,f1b2gradlist,pfmin,this);
00271   //int sgn=0;
00272 
00273   if (ad_comm::time_flag)
00274   {
00275     if (ad_comm::ptm)
00276     {
00277       double time=ad_comm::ptm->get_elapsed_time_and_reset();
00278       if (ad_comm::global_logfile)
00279       {
00280         (*ad_comm::global_logfile) << " Time in dget second ders "
00281           << time << endl;
00282       }
00283     }
00284   }
00285   if (!ierr)
00286   {
00287     if (num_importance_samples==0)
00288     {
00289       //cout << "Hess " << endl << Hess << endl;
00290       f=calculate_laplace_approximation(x,uhat,Hess,xadjoint,uadjoint,
00291         Hessadjoint,pfmin);
00292     }
00293     else
00294     {
00295       if (isfunnel_flag==0)
00296       {
00297         f=calculate_importance_sample(x,uhat,Hess,xadjoint,uadjoint,
00298           Hessadjoint,pfmin);
00299       }
00300       else
00301       {
00302         f=calculate_importance_sample_funnel(x,uhat,Hess,xadjoint,uadjoint,
00303           Hessadjoint,pfmin);
00304       }
00305     }
00306   }
00307   else
00308   {
00309     f=1.e+30;
00310   }
00311 
00312   if (ad_comm::time_flag)
00313   {
00314     if (ad_comm::ptm)
00315     {
00316       double time=ad_comm::ptm->get_elapsed_time_and_reset();
00317       if (ad_comm::global_logfile)
00318       {
00319         (*ad_comm::global_logfile) << "Time in calculate laplace approximation "
00320           << time << endl;
00321       }
00322     }
00323   }
00324 
00325   for (int ip=num_der_blocks;ip>=1;ip--)
00326   {
00327     df1b2variable::minder=minder(ip);
00328     df1b2variable::maxder=maxder(ip);
00329     int mind=y(1).minder;
00330     int jmin=max(mind,xsize+1);
00331     int jmax=min(y(1).maxder,xsize+usize);
00332     for (i=1;i<=usize;i++)
00333     {
00334       for (j=jmin;j<=jmax;j++)
00335       {
00336         //Hess(i,j-xsize)=y(i+xsize).u_bar[j-mind];
00337         y(i+xsize).get_u_bar_tilde()[j-mind]=Hessadjoint(i,j-xsize);
00338       }
00339     }
00340 
00341     if (initial_df1b2params::separable_flag)
00342     {
00343       for (j=1;j<=xsize+usize;j++)
00344       {
00345         *y(j).get_u_tilde()=0;
00346       }
00347       Hess.initialize();
00348       initial_df1b2params::separable_calculation_type=3;
00349       pfmin->user_function();
00350     }
00351     else
00352     {
00353       if (ip<num_der_blocks)
00354       {
00355         f1b2gradlist->reset();
00356         set_u_dot(ip);
00357         df1b2_gradlist::set_yes_derivatives();
00358         (*re_objective_function_value::pobjfun)=0;
00359         df1b2variable pen=0.0;
00360         df1b2variable zz=0.0;
00361 
00362         initial_df1b2params::reset(y,pen);
00363         pfmin->user_function();
00364 
00365         re_objective_function_value::fun_without_pen=
00366           value(*re_objective_function_value::pobjfun);
00367 
00368         (*re_objective_function_value::pobjfun)+=pen;
00369         (*re_objective_function_value::pobjfun)+=zz;
00370 
00371         set_dependent_variable(*re_objective_function_value::pobjfun);
00372         df1b2_gradlist::set_no_derivatives();
00373         df1b2variable::passnumber=1;
00374         df1b2_gradcalc1();
00375       }
00376 
00377       for (i=1;i<=usize;i++)
00378       {
00379         for (j=jmin;j<=jmax;j++)
00380         {
00381           //Hess(i,j-xsize)=y(i+xsize).u_bar[j-mind];
00382           y(i+xsize).get_u_bar_tilde()[j-mind]=Hessadjoint(i,j-xsize);
00383         }
00384       }
00385 
00386       //int mind=y(1).minder;
00387       df1b2variable::passnumber=2;
00388       df1b2_gradcalc1();
00389 
00390       df1b2variable::passnumber=3;
00391       df1b2_gradcalc1();
00392 
00393       f1b2gradlist->reset();
00394       f1b2gradlist->list.initialize();
00395       f1b2gradlist->list2.initialize();
00396       f1b2gradlist->list3.initialize();
00397       f1b2gradlist->nlist.initialize();
00398       f1b2gradlist->nlist2.initialize();
00399       f1b2gradlist->nlist3.initialize();
00400     }
00401 
00402     if (ad_comm::time_flag)
00403     {
00404       if (ad_comm::ptm)
00405       {
00406         double time=ad_comm::ptm->get_elapsed_time_and_reset();
00407         if (ad_comm::global_logfile)
00408         {
00409           (*ad_comm::global_logfile) << " time for 3rd derivatives "
00410             << time << endl;
00411         }
00412       }
00413     }
00414 
00415     dvector dtmp(1,xsize);
00416     for (i=1;i<=xsize;i++)
00417     {
00418       dtmp(i)=*y(i).get_u_tilde();
00419     }
00420     if (initial_df1b2params::separable_flag)
00421     {
00422 #ifndef OPT_LIB
00423       assert(nvar <= INT_MAX);
00424 #endif
00425       dvector scale(1,(int)nvar);   // need to get scale from somewhere
00426       /*int check=*/initial_params::stddev_scale(scale,x);
00427       dvector sscale=scale(1,Dux(1).indexmax());
00428       for (i=1;i<=usize;i++)
00429       {
00430         Dux(i)=elem_prod(Dux(i),sscale);
00431       }
00432       dtmp=elem_prod(dtmp,sscale);
00433     }
00434 
00435     for (i=1;i<=xsize;i++)
00436     {
00437       xadjoint(i)+=dtmp(i);
00438     }
00439     for (i=1;i<=usize;i++)
00440       uadjoint(i)+=*y(xsize+i).get_u_tilde();
00441   }
00442  // *****************************************************************
00443  // new stuff to deal with quadraticprior
00444  // *****************************************************************
00445 
00446     int xstuff=3;
00447     if (xstuff && df1b2quadratic_prior::get_num_quadratic_prior()>0)
00448     {
00449       initial_params::straight_through_flag=0;
00450       funnel_init_var::lapprox=0;
00451       block_diagonal_flag=0;
00452 #ifndef OPT_LIB
00453       assert(nvar <= INT_MAX);
00454 #endif
00455       dvector scale1(1,(int)nvar);   // need to get scale from somewhere
00456       initial_params::set_inactive_only_random_effects();
00457       /*int check=*/initial_params::stddev_scale(scale1,x);
00458 
00459       laplace_approximation_calculator::where_are_we_flag=3;
00460       quadratic_prior::in_qp_calculations=1;
00461       funnel_init_var::lapprox=this;
00462       df1b2_gradlist::set_no_derivatives();
00463       dvector scale(1,(int)nvar);   // need to get scale from somewhere
00464       /*check=*/initial_params::stddev_scale(scale,x);
00465       dvector sscale=scale(1,Dux(1).indexmax());
00466 
00467       for (i=1;i<=usize;i++)
00468       {
00469         Dux(i)=elem_div(Dux(i),sscale);
00470       }
00471 
00472       if (xstuff>1)
00473       {
00474         df1b2quadratic_prior::get_Lxu_contribution(Dux);
00475       }
00476       quadratic_prior::in_qp_calculations=0;
00477       funnel_init_var::lapprox=0;
00478       laplace_approximation_calculator::where_are_we_flag=0;
00479 
00480       for (i=1;i<=usize;i++)
00481       {
00482         Dux(i)=elem_prod(Dux(i),sscale);
00483       }
00484       //local_dtemp=elem_prod(local_dtemp,sscale);
00485 
00486       if (xstuff>2)
00487       {
00488         dvector tmp=evaluate_function_with_quadprior(x,usize,pfmin);
00489         for (i=1;i<=xsize;i++)
00490         {
00491           xadjoint(i)+=tmp(i);
00492         }
00493       }
00494 
00495       if (xstuff>2)
00496       {
00497         quadratic_prior::get_cHessian_contribution_from_vHessian(Hess,xsize);
00498       }
00499     }
00500 
00501  // *****************************************************************
00502  // new stuff to deal with quadraticprior
00503  // *****************************************************************
00504   if (ad_comm::ptm)
00505   {
00506     /*double time=*/ad_comm::ptm->get_elapsed_time_and_reset();
00507   }
00508 
00509 #if defined(USE_ATLAS)
00510       if (!ad_comm::no_atlas_flag)
00511       {
00512         //xadjoint -= uadjoint*atlas_solve_spd_trans(Hess,Dux);
00513         xadjoint -= atlas_solve_spd_trans(Hess,uadjoint)*Dux;
00514       }
00515       else
00516       {
00517         //xadjoint -= uadjoint*solve(Hess,Dux);
00518         xadjoint -= solve(Hess,uadjoint)*Dux;
00519       }
00520 #else
00521       //xadjoint -= uadjoint*solve(Hess,Dux);
00522       xadjoint -= solve(Hess,uadjoint)*Dux;
00523 #endif
00524 
00525 
00526   if (ad_comm::ptm)
00527   {
00528     double time=ad_comm::ptm->get_elapsed_time_and_reset();
00529     if (ad_comm::global_logfile)
00530     {
00531       (*ad_comm::global_logfile) << " Time in second solve "
00532         << time << endl;
00533     }
00534   }
00535   if (ad_comm::ptm1)
00536   {
00537     double time=ad_comm::ptm1->get_elapsed_time_and_reset();
00538     if (ad_comm::global_logfile)
00539     {
00540       (*ad_comm::global_logfile) << " Total time in function evaluation "
00541         << time << endl << endl;
00542     }
00543   }
00544 
00545   return xadjoint;
00546 }