pytorch+lstm实现的pos示例


Posted in Python onJanuary 14, 2020

学了几天终于大概明白pytorch怎么用了

这个是直接搬运的官方文档的代码

之后会自己试着实现其他nlp的任务

# Author: Robert Guthrie

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)


lstm = nn.LSTM(3, 3) # Input dim is 3, output dim is 3
inputs = [autograd.Variable(torch.randn((1, 3)))
     for _ in range(5)] # make a sequence of length 5

# initialize the hidden state.
hidden = (autograd.Variable(torch.randn(1, 1, 3)),
     autograd.Variable(torch.randn((1, 1, 3))))
for i in inputs:
  # Step through the sequence one element at a time.
  # after each step, hidden contains the hidden state.
  out, hidden = lstm(i.view(1, 1, -1), hidden)

# alternatively, we can do the entire sequence all at once.
# the first value returned by LSTM is all of the hidden states throughout
# the sequence. the second is just the most recent hidden state
# (compare the last slice of "out" with "hidden" below, they are the same)
# The reason for this is that:
# "out" will give you access to all hidden states in the sequence
# "hidden" will allow you to continue the sequence and backpropagate,
# by passing it as an argument to the lstm at a later time
# Add the extra 2nd dimension
inputs = torch.cat(inputs).view(len(inputs), 1, -1)
hidden = (autograd.Variable(torch.randn(1, 1, 3)), autograd.Variable(
  torch.randn((1, 1, 3)))) # clean out hidden state
out, hidden = lstm(inputs, hidden)
#print(out)
#print(hidden)

#准备数据
def prepare_sequence(seq, to_ix):
  idxs = [to_ix[w] for w in seq]
  tensor = torch.LongTensor(idxs)
  return autograd.Variable(tensor)

training_data = [
  ("The dog ate the apple".split(), ["DET", "NN", "V", "DET", "NN"]),
  ("Everybody read that book".split(), ["NN", "V", "DET", "NN"])
]
word_to_ix = {}
for sent, tags in training_data:
  for word in sent:
    if word not in word_to_ix:
      word_to_ix[word] = len(word_to_ix)
print(word_to_ix)
tag_to_ix = {"DET": 0, "NN": 1, "V": 2}

# These will usually be more like 32 or 64 dimensional.
# We will keep them small, so we can see how the weights change as we train.
EMBEDDING_DIM = 6
HIDDEN_DIM = 6

#继承自nn.module
class LSTMTagger(nn.Module):

  def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
    super(LSTMTagger, self).__init__()
    self.hidden_dim = hidden_dim

    #一个单词数量到embedding维数的矩阵
    self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)

    #传入两个维度参数
    # The LSTM takes word embeddings as inputs, and outputs hidden states
    # with dimensionality hidden_dim.
    self.lstm = nn.LSTM(embedding_dim, hidden_dim)

    #线性layer从隐藏状态空间映射到tag便签
    # The linear layer that maps from hidden state space to tag space
    self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
    self.hidden = self.init_hidden()

  def init_hidden(self):
    # Before we've done anything, we dont have any hidden state.
    # Refer to the Pytorch documentation to see exactly
    # why they have this dimensionality.
    # The axes semantics are (num_layers, minibatch_size, hidden_dim)
    return (autograd.Variable(torch.zeros(1, 1, self.hidden_dim)),
        autograd.Variable(torch.zeros(1, 1, self.hidden_dim)))

  def forward(self, sentence):
    embeds = self.word_embeddings(sentence)
    lstm_out, self.hidden = self.lstm(embeds.view(len(sentence), 1, -1), self.hidden)
    tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
    tag_scores = F.log_softmax(tag_space)
    return tag_scores

#embedding维度,hidden维度,词语数量,标签数量
model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix))

#optim中存了各种优化算法
loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

# See what the scores are before training
# Note that element i,j of the output is the score for tag j for word i.
inputs = prepare_sequence(training_data[0][0], word_to_ix)
tag_scores = model(inputs)
print(tag_scores)

for epoch in range(300): # again, normally you would NOT do 300 epochs, it is toy data
  for sentence, tags in training_data:
    # Step 1. Remember that Pytorch accumulates gradients.
    # We need to clear them out before each instance
    model.zero_grad()

    # Also, we need to clear out the hidden state of the LSTM,
    # detaching it from its history on the last instance.
    model.hidden = model.init_hidden()

    # Step 2. Get our inputs ready for the network, that is, turn them into
    # Variables of word indices.
    sentence_in = prepare_sequence(sentence, word_to_ix)
    targets = prepare_sequence(tags, tag_to_ix)

    # Step 3. Run our forward pass.
    tag_scores = model(sentence_in)

    # Step 4. Compute the loss, gradients, and update the parameters by
    # calling optimizer.step()
    loss = loss_function(tag_scores, targets)
    loss.backward()
    optimizer.step()

