【问题标题】:How to pass parameters into reference of a function in C++?如何将参数传递给 C++ 中函数的引用?
【发布时间】:2013-09-16 13:18:01
【问题描述】:

我正在尝试了解我从 Internet 下载的库 (dlib)。这是源代码:

double func1(double a, double b)
{
    return a - b;
}

double func2(const func_type& f);

int main()
{
    func2(&func1);

    return 0;
}

更详细的版本,这里是链接http://dlib.net/least_squares_ex.cpp.html

当我运行这个示例时,它运行良好。但是,我看不到参数是如何传递给func1的。

谁能帮我理解这是怎么回事?

================================================ ======== 编辑:

这里是完整版的源代码:

// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*

    This is an example illustrating the use the general purpose non-linear 
    least squares optimization routines from the dlib C++ Library.

    This example program will demonstrate how these routines can be used for data fitting.
    In particular, we will generate a set of data and then use the least squares  
    routines to infer the parameters of the model which generated the data.
*/


#include <dlib/optimization.h>
#include <iostream>
#include <vector>

using namespace std;
using namespace dlib;

typedef matrix<double,2,1> input_vector;
typedef matrix<double,3,1> parameter_vector;

// We will use this function to generate data.  It represents a function of 2 variables
// and 3 parameters.   The least squares procedure will be used to infer the values of 
// the 3 parameters based on a set of input/output pairs.
double model (const input_vector& input, const parameter_vector& params)
{
    const double p0 = params(0);
    const double p1 = params(1);
    const double p2 = params(2);

    const double i0 = input(0);
    const double i1 = input(1);

    const double temp = p0*i0 + p1*i1 + p2;

    return temp*temp;
}

// This function is the "residual" for a least squares problem.   It takes an input/output
// pair and compares it to the output of our model and returns the amount of error.  The idea
// is to find the set of parameters which makes the residual small on all the data pairs.
double residual (const std::pair<input_vector, double>& data,const parameter_vector& params)
{
    return model(data.first, params) - data.second;
}

// This function is the derivative of the residual() function with respect to the parameters.
parameter_vector residual_derivative (const std::pair<input_vector, double>& data,const parameter_vector& params)
{
    parameter_vector der;

    const double p0 = params(0);
    const double p1 = params(1);
    const double p2 = params(2);

    const double i0 = data.first(0);
    const double i1 = data.first(1);

    const double temp = p0*i0 + p1*i1 + p2;

    der(0) = i0*2*temp;
    der(1) = i1*2*temp;
    der(2) = 2*temp;

    return der;
}

int main()
{
    try
    {
        // randomly pick a set of parameters to use in this example
        const parameter_vector params = 10*randm(3,1);
        cout << "params: " << trans(params) << endl;

        // Now lets generate a bunch of input/output pairs according to our model.
        std::vector<std::pair<input_vector, double> > data_samples;
        input_vector input;
        for (int i = 0; i < 1000; ++i)
        {
            input = 10*randm(2,1);
            const double output = model(input, params);

            // save the pair
            data_samples.push_back(make_pair(input, output));
        }

        // Before we do anything, lets make sure that our derivative function defined above matches
        // the approximate derivative computed using central differences (via derivative()).  
        // If this value is big then it means we probably typed the derivative function incorrectly.
        cout << "derivative error: " << length(residual_derivative(data_samples[0], params) - 
                                               derivative(&residual)(data_samples[0], params) ) << endl;

        // Now lets use the solve_least_squares_lm() routine to figure out what the
        // parameters are based on just the data_samples.
        parameter_vector x;
        x = 1;

        cout << "Use Levenberg-Marquardt" << endl;
        // Use the Levenberg-Marquardt method to determine the parameters which
        // minimize the sum of all squared residuals.
        solve_least_squares_lm(objective_delta_stop_strategy(1e-7).be_verbose(), 
                               &residual,
                               &residual_derivative,
                               data_samples,
                               x);

        // Now x contains the solution.  If everything worked it will be equal to params.
        cout << "inferred parameters: "<< trans(x) << endl;
        cout << "solution error:      "<< length(x - params) << endl;
        cout << endl;




        x = 1;
        cout << "Use Levenberg-Marquardt, approximate derivatives" << endl;
        // If we didn't create the residual_derivative function then we could
        // have used this method which numerically approximates the derivatives for you.
        solve_least_squares_lm(objective_delta_stop_strategy(1e-7).be_verbose(), 
                               &residual,
                               derivative(&residual),
                               data_samples,
                               x);

        // Now x contains the solution.  If everything worked it will be equal to params.
        cout << "inferred parameters: "<< trans(x) << endl;
        cout << "solution error:      "<< length(x - params) << endl;
        cout << endl;




        x = 1;
        cout << "Use Levenberg-Marquardt/quasi-newton hybrid" << endl;
        // This version of the solver uses a method which is appropriate for problems
        // where the residuals don't go to zero at the solution.  So in these cases
        // it may provide a better answer.
        solve_least_squares(objective_delta_stop_strategy(1e-7).be_verbose(), 
                            &residual,
                            &residual_derivative,
                            data_samples,
                            x);

        // Now x contains the solution.  If everything worked it will be equal to params.
        cout << "inferred parameters: "<< trans(x) << endl;
        cout << "solution error:      "<< length(x - params) << endl;

    }
    catch (std::exception& e)
    {
        cout << e.what() << endl;
    }
}

