【问题标题】:Cartesian product in MataMata 中的笛卡尔积
【发布时间】:2015-03-27 19:50:36
【问题描述】:

要构造一组向量,我需要取集合 C[1]..C[d] 的笛卡尔积,

D := {x : x[i] ϵ C[i], i = 1..d}

例子:如果*C[1]=(5,6,7)';*C[2]=(3,5,6)';*C[3]=(1,3,5)',那么D的一些元素是(5,3,1), (5,3,3) ...

我想知道:一般来说,在 Mata 中获取笛卡尔积的最佳方法是什么?我发现了 d=3 的笨拙方法,如下图所示。


详细示例。此代码应说明我尝试过的内容和所需的输出。 mm_expand 函数来自ssc install moremata

mata

// prep

lo = (5,3,1)'
hi = (7,6,5)'
all = uniqrows((lo\hi))

n_cols = length(lo)
n_vals = length(all)

c_list = J( 1,n_cols,NULL )
c_lens = J( 1,n_cols,0 )

for (i=1;i<=n_cols;i++){
    c_list[i] = &(select( all,all :>= lo[i] :& all :<= hi[i] ))
    c_lens[i] = length(*c_list[i])
}   

// question: How should I take this Cartesian product?

grid_box = 
mm_expand(*c_list[1],c_lens[2]*c_lens[3],0,1),
mm_expand(mm_expand(*c_list[2],c_lens[1],0,1),c_lens[3],0,0),
mm_expand(*c_list[3],c_lens[1]*c_lens[2],0,0)

// (just fyi) my next step

is_decr = ! rowsum( grid_box[,1..(n_cols-1)]-grid_box[,2..n_cols] :< 0 )
select(grid_box,is_decr)

end

代码的符号和“prep”部分与my application相关。

【问题讨论】:

    标签: set stata cartesian-product


    【解决方案1】:

    最简单的方法是使用递归:

    real matrix cart_prod(pointer vector c_list ,| real scalar curr_i){
        if(curr_i==.) curr_i=1
        myret = (*c_list[curr_i])
        if (curr_i<length(c_list)){
            ret = cart_prod(c_list, curr_i+1)
            myret = mm_expand(myret,rows(ret),1,1), mm_expand(ret, rows(myret),1,0)
        }
        return(myret)
    }
    cart_prod(c_list)
    

    即使从c_list 指向的向量的长度不同,这也可以工作。

    【讨论】:

    • 谢谢!我接受你的优雅解决方案,但也发布我自己的,因为我怀疑它会快一点。
    【解决方案2】:

    结果的每一列只需要mm_expanded 两次。以这种方式直接构建列应该更快(而不是在第 k 列的 d-k 步骤中,如@BeingQuisitive 的答案)。

    function prod(x,| real scalar need_int){
        y = exp(sum(log(x)))
        if (need_int==0) return(y)
        return(round(y))
    }
    
    function cartem(pointer vector vlist){
    
        // input: vlist should point to column vectors
    
        d = length( vlist )
        lens = J(1,d,.)
        for (i=1;i<=d;i++){
            lens[i] = length ( *vlist[i] )
        }
        tot_len = prod(lens)
        out = J(tot_len,d,.)
    
        out[,1] = mm_expand(*vlist[1],round(tot_len/lens[1]),0,1)
        out[,d] = mm_expand(*vlist[d],round(tot_len/lens[d]),0,0)
    
        if (d == 2) return(out)
    
        for (i=2;i<=d-1;i++){
            out[,i] = mm_expand(
            mm_expand(*vlist[i],prod(lens[1..i-1]),0,0)
            ,prod(lens[i+1..d]),0,1)
        }
    
        return(out)
    }
    

    这是一个基于我的用例的示例。它表明上面的代码产生了与@BeingQuisitive 的解决方案相同的(期望的)结果:

    c_list = (&(7\10),&(5\6\7),&(3\5\6),&(1\3\5))
    
    cartem(c_list) == cart_prod(c_list)
    // ^ it's true
    

    我并不精通 Stata 的基准测试工具,所以我还没有证实我对效率提升的怀疑。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2012-04-07
      • 2013-08-17
      • 2011-01-29
      • 2019-04-20
      • 2012-10-27
      相关资源
      最近更新 更多