# See what the scores are after training
inputs = prepare_sequence(training_data[0][0], word_to_ix)
tag_scores = model(inputs)
# The sentence is "the dog ate the apple". i,j corresponds to score for tag j
# for word i. The predicted tag is the maximum scoring tag.
# Here, we can see the predicted sequence below is 0 1 2 0 1
# since 0 is index of the maximum value of row 1,
# 1 is the index of maximum value of row 2, etc.
# Which is DET NOUN VERB DET NOUN, the correct sequence!
print(tag_scores)

以上这篇pytorch+lstm实现的pos示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Django框架中编写Context处理器的方法
Jul 20 Python
Python爬虫模拟登录带验证码网站
Jan 22 Python
python框架中flask知识点总结
Aug 17 Python
python os.path模块常用方法实例详解
Sep 16 Python
PyQt5实现从主窗口打开子窗口的方法
Jun 19 Python
python gdal安装与简单使用
Aug 01 Python
详解Python文件修改的两种方式
Aug 22 Python
python区分不同数据类型的方法
Oct 14 Python
Django项目使用ckeditor详解(不使用admin)
Dec 17 Python
python如何将两张图片生成为全景图片
Mar 05 Python
Django Path转换器自定义及正则代码实例
May 29 Python
Django admin组件的使用
Oct 24 Python
Python中sorted()排序与字母大小写的问题
Jan 14 #Python
Pytorch实现LSTM和GRU示例
Jan 14 #Python
Python生成词云的实现代码
Jan 14 #Python
pytorch-RNN进行回归曲线预测方式
Jan 14 #Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
You might like
PHP5.5在windows安装使用memcached服务端的方法
2014/04/16 PHP
ThinkPHP添加更新标签的方法
2014/12/05 PHP
PHP+jquery+ajax实现即时聊天功能实例
2014/12/23 PHP
FastCGI 进程意外退出造成500错误
2015/07/26 PHP
PHP中ltrim()函数的用法与实例讲解
2019/03/28 PHP
Laravel中如何轻松容易的输出完整的SQL语句
2020/07/26 PHP
常用js脚本
2006/12/03 Javascript
jquery聚焦文本框与扩展文本框聚焦方法
2012/10/12 Javascript
javascript实例分享---具有立体效果的图片特效
2014/06/08 Javascript
JS实现动态给图片添加边框的方法
2015/04/01 Javascript
JavaScript中数据结构与算法(二):队列
2015/06/19 Javascript
JavaScript+CSS实现仿Mootools竖排弹性动画菜单效果
2015/10/14 Javascript
实例讲解jQuery EasyUI tree中state属性慎用
2016/04/01 Javascript
百度地图JavascriptApi Marker平滑移动及车头指向行径方向
2017/03/13 Javascript
详解vue+vueRouter+webpack的简单实例
2017/06/17 Javascript
纯javascript前端实现base64图片下载(兼容IE10+)
2018/09/14 Javascript
JS模拟浏览器实现全局搜索功能
2019/09/11 Javascript
微信小程序保存图片到相册权限设置
2020/04/09 Javascript
[03:20]2015国际邀请赛全明星表演赛
2015/08/08 DOTA
[01:05:52]DOTA2-DPC中国联赛 正赛 Ehome vs Aster BO3 第一场 2月2日
2021/03/11 DOTA
Python BeautifulSoup中文乱码问题的2种解决方法
2014/04/22 Python
Python Web框架Flask中使用新浪SAE云存储实例
2015/02/08 Python
Python cookbook(数据结构与算法)在字典中将键映射到多个值上的方法
2018/02/18 Python
python3+PyQt5实现文档打印功能
2018/04/24 Python
Python实现的直接插入排序算法示例
2018/04/29 Python
django ManyToManyField多对多关系的实例详解
2019/08/09 Python
如何用PyPy让你的Python代码运行得更快
2020/12/02 Python
html5调用app分享功能示例(WebViewJavascriptBridge)
2018/03/21 HTML / CSS
美国最古老的精致书写工具制造商:A.T. Cross(高仕)
2018/01/30 全球购物
梵蒂冈和罗马卡:Omnia Card Pass
2018/02/10 全球购物
Blank NYC官网:夹克、牛仔裤等
2020/12/16 全球购物
化学教师自荐信范文
2013/12/28 职场文书
学校运动会广播稿100条
2014/09/14 职场文书
2014统计局民主生活会对照检查材料思想汇报
2014/10/02 职场文书
2014年煤矿安全工作总结
2014/12/04 职场文书
Java中Dijkstra(迪杰斯特拉)算法
2022/05/20 Java/Android