// Example output:
/*
params: 8.40188 3.94383 7.83099 

derivative error: 9.78267e-06
Use Levenberg-Marquardt
iteration: 0   objective: 2.14455e+10
iteration: 1   objective: 1.96248e+10
iteration: 2   objective: 1.39172e+10
iteration: 3   objective: 1.57036e+09
iteration: 4   objective: 2.66917e+07
iteration: 5   objective: 4741.9
iteration: 6   objective: 0.000238674
iteration: 7   objective: 7.8815e-19
iteration: 8   objective: 0
inferred parameters: 8.40188 3.94383 7.83099 

solution error:      0

Use Levenberg-Marquardt, approximate derivatives
iteration: 0   objective: 2.14455e+10
iteration: 1   objective: 1.96248e+10
iteration: 2   objective: 1.39172e+10
iteration: 3   objective: 1.57036e+09
iteration: 4   objective: 2.66917e+07
iteration: 5   objective: 4741.87
iteration: 6   objective: 0.000238701
iteration: 7   objective: 1.0571e-18
iteration: 8   objective: 4.12469e-22
inferred parameters: 8.40188 3.94383 7.83099 

solution error:      5.34754e-15

Use Levenberg-Marquardt/quasi-newton hybrid
iteration: 0   objective: 2.14455e+10
iteration: 1   objective: 1.96248e+10
iteration: 2   objective: 1.3917e+10
iteration: 3   objective: 1.5572e+09
iteration: 4   objective: 2.74139e+07
iteration: 5   objective: 5135.98
iteration: 6   objective: 0.000285539
iteration: 7   objective: 1.15441e-18
iteration: 8   objective: 3.38834e-23
inferred parameters: 8.40188 3.94383 7.83099 

solution error:      1.77636e-15
*/

【问题讨论】:

  • 信息不足。什么的源代码?什么是func_type? func2的定义是什么?为什么不打印调用结果?
  • 嗯,1. 一个下载库的概念架构的源代码。如果您仔细阅读,您可以找到指向完整版本的链接。我没有在这里发布它,因为它太长了,因此可能会分散注意力。 2. func2 和 func_type 的定义隐藏在我找不到的地方。 3. 同样,如果您仔细阅读,您可以在我发布的链接中找到打印出来的内容。 4. 请务必仔细阅读。谢谢。 :-)
  • 确实阅读了您的链接代码。它与您的简单示例不同。每个看起来像它的调用都需要一个函数的引用/指针需要其他参数。
  • 在dlib中,所有这些函数都是模板。所以它们真的被声明为模板 double f(funct_type)。这就是它的来源。

标签: c++ function-pointers


【解决方案1】:

你告诉func2 一个函数需要两个doubles 并返回一个double。在其实现中的某个地方func2 可能正在调用该函数; func1 在你的情况下。至此,合适的参数就传递过来了。

这种机制称为回调函数。它们在 C++ 中不如在 C 中常见(在前者中,可以使用模板代替)。见What is a callback function?

【讨论】:

    【解决方案2】:

    func2 的实现中的某个地方,它会有一个看起来像f(foo, bar) 的表达式。这就是争论的来源。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2012-02-25
      • 2020-09-28
      • 1970-01-01
      • 1970-01-01
      • 2021-05-23
      • 2021-01-08
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多