博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow 在加载大型的embedding模型参数时,会遇到cannot be larger than 2GB
阅读量:7011 次
发布时间:2019-06-28

本文共 1653 字,大约阅读时间需要 5 分钟。

      这种问题是,对于每一个变量 variable 由于是基于protobuf存在这大小限制(2G),这个时候,我们需要将embedding拆开,拆分成N等分,来使得每一个

variable都在2G以下; 

  

1 # !/usr/bin/env/python 2 # coding=utf-8 3 import tensorflow as tf 4 import numpy as np 5  6 input_ids = tf.placeholder(dtype=tf.int32, shape=[None,None]) 7  8 num_shards = 3 9 weights = []10 weights_shape = np.arange(27).reshape(9, 3)11 # assert weights_shape[0] % num_shards == 012 num_shards_len = (weights_shape.shape[0]) / num_shards13 assert  (weights_shape.shape[0]) % num_shards ==014 begin_ = 015 ends_ = num_shards_len16 for i in range(0, num_shards):17     if (i + 1) * num_shards_len < weights_shape.shape[0]:18         begin_ = i * num_shards_len19         if i + 1 == num_shards:20             ends_ = weights_shape.shape[0]21         else:22             ends_ = (i + 1) * num_shards_len23     else:24         begin_ = i * num_shards_len25         ends_ = weights_shape.shape[0]26     weights_i = tf.get_variable("words-%02d" % i,27                                 initializer=tf.constant(weights_shape[begin_: ends_, ]))28     weights.append(weights_i)29 30 input_embedding = tf.nn.embedding_lookup(weights, input_ids,partition_strategy="div")31 32 sess = tf.InteractiveSession()33 sess.run(tf.global_variables_initializer())34 print(sess.run(weights))35 36 print(sess.run(input_embedding, feed_dict={input_ids: [[1, 2], [3, 0], [8, 2], [5, 1]]}))

 结果为:

    

[array([[0, 1, 2],       [3, 4, 5],       [6, 7, 8]]), array([[ 9, 10, 11],       [12, 13, 14],       [15, 16, 17]]), array([[18, 19, 20],       [21, 22, 23],       [24, 25, 26]])][[[ 3  4  5]  [ 6  7  8]] [[ 9 10 11]  [ 0  1  2]] [[24 25 26]  [ 6  7  8]] [[15 16 17]  [ 3  4  5]]]

 

转载地址:http://ebqtl.baihongyu.com/

你可能感兴趣的文章