0%

稀疏图计算

稀疏图计算的实现

遇到图中元素很稀疏时可以使用sparse tensor计算所需要的值,然后再转化为dense tensor。假如涉及图的运算中稀疏程度很大、或者中间结果的维度很高,都能够有效的降低时间跟内存开销。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 选择所有非0元素的索引
indices = torch.nonzero(x)
# 从参与计算的tensor中取值
values = x[tuple(indices[i] for i in range(indices.shape[0]))]
# 进行需要的索引变换
j_indices = indices.clone()
value_indices = list(j_indices[i] for i in range(j_indices.shape[0]))
value_indices[1] = value_indices[1].zero_()
# 从某个参与计算的tensor中取相应的值
emb_j = node_j[tuple(value_indices)]
# 进行计算
final_values = func(values, emb_j)
# 最后可能需要再次对索引进行变换(比如加上一维)
extra_indices = torch.arange(0, window_size).to(indices.device).unsqueeze(1).repeat(1, n_num).t().reshape(
(1, n_num, window_size))
indices = torch.cat([indices[0:1], extra_indices, indices[1:]], dim=0)
indices = indices.reshape((4, -1))
# 最后生成sparse tensor,再转换回来
x_typename = torch.typename(x).split('.')[-1]
sparse_tensortype = getattr(torch.sparse, x_typename)
res = sparse_tensortype(indices, final_values, (b, window_size, l, l, 1)).requires_grad_(True).to_dense()

另外,中间尝试了把几个计算涉及的tensor分别转为sparse tensor然后运算最后再转回来,但是在backward的时候会出错,用tensor.contiguous()也没解决。