Pytorch实验常用代码段汇总


Posted in Python onNovember 19, 2020

1. 大幅度提升 Pytorch 的训练速度

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

但加了这一行,似乎运行结果不一样了。

2. 把原有的记录文件加个后缀变为 .bak 文件,避免直接覆盖

# from co-teaching train codetxtfile = save_dir + "/" + model_str + "_%s.txt"%str(args.optimizer)  ## good job!
nowTime=datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
if os.path.exists(txtfile):
  os.system('mv %s %s' % (txtfile, txtfile+".bak-%s" % nowTime)) # bakeup 备份文件

3. 计算 Accuracy 返回list, 调用函数时,直接提取值,而非提取list

# from co-teaching code but MixMatch_pytorch code also has itdef accuracy(logit, target, topk=(1,)):
  """Computes the precision@k for the specified values of k"""
  output = F.softmax(logit, dim=1) # but actually not need it 
  maxk = max(topk)
  batch_size = target.size(0)

  _, pred = output.topk(maxk, 1, True, True) # _, pred = logit.topk(maxk, 1, True, True)
  pred = pred.t()
  correct = pred.eq(target.view(1, -1).expand_as(pred))

  res = []
  for k in topk:
    correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
    res.append(correct_k.mul_(100.0 / batch_size)) # it seems this is a bug, when not all batch has same size, the mean of accuracy of each batch is not the mean of accu of all dataset
  return res

prec1, = accuracy(logit, labels, topk=(1,)) # , indicate tuple unpackage
prec1, prec5 = accuracy(logits, labels, topk=(1, 5))

4. 善于利用 logger 文件来记录每一个 epoch 的实验值

# from Pytorch_MixMatch codeclass Logger(object):
  '''Save training process to log file with simple plot function.'''
  def __init__(self, fpath, title=None, resume=False): 
    self.file = None
    self.resume = resume
    self.title = '' if title == None else title
    if fpath is not None:
      if resume: 
        self.file = open(fpath, 'r') 
        name = self.file.readline()
        self.names = name.rstrip().split('\t')
        self.numbers = {}
        for _, name in enumerate(self.names):
          self.numbers[name] = []

        for numbers in self.file:
          numbers = numbers.rstrip().split('\t')
          for i in range(0, len(numbers)):
            self.numbers[self.names[i]].append(numbers[i])
        self.file.close()
        self.file = open(fpath, 'a') 
      else:
        self.file = open(fpath, 'w')

  def set_names(self, names):
    if self.resume: 
      pass
    # initialize numbers as empty list
    self.numbers = {}
    self.names = names
    for _, name in enumerate(self.names):
      self.file.write(name)
      self.file.write('\t')
      self.numbers[name] = []
    self.file.write('\n')
    self.file.flush()


  def append(self, numbers):
    assert len(self.names) == len(numbers), 'Numbers do not match names'
    for index, num in enumerate(numbers):
      self.file.write("{0:.4f}".format(num))
      self.file.write('\t')
      self.numbers[self.names[index]].append(num)
    self.file.write('\n')
    self.file.flush()

  def plot(self, names=None):  
    names = self.names if names == None else names
    numbers = self.numbers
    for _, name in enumerate(names):
      x = np.arange(len(numbers[name]))
      plt.plot(x, np.asarray(numbers[name]))
    plt.legend([self.title + '(' + name + ')' for name in names])
    plt.grid(True)

  def close(self):
    if self.file is not None:
      self.file.close()
# usage
logger = Logger(new_folder+'/log_for_%s_WebVision1M.txt'%data_type, title=title)
logger.set_names(['epoch', 'val_acc', 'val_acc_ImageNet'])
for epoch in range(100):
  logger.append([epoch, val_acc, val_acc_ImageNet])
logger.close()

5. 利用 argparser 命令行工具来进行代码重构,使用不同参数适配不同数据集,不同优化方式,不同setting, 避免多个高度冗余的重复代码

# argparser 命令行工具有一个坑的地方是,无法设置 bool 变量, flag=FALSE, 然后会解释为 字符串,仍然当做 True

