ADMB Documentation  11.1.1916
 All Classes Files Functions Variables Typedefs Friends Defines
df1b2chkder.cpp
Go to the documentation of this file.
00001 /*
00002  * $Id: df1b2chkder.cpp 1784 2014-03-12 23:32:25Z johnoel $
00003  *
00004  * Author: David Fournier
00005  * Copyright (c) 2008-2012 Regents of the University of California
00006  */
00011 #if defined(USE_LAPLACE)
00012 #  include <fvar.hpp>
00013 #  include <admodel.h>
00014 #  include <df1b2fun.h>
00015 #  include <adrndeff.h>
00016 double evaluate_function(const dvector& x,function_minimizer * pfmin);
00017 void get_second_ders(int xs,int us,const init_df1b2vector y,dmatrix& Hess,
00018   dmatrix& Dux, df1b2_gradlist * f1b2gradlist,function_minimizer * pfmin,
00019   laplace_approximation_calculator* lap);
00020 double calculate_laplace_approximation(const dvector& x,const dvector& u0,
00021   const dmatrix& Hess,const dvector& _xadjoint,const dvector& _uadjoint,
00022   const dmatrix& _Hessadjoint,function_minimizer * pmin);
00023 
00024 double calculate_importance_sample(const dvector& x,const dvector& u0,
00025   const dmatrix& Hess,const dvector& _xadjoint,const dvector& _uadjoint,
00026   const dmatrix& _Hessadjoint,function_minimizer * pmin);
00027 
00028 double calculate_importance_sample_funnel(const dvector& x,const dvector& u0,
00029   const dmatrix& Hess,const dvector& _xadjoint,const dvector& _uadjoint,
00030   const dmatrix& _Hessadjoint,function_minimizer * pmin);
00031 
00032 dmatrix choleski_decomp_positive(const dmatrix& M,double b);
00033 
00038 void laplace_approximation_calculator::
00039   check_derivatives(const dvector& _x,function_minimizer * pfmin,double f)
00040 {
00041   cerr << "need to define this" << endl;
00042   ad_exit(1);
00043 }
00044 
00049 dvector laplace_approximation_calculator::
00050   default_calculations_check_derivatives(const dvector& _x,
00051     function_minimizer * pfmin, const double& _f)
00052 {
00053   // for use when there is no separability
00054   ADUNCONST(dvector,x)
00055   int i,j;
00056   double& f = (double&)_f;
00057 
00058   initial_params::set_inactive_only_random_effects();
00059   gradient_structure::set_NO_DERIVATIVES();
00060   initial_params::reset(x);    // get current x values into the model
00061 
00062 
00063   pfmin->AD_uf_inner();
00064   double fval1=value(*objective_function_value::pobjfun);
00065 
00066   gradient_structure::set_YES_DERIVATIVES();
00067 
00068   initial_params::set_active_only_random_effects();
00069   initial_params::xinit(uhat);    // get current x values into the model
00070   //int lmn_flag=0;
00071   if (ad_comm::time_flag)
00072   {
00073     if (ad_comm::ptm1)
00074     {
00075       ad_comm::ptm1->get_elapsed_time_and_reset();
00076     }
00077     if (ad_comm::ptm)
00078     {
00079       ad_comm::ptm->get_elapsed_time_and_reset();
00080     }
00081   }
00082   if (ad_comm::time_flag)
00083   {
00084     if (ad_comm::ptm)
00085     {
00086       double time=ad_comm::ptm->get_elapsed_time();
00087       if (ad_comm::global_logfile)
00088       {
00089         (*ad_comm::global_logfile) << " Time pos 0 "
00090           << time << endl;
00091       }
00092     }
00093   }
00094 
00095   double maxg = 0;
00096   dvector uhat_old(1,usize);
00097   //double f_from_1=0.0;
00098 
00099   for (i=1;i<=xsize;i++)
00100   {
00101     y(i)=x(i);
00102   }
00103   for (i=1;i<=usize;i++)
00104   {
00105     y(i+xsize)=uhat(i);
00106   }
00107 
00108   int ierr=0;
00109   int niters=0;
00110   if (function_minimizer::first_hessian_flag)
00111     niters=num_nr_iters+1;
00112   else
00113     niters=num_nr_iters;
00114 
00115   int nv=0;
00116   if (quadratic_prior::get_num_quadratic_prior()>0)
00117   {
00118     nv=initial_df1b2params::set_index();
00119     if (allocated(used_flags))
00120     {
00121       if (used_flags.indexmax() != nv)
00122       {
00123         used_flags.safe_deallocate();
00124       }
00125     }
00126     if (!allocated(used_flags))
00127     {
00128       used_flags.safe_allocate(1,nv);
00129     }
00130   }
00131 
00132   for(int ii=1;ii<=niters;ii++)
00133   {
00134     if (quadratic_prior::get_num_quadratic_prior()>0)
00135     {
00136       check_pool_size();
00137     }
00138     {
00139       // test newton raphson
00140       Hess.initialize();
00141       cout << "Checking derivatives " << ii << endl;
00142       check_derivatives(x,pfmin,fval1);
00143 
00144       if (quadratic_prior::get_num_quadratic_prior()>0)
00145       {
00146         laplace_approximation_calculator::where_are_we_flag=2;
00147         /*double maxg = */evaluate_function_quiet(uhat,pfmin);
00148         laplace_approximation_calculator::where_are_we_flag=0;
00149         quadratic_prior::get_cHessian_contribution(Hess,xsize);
00150         quadratic_prior::get_cgradient_contribution(grad,xsize);
00151       }
00152 
00153       /*
00154       if (ii == 1)
00155         { double diff = fabs(re_objective_function_value::fun_without_pen - objective_function_value::fun_without_pen); }
00156       */
00157 
00158       dvector step;
00159       int print_hess_in_newton_raphson_flag=0;
00160       if (print_hess_in_newton_raphson_flag)
00161       {
00162         cout << norm2(Hess-trans(Hess)) << endl;
00163         if (ad_comm::global_logfile)
00164         {
00165           (*ad_comm::global_logfile) << setprecision(4) << setscientific()
00166             << setw(12) << sort(eigenvalues(Hess)) << endl;
00167           (*ad_comm::global_logfile) << setprecision(4) << setscientific()
00168             << setw(12) << Hess << endl;
00169         }
00170       }
00171 
00172 #if defined(USE_ATLAS)
00173       if (!ad_comm::no_atlas_flag)
00174       {
00175         step=-atlas_solve_spd(Hess,grad,ierr);
00176       }
00177       else
00178       {
00179         dmatrix A=choleski_decomp_positive(Hess,ierr);
00180         if (!ierr)
00181         {
00182           step=-solve(Hess,grad);
00183           //step=-solve(A*trans(A),grad);
00184         }
00185       }
00186       if (ierr)
00187       {
00188         f1b2gradlist->reset();
00189         f1b2gradlist->list.initialize();
00190         f1b2gradlist->list2.initialize();
00191         f1b2gradlist->list3.initialize();
00192         f1b2gradlist->nlist.initialize();
00193         f1b2gradlist->nlist2.initialize();
00194         f1b2gradlist->nlist3.initialize();
00195         break;
00196       }
00197 #else
00198       step=-solve(Hess,grad);
00199 #endif
00200 
00201       if (ad_comm::time_flag)
00202       {
00203         if (ad_comm::ptm)
00204         {
00205           double time=ad_comm::ptm->get_elapsed_time_and_reset();
00206           if (ad_comm::global_logfile)
00207           {
00208             (*ad_comm::global_logfile) << " time_in solve " <<  ii << "  "
00209               << time << endl;
00210           }
00211         }
00212       }
00213 
00214       f1b2gradlist->reset();
00215       f1b2gradlist->list.initialize();
00216       f1b2gradlist->list2.initialize();
00217       f1b2gradlist->list3.initialize();
00218       f1b2gradlist->nlist.initialize();
00219       f1b2gradlist->nlist2.initialize();
00220       f1b2gradlist->nlist3.initialize();
00221 
00222       uhat_old=uhat;
00223       uhat+=step;
00224 
00225       double maxg_old=maxg;
00226       maxg=fabs(evaluate_function(uhat,pfmin));
00227       if (maxg>maxg_old)
00228       {
00229         uhat=uhat_old;
00230         evaluate_function(uhat,pfmin);
00231         break;
00232       }
00233       if (maxg < 1.e-13)
00234       {
00235         break;
00236       }
00237     }
00238     for (i=1;i<=usize;i++)
00239     {
00240       y(i+xsize)=uhat(i);
00241     }
00242   }
00243 
00244   if (num_nr_iters<=0)
00245   {
00246     evaluate_function(uhat,pfmin);
00247   }
00248 
00249   for (i=1;i<=usize;i++)
00250   {
00251     y(i+xsize)=uhat(i);
00252   }
00253 
00254 
00255   if (ad_comm::time_flag)
00256   {
00257     if (ad_comm::ptm)
00258     {
00259       double time=ad_comm::ptm->get_elapsed_time_and_reset();
00260       if (ad_comm::global_logfile)
00261       {
00262         (*ad_comm::global_logfile) << " Time in reset and evaluate function"
00263           << time << endl;
00264       }
00265     }
00266   }
00267   get_second_ders(xsize,usize,y,Hess,Dux,f1b2gradlist,pfmin,this);
00268   //int sgn=0;
00269 
00270   if (ad_comm::time_flag)
00271   {
00272     if (ad_comm::ptm)
00273     {
00274       double time=ad_comm::ptm->get_elapsed_time_and_reset();
00275       if (ad_comm::global_logfile)
00276       {
00277         (*ad_comm::global_logfile) << " Time in dget second ders "
00278           << time << endl;
00279       }
00280     }
00281   }
00282   if (!ierr)
00283   {
00284     if (num_importance_samples==0)
00285     {
00286       //cout << "Hess " << endl << Hess << endl;
00287       f=calculate_laplace_approximation(x,uhat,Hess,xadjoint,uadjoint,
00288         Hessadjoint,pfmin);
00289     }
00290     else
00291     {
00292       if (isfunnel_flag==0)
00293       {
00294         f=calculate_importance_sample(x,uhat,Hess,xadjoint,uadjoint,
00295           Hessadjoint,pfmin);
00296       }
00297       else
00298       {
00299         f=calculate_importance_sample_funnel(x,uhat,Hess,xadjoint,uadjoint,
00300           Hessadjoint,pfmin);
00301       }
00302     }
00303   }
00304   else
00305   {
00306     f=1.e+30;
00307   }
00308 
00309   if (ad_comm::time_flag)
00310   {
00311     if (ad_comm::ptm)
00312     {
00313       double time=ad_comm::ptm->get_elapsed_time_and_reset();
00314       if (ad_comm::global_logfile)
00315       {
00316         (*ad_comm::global_logfile) << "Time in calculate laplace approximation "
00317           << time << endl;
00318       }
00319     }
00320   }
00321 
00322   for (int ip=num_der_blocks;ip>=1;ip--)
00323   {
00324     df1b2variable::minder=minder(ip);
00325     df1b2variable::maxder=maxder(ip);
00326     int mind=y(1).minder;
00327     int jmin=max(mind,xsize+1);
00328     int jmax=min(y(1).maxder,xsize+usize);
00329     for (i=1;i<=usize;i++)
00330     {
00331       for (j=jmin;j<=jmax;j++)
00332       {
00333         //Hess(i,j-xsize)=y(i+xsize).u_bar[j-mind];
00334         y(i+xsize).get_u_bar_tilde()[j-mind]=Hessadjoint(i,j-xsize);
00335       }
00336     }
00337 
00338     if (initial_df1b2params::separable_flag)
00339     {
00340       for (j=1;j<=xsize+usize;j++)
00341       {
00342         *y(j).get_u_tilde()=0;
00343       }
00344       Hess.initialize();
00345       initial_df1b2params::separable_calculation_type=3;
00346       pfmin->user_function();
00347     }
00348     else
00349     {
00350       if (ip<num_der_blocks)
00351       {
00352         f1b2gradlist->reset();
00353         set_u_dot(ip);
00354         df1b2_gradlist::set_yes_derivatives();
00355         (*re_objective_function_value::pobjfun)=0;
00356         df1b2variable pen=0.0;
00357         df1b2variable zz=0.0;
00358 
00359         initial_df1b2params::reset(y,pen);
00360         pfmin->user_function();
00361 
00362         re_objective_function_value::fun_without_pen=
00363           value(*re_objective_function_value::pobjfun);
00364 
00365         (*re_objective_function_value::pobjfun)+=pen;
00366         (*re_objective_function_value::pobjfun)+=zz;
00367 
00368         set_dependent_variable(*re_objective_function_value::pobjfun);
00369         df1b2_gradlist::set_no_derivatives();
00370         df1b2variable::passnumber=1;
00371         df1b2_gradcalc1();
00372       }
00373 
00374       for (i=1;i<=usize;i++)
00375       {
00376         for (j=jmin;j<=jmax;j++)
00377         {
00378           //Hess(i,j-xsize)=y(i+xsize).u_bar[j-mind];
00379           y(i+xsize).get_u_bar_tilde()[j-mind]=Hessadjoint(i,j-xsize);
00380         }
00381       }
00382 
00383       //int mind=y(1).minder;
00384       df1b2variable::passnumber=2;
00385       df1b2_gradcalc1();
00386 
00387       df1b2variable::passnumber=3;
00388       df1b2_gradcalc1();
00389 
00390       f1b2gradlist->reset();
00391       f1b2gradlist->list.initialize();
00392       f1b2gradlist->list2.initialize();
00393       f1b2gradlist->list3.initialize();
00394       f1b2gradlist->nlist.initialize();
00395       f1b2gradlist->nlist2.initialize();
00396       f1b2gradlist->nlist3.initialize();
00397     }
00398 
00399     if (ad_comm::time_flag)
00400     {
00401       if (ad_comm::ptm)
00402       {
00403         double time=ad_comm::ptm->get_elapsed_time_and_reset();
00404         if (ad_comm::global_logfile)
00405         {
00406           (*ad_comm::global_logfile) << " time for 3rd derivatives "
00407             << time << endl;
00408         }
00409       }
00410     }
00411 
00412     dvector dtmp(1,xsize);
00413     for (i=1;i<=xsize;i++)
00414     {
00415       dtmp(i)=*y(i).get_u_tilde();
00416     }
00417     if (initial_df1b2params::separable_flag)
00418     {
00419       dvector scale(1,nvar);   // need to get scale from somewhere
00420       /*int check=*/initial_params::stddev_scale(scale,x);
00421       dvector sscale=scale(1,Dux(1).indexmax());
00422       for (i=1;i<=usize;i++)
00423       {
00424         Dux(i)=elem_prod(Dux(i),sscale);
00425       }
00426       dtmp=elem_prod(dtmp,sscale);
00427     }
00428 
00429     for (i=1;i<=xsize;i++)
00430     {
00431       xadjoint(i)+=dtmp(i);
00432     }
00433     for (i=1;i<=usize;i++)
00434       uadjoint(i)+=*y(xsize+i).get_u_tilde();
00435   }
00436  // *****************************************************************
00437  // new stuff to deal with quadraticprior
00438  // *****************************************************************
00439 
00440     int xstuff=3;
00441     if (xstuff && df1b2quadratic_prior::get_num_quadratic_prior()>0)
00442     {
00443       initial_params::straight_through_flag=0;
00444       funnel_init_var::lapprox=0;
00445       block_diagonal_flag=0;
00446       dvector scale1(1,nvar);   // need to get scale from somewhere
00447       initial_params::set_inactive_only_random_effects();
00448       /*int check=*/initial_params::stddev_scale(scale1,x);
00449 
00450       laplace_approximation_calculator::where_are_we_flag=3;
00451       quadratic_prior::in_qp_calculations=1;
00452       funnel_init_var::lapprox=this;
00453       df1b2_gradlist::set_no_derivatives();
00454       dvector scale(1,nvar);   // need to get scale from somewhere
00455       /*check=*/initial_params::stddev_scale(scale,x);
00456       dvector sscale=scale(1,Dux(1).indexmax());
00457 
00458       for (i=1;i<=usize;i++)
00459       {
00460         Dux(i)=elem_div(Dux(i),sscale);
00461       }
00462 
00463       if (xstuff>1)
00464       {
00465         df1b2quadratic_prior::get_Lxu_contribution(Dux);
00466       }
00467       quadratic_prior::in_qp_calculations=0;
00468       funnel_init_var::lapprox=0;
00469       laplace_approximation_calculator::where_are_we_flag=0;
00470 
00471       for (i=1;i<=usize;i++)
00472       {
00473         Dux(i)=elem_prod(Dux(i),sscale);
00474       }
00475       //local_dtemp=elem_prod(local_dtemp,sscale);
00476 
00477       if (xstuff>2)
00478       {
00479         dvector tmp=evaluate_function_with_quadprior(x,usize,pfmin);
00480         for (i=1;i<=xsize;i++)
00481         {
00482           xadjoint(i)+=tmp(i);
00483         }
00484       }
00485 
00486       if (xstuff>2)
00487       {
00488         quadratic_prior::get_cHessian_contribution_from_vHessian(Hess,xsize);
00489       }
00490     }
00491 
00492  // *****************************************************************
00493  // new stuff to deal with quadraticprior
00494  // *****************************************************************
00495   if (ad_comm::ptm)
00496   {
00497     /*double time=*/ad_comm::ptm->get_elapsed_time_and_reset();
00498   }
00499 
00500 #if defined(USE_ATLAS)
00501       if (!ad_comm::no_atlas_flag)
00502       {
00503         //xadjoint -= uadjoint*atlas_solve_spd_trans(Hess,Dux);
00504         xadjoint -= atlas_solve_spd_trans(Hess,uadjoint)*Dux;
00505       }
00506       else
00507       {
00508         //xadjoint -= uadjoint*solve(Hess,Dux);
00509         xadjoint -= solve(Hess,uadjoint)*Dux;
00510       }
00511 #else
00512       //xadjoint -= uadjoint*solve(Hess,Dux);
00513       xadjoint -= solve(Hess,uadjoint)*Dux;
00514 #endif
00515 
00516 
00517   if (ad_comm::ptm)
00518   {
00519     double time=ad_comm::ptm->get_elapsed_time_and_reset();
00520     if (ad_comm::global_logfile)
00521     {
00522       (*ad_comm::global_logfile) << " Time in second solve "
00523         << time << endl;
00524     }
00525   }
00526   if (ad_comm::ptm1)
00527   {
00528     double time=ad_comm::ptm1->get_elapsed_time_and_reset();
00529     if (ad_comm::global_logfile)
00530     {
00531       (*ad_comm::global_logfile) << " Total time in function evaluation "
00532         << time << endl << endl;
00533     }
00534   }
00535 
00536   return xadjoint;
00537 }
00538 #endif //#if defined(USE_LAPLACE)