這篇文章給大家介紹怎么在pytorch中處理可變長(zhǎng)度序列,內(nèi)容非常詳細(xì),感興趣的小伙伴們可以參考借鑒,希望對(duì)大家能有所幫助。
成都創(chuàng)新互聯(lián)長(zhǎng)期為數(shù)千家客戶(hù)提供的網(wǎng)站建設(shè)服務(wù),團(tuán)隊(duì)從業(yè)經(jīng)驗(yàn)10年,關(guān)注不同地域、不同群體,并針對(duì)不同對(duì)象提供差異化的產(chǎn)品和服務(wù);打造開(kāi)放共贏平臺(tái),與合作伙伴共同營(yíng)造健康的互聯(lián)網(wǎng)生態(tài)環(huán)境。為紅橋企業(yè)提供專(zhuān)業(yè)的網(wǎng)站設(shè)計(jì)制作、成都網(wǎng)站建設(shè),紅橋網(wǎng)站改版等技術(shù)服務(wù)。擁有10年豐富建站經(jīng)驗(yàn)和眾多成功案例,為您定制開(kāi)發(fā)。1、torch.nn.utils.rnn.PackedSequence()
NOTE: 這個(gè)類(lèi)的實(shí)例不能手動(dòng)創(chuàng)建。它們只能被 pack_padded_sequence() 實(shí)例化。
PackedSequence對(duì)象包括:
一個(gè)data對(duì)象:一個(gè)torch.Variable(令牌的總數(shù),每個(gè)令牌的維度),在這個(gè)簡(jiǎn)單的例子中有五個(gè)令牌序列(用整數(shù)表示):(18,1)
一個(gè)batch_sizes對(duì)象:每個(gè)時(shí)間步長(zhǎng)的令牌數(shù)列表,在這個(gè)例子中為:[6,5,2,4,1]
用pack_padded_sequence函數(shù)來(lái)構(gòu)造這個(gè)對(duì)象非常的簡(jiǎn)單:
如何構(gòu)造一個(gè)PackedSequence對(duì)象(batch_first = True)
PackedSequence對(duì)象有一個(gè)很不錯(cuò)的特性,就是我們無(wú)需對(duì)序列解包(這一步操作非常慢)即可直接在PackedSequence數(shù)據(jù)變量上執(zhí)行許多操作。特別是我們可以對(duì)令牌執(zhí)行任何操作(即對(duì)令牌的順序/上下文不敏感)。當(dāng)然,我們也可以使用接受PackedSequence作為輸入的任何一個(gè)pyTorch模塊(pyTorch 0.2)。
2、torch.nn.utils.rnn.pack_padded_sequence()
這里的pack,理解成壓緊比較好。 將一個(gè) 填充過(guò)的變長(zhǎng)序列 壓緊。(填充時(shí)候,會(huì)有冗余,所以壓緊一下)
輸入的形狀可以是(T×B×* )。T是最長(zhǎng)序列長(zhǎng)度,B是batch size,*代表任意維度(可以是0)。如果batch_first=True的話(huà),那么相應(yīng)的 input size 就是 (B×T×*)。
Variable中保存的序列,應(yīng)該按序列長(zhǎng)度的長(zhǎng)短排序,長(zhǎng)的在前,短的在后。即input[:,0]代表的是最長(zhǎng)的序列,input[:, B-1]保存的是最短的序列。
NOTE: 只要是維度大于等于2的input都可以作為這個(gè)函數(shù)的參數(shù)。你可以用它來(lái)打包labels,然后用RNN的輸出和打包后的labels來(lái)計(jì)算loss。通過(guò)PackedSequence對(duì)象的.data屬性可以獲取 Variable。
參數(shù)說(shuō)明:
input (Variable) – 變長(zhǎng)序列 被填充后的 batch
lengths (list[int]) – Variable 中 每個(gè)序列的長(zhǎng)度。
batch_first (bool, optional) – 如果是True,input的形狀應(yīng)該是B*T*size。
返回值:
一個(gè)PackedSequence 對(duì)象。
3、torch.nn.utils.rnn.pad_packed_sequence()
填充packed_sequence。
上面提到的函數(shù)的功能是將一個(gè)填充后的變長(zhǎng)序列壓緊。 這個(gè)操作和pack_padded_sequence()是相反的。把壓緊的序列再填充回來(lái)。
返回的Varaible的值的size是 T×B×*, T 是最長(zhǎng)序列的長(zhǎng)度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*。
Batch中的元素將會(huì)以它們長(zhǎng)度的逆序排列。
參數(shù)說(shuō)明:
sequence (PackedSequence) – 將要被填充的 batch
batch_first (bool, optional) – 如果為T(mén)rue,返回的數(shù)據(jù)的格式為 B×T×*。
返回值: 一個(gè)tuple,包含被填充后的序列,和batch中序列的長(zhǎng)度列表。
例子:
import torch import torch.nn as nn from torch.autograd import Variable from torch.nn import utils as nn_utils batch_size = 2 max_length = 3 hidden_size = 2 n_layers =1 tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1) tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1] seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step # pack it pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True) # initialize rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True) h0 = Variable(torch.randn(n_layers, batch_size, hidden_size)) #forward out, _ = rnn(pack, h0) # unpack unpacked = nn_utils.rnn.pad_packed_sequence(out) print('111',unpacked)
輸出:
111 (Variable containing: (0 ,.,.) = 0.5406 0.3584 -0.1403 0.0308 (1 ,.,.) = -0.6855 -0.9307 0.0000 0.0000 [torch.FloatTensor of size 2x2x2] , [2, 1])pytorch的優(yōu)點(diǎn)
1.PyTorch是相當(dāng)簡(jiǎn)潔且高效快速的框架;2.設(shè)計(jì)追求最少的封裝;3.設(shè)計(jì)符合人類(lèi)思維,它讓用戶(hù)盡可能地專(zhuān)注于實(shí)現(xiàn)自己的想法;4.與google的Tensorflow類(lèi)似,F(xiàn)AIR的支持足以確保PyTorch獲得持續(xù)的開(kāi)發(fā)更新;5.PyTorch作者親自維護(hù)的論壇 供用戶(hù)交流和求教問(wèn)題6.入門(mén)簡(jiǎn)單
關(guān)于怎么在pytorch中處理可變長(zhǎng)度序列就分享到這里了,希望以上內(nèi)容可以對(duì)大家有一定的幫助,可以學(xué)到更多知識(shí)。如果覺(jué)得文章不錯(cuò),可以把它分享出去讓更多的人看到。
名稱(chēng)欄目:怎么在pytorch中處理可變長(zhǎng)度序列-創(chuàng)新互聯(lián)
本文路徑:http://m.rwnh.cn/article18/cechgp.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供軟件開(kāi)發(fā)、網(wǎng)站維護(hù)、品牌網(wǎng)站制作、外貿(mào)建站、手機(jī)網(wǎng)站建設(shè)、網(wǎng)站策劃
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶(hù)投稿、用戶(hù)轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀(guān)點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話(huà):028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來(lái)源: 創(chuàng)新互聯(lián)
猜你還喜歡下面的內(nèi)容
營(yíng)銷(xiāo)型網(wǎng)站建設(shè)知識(shí)