发现可以使用如下命令来进行修补,来自 ICML-19-SGC github 上代码

parser.add_argument('--test', action='store_true', default=False, help='inductive training.')

当命令行出现 test 字样时,则为 args.test = true

若未出现 test 字样,则为 args.test = false

6. 使用shell 变量来设置所使用的显卡, 便于利用shell 脚本进行程序的串行,从而挂起来跑。或者多开几个 screen 进行同一张卡上多个程序并行跑,充分利用显卡的内存。

命令行中使用如下语句,或者把语句写在 shell 脚本中 # 不要忘了 export

export CUDA_VISIBLE_DEVICES=1 #设置当前可用显卡为编号为1的显卡(从 0 开始编号),即不在 0 号上跑
export CUDA_VISIBlE_DEVICES=0,1 # 设置当前可用显卡为 0,1 显卡,当 0 用满后,就会自动使用 1 显卡

一般经验,即使多个程序并行跑时,即使显存完全足够,单个程序的速度也会变慢,这可能是由于还有 cpu 和内存的限制。

这里显存占用不是阻碍,应该主要看GPU 利用率(也就是计算单元的使用,如果达到了 99% 就说明程序过多了。)

使用 watch nvidia-smi 来监测每个程序当前是否在正常跑。

7. 使用 python 时间戳来保存并进行区别不同的 result 文件

参照自己很早之前写的 co-training 的代码

8. 把训练时 命令行窗口的 print 输出全部保存到一个 log 文件:(参照 DIEN)

mkdir dnn_save_path
mkdir dnn_best_model
CUDA_VISIBLE_DEVICES=0 /usr/bin/python2.7 script/train.py train DIEN >train_dein2.log 2>&1 &

并且使用如下命令 | tee 命令则可以同时保存到文件并且写到命令行输出:

python script/train.py train DIEN | tee train_dein2.log

9. git clone 可以用来下载 github 上的代码,更快。(由 DIEN 的下载)

git clone https://github.com/mouna99/dien.git 使用这个命令可以下载 github 上的代码库

10. (来自 DIEN ) 对于命令行参数不一定要使用 argparser 来读取,也可以直接使用 sys.argv 读取,不过这样的话,就无法指定关键字参数,只能使用位置参数。

### run.sh ###
CUDA_VISIBLE_DEVICES=0 /usr/bin/python2.7 script/train.py train DIEN >train_dein2.log 2>&1 &
#############

if __name__ == '__main__':
  if len(sys.argv) == 4:
    SEED = int(sys.argv[3]) # 0,1,2,3
  else:
    SEED = 3
  tf.set_random_seed(SEED)
  numpy.random.seed(SEED)
  random.seed(SEED)
  if sys.argv[1] == 'train':
    train(model_type=sys.argv[2], seed=SEED)
  elif sys.argv[1] == 'test':
    test(model_type=sys.argv[2], seed=SEED)
  else:
    print('do nothing...')

11.代码的一种逻辑:time_point 是一个参数变量,可以有两种方案来处理

一种直接在外面判断:

#适用于输出变量的个数不同的情况
if time_point:
A, B, C = f1(x, y, time_point=True)
else:

A, B = f1(x, y, time_point=False)
# 适用于输出变量个数和类型相同的情况
C, D = f2(x, y, time_point=time_point)

12. 写一个 shell 脚本文件来进行调节超参数, 来自 [NIPS-20 Grand]

mkdir cora
for num in $(seq 0 99) do
python train_grand.py --hidden 32 --lr 0.01 --patience 200 --seed $num --dropnode_rate 0.5 > cora/"$num".txt
done

13. 使用 或者 不使用 cuda 运行结果可能会不一样,有细微差别。

cuda 也有一个相关的随机数种子的参数,当不使用 cuda 时,这一个随机数种子没有起到作用,因此可能会得到不同的结果。

来自 NIPS-20 Grand (2020.11.18)的实验结果发现。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中的hashlib和base64加密模块使用实例
Sep 02 Python
总结用Pdb库调试Python的方式及常用的命令
Aug 18 Python
Python入门_浅谈数据结构的4种基本类型
May 16 Python
基于python爬虫数据处理(详解)
Jun 10 Python
Python列表解析配合if else的方法
Jun 23 Python
Python调用C++,通过Pybind11制作Python接口
Oct 16 Python
selenium+python自动化测试之使用webdriver操作浏览器的方法
Jan 23 Python
python multiprocessing多进程变量共享与加锁的实现
Oct 02 Python
python 实现人和电脑猜拳的示例代码
Mar 02 Python
Python爬虫入门有哪些基础知识点
Jun 02 Python
python如何输出反斜杠
Jun 18 Python
Python中的np.argmin()和np.argmax()函数用法
Jun 02 Python
Ubuntu配置Pytorch on Graph (PoG)环境过程图解
Nov 19 #Python
python基于pygame实现飞机大作战小游戏
Nov 19 #Python
Python numpy大矩阵运算内存不足如何解决
Nov 19 #Python
python3 os进行嵌套操作的实例讲解
Nov 19 #Python
如何创建一个Flask项目并进行简单配置
Nov 18 #Python
使用PyCharm官方中文语言包汉化PyCharm
Nov 18 #Python
Python web框架(django,flask)实现mysql数据库读写分离的示例
Nov 18 #Python
You might like
雄兵连:天使彦天使彦为爱折翼,彦和炙心同时念动的誓言!
2020/03/02 国漫
全国FM电台频率大全 - 8 黑龙江省
2020/03/11 无线电
php的GD库imagettftext函数解决中文乱码问题
2015/01/24 PHP
JavaScript的变量作用域深入理解
2009/10/25 Javascript
基于JQuery实现鼠标点击文本框显示隐藏提示文本
2012/02/23 Javascript
往光标所在位置插入值的js代码
2013/09/22 Javascript
JS获取节点的兄弟,父级,子级元素的方法
2014/01/09 Javascript
JS实现一个按钮的方法
2015/02/05 Javascript
js实现带圆角的两级导航菜单效果代码
2015/08/24 Javascript
json+jQuery实现的无限级树形菜单效果代码
2015/08/27 Javascript
JS组件Bootstrap实现弹出框效果代码
2016/04/26 Javascript
JS中如何比较两个Json对象是否相等实例代码
2016/07/13 Javascript
浅谈js之字面量、对象字面量的访问、关键字in的用法
2016/11/20 Javascript
JS实现音乐导航特效
2020/01/06 Javascript
Element InputNumber计数器的使用方法
2020/07/27 Javascript
vue 项目软键盘回车触发搜索事件
2020/09/09 Javascript
[02:36]DOTA2-DPC中国联赛 正赛 PSG.LGD vs Magma 选手采访
2021/03/11 DOTA
详解Python中heapq模块的用法
2016/06/28 Python
详解Python多线程Selenium跨浏览器测试
2017/04/01 Python
python批量查询、汉字去重处理CSV文件
2018/05/31 Python
VSCode基础使用与VSCode调试python程序入门的图文教程
2020/03/30 Python
python基本算法之实现归并排序(Merge sort)
2020/09/01 Python
Kusmi茶美国官网:优质散叶茶和茶包
2019/10/13 全球购物
英国领先的在线高尔夫商店:Gamola Golf
2019/11/16 全球购物
软件测试常见笔试题
2012/02/04 面试题
普通院校学生的自荐信
2013/11/27 职场文书
医学检验专业个人求职信范文
2013/12/04 职场文书
努力学习演讲稿
2014/05/10 职场文书
2014年度培训工作总结
2014/11/27 职场文书
2015年学校总务处工作总结
2015/05/19 职场文书
中秋节感想
2015/08/10 职场文书
谢师宴学生答谢词
2015/09/30 职场文书
MySQL 使用索引扫描进行排序
2021/06/20 MySQL
详解Spring事件发布与监听机制
2021/06/30 Java/Android
css常用字体属性与背景属性介绍
2022/02/28 HTML / CSS
MyBatis XPathParser解析器使用范例详解
2022/07/15 Java/Android