Python实现EM算法实例代码


Posted in Python onOctober 04, 2020

EM算法实例

通过实例可以快速了解EM算法的基本思想,具体推导请点文末链接。图a是让我们预热的,图b是EM算法的实例。

这是一个抛硬币的例子,H表示正面向上,T表示反面向上,参数θ表示正面朝上的概率。硬币有两个,A和B,硬币是有偏的。本次实验总共做了5组,每组随机选一个硬币,连续抛10次。如果知道每次抛的是哪个硬币,那么计算参数θ就非常简单了,如

下图所示:

Python实现EM算法实例代码

如果不知道每次抛的是哪个硬币呢?那么,我们就需要用EM算法,基本步骤为:

  1、给θ_AθA​和θ_BθB​一个初始值;

  2、(E-step)估计每组实验是硬币A的概率(本组实验是硬币B的概率=1-本组实验是硬币A的概率)。分别计算每组实验中,选择A硬币且正面朝上次数的期望值,选择B硬币且正面朝上次数的期望值;

  3、(M-step)利用第三步求得的期望值重新计算θ_AθA​和θ_BθB​;

  4、当迭代到一定次数,或者算法收敛到一定精度,结束算法,否则,回到第2步。

Python实现EM算法实例代码

计算过程详解:初始值θ_A^{(0)}θA(0)​=0.6,θ_B^{(0)}θB(0)​=0.5。

由两个硬币的初始值0.6和0.5,容易得出投掷出5正5反的概率是p_A=C^5_{10}*(0.6^5)*(0.4^5)pA​=C105​∗(0.65)∗(0.45),p_B=C_{10}^5*(0.5^5)*(0.5^5)pB​=C105​∗(0.55)∗(0.55), p_ApA​/(p_ApA​+p_BpB​)=0.449, 0.45就是0.449近似而来的,表示第一组实验选择的硬币是A的概率为0.45。然后,0.449 * 5H = 2.2H ,0.449 * 5T = 2.2T ,表示第一组实验选择A硬币且正面朝上次数和反面朝上次数的期望值都是2.2,其他的值依次类推。最后,求出θ_A^{(1)}θA(1)​=0.71,θ_B^{(1)}θB(1)​=0.58。重复上述过程,不断迭代,直到算法收敛到一定精度为止。

这篇博客对EM算法的推导非常详细,链接如下:

https://blog.csdn.net/zhihua_oba/article/details/73776553

Python实现

#coding=utf-8
from numpy import *
from scipy import stats
import time
start = time.perf_counter()

def em_single(priors,observations):
 """
 EM算法的单次迭代
 Arguments
 ------------
 priors:[theta_A,theta_B]
 observation:[m X n matrix]

 Returns
 ---------------
 new_priors:[new_theta_A,new_theta_B]
 :param priors:
 :param observations:
 :return:
 """
 counts = {'A': {'H': 0, 'T': 0}, 'B': {'H': 0, 'T': 0}}
 theta_A = priors[0]
 theta_B = priors[1]
 #E step
 for observation in observations:
  len_observation = len(observation)
  num_heads = observation.sum()
  num_tails = len_observation-num_heads
  #二项分布求解公式
  contribution_A = stats.binom.pmf(num_heads,len_observation,theta_A)
  contribution_B = stats.binom.pmf(num_heads,len_observation,theta_B)

  weight_A = contribution_A / (contribution_A + contribution_B)
  weight_B = contribution_B / (contribution_A + contribution_B)
  #更新在当前参数下A,B硬币产生的正反面次数
  counts['A']['H'] += weight_A * num_heads
  counts['A']['T'] += weight_A * num_tails
  counts['B']['H'] += weight_B * num_heads
  counts['B']['T'] += weight_B * num_tails

 # M step
 new_theta_A = counts['A']['H'] / (counts['A']['H'] + counts['A']['T'])
 new_theta_B = counts['B']['H'] / (counts['B']['H'] + counts['B']['T'])
 return [new_theta_A,new_theta_B]


def em(observations,prior,tol = 1e-6,iterations=10000):
 """
 EM算法
 :param observations :观测数据
 :param prior:模型初值
 :param tol:迭代结束阈值
 :param iterations:最大迭代次数
 :return:局部最优的模型参数
 """
 iteration = 0;
 while iteration < iterations:
  new_prior = em_single(prior,observations)
  delta_change = abs(prior[0]-new_prior[0])
  if delta_change < tol:
   break
  else:
   prior = new_prior
   iteration +=1
 return [new_prior,iteration]

#硬币投掷结果
observations = array([[1,0,0,0,1,1,0,1,0,1],
      [1,1,1,1,0,1,1,1,0,1],
      [1,0,1,1,1,1,1,0,1,1],
      [1,0,1,0,0,0,1,1,0,0],
      [0,1,1,1,0,1,1,1,0,1]])
print (em(observations,[0.6,0.5]))
end = time.perf_counter()
print('Running time: %f seconds'%(end-start))

总结

到此这篇关于Python实现EM算法实例的文章就介绍到这了,更多相关Python实现EM算法实例内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python函数参数*args**kwargs用法实例
Dec 04 Python
python爬虫常用的模块分析
Aug 29 Python
python进阶教程之异常处理
Aug 30 Python
Python写的服务监控程序实例
Jan 31 Python
讲解Python中运算符使用时的优先级
May 14 Python
python清除字符串里非字母字符的方法
Jul 02 Python
python实现本地图片转存并重命名的示例代码
Oct 27 Python
python实现电子产品商店
Feb 26 Python
详解Python正则表达式re模块
Mar 19 Python
详解Django-channels 实现WebSocket实例
Aug 22 Python
Python内置类型性能分析过程实例
Jan 29 Python
PyCharm 在Windows的有用快捷键详解
Apr 07 Python
python em算法的实现
Oct 03 #Python
浅析Python中字符串的intern机制
Oct 03 #Python
Python实现AES加密,解密的两种方法
Oct 03 #Python
python实现AdaBoost算法的示例
Oct 03 #Python
Django创建一个后台的基本步骤记录
Oct 02 #Python
Python中qutip用法示例详解
Oct 02 #Python
如何利用Python给自己的头像加一个小国旗(小月饼)
Oct 02 #Python
You might like
PHP 增加了对 .ZIP 文件的读取功能
2006/10/09 PHP
php文件类型MIME对照表(比较全)
2016/10/07 PHP
详解Yii实现分页的两种方法
2017/01/14 PHP
PHP类的自动加载与命名空间用法实例分析
2020/06/05 PHP
学习ExtJS border布局
2009/10/08 Javascript
JQuery 将元素显示在屏幕的中央的代码
2010/02/27 Javascript
分享27个jQuery 表单插件集合推荐
2011/04/25 Javascript
使用UglifyJS合并/压缩JavaScript的方法
2012/03/07 Javascript
js中settimeout方法加参数的使用实例
2014/02/27 Javascript
javascript + jquery实现定时修改文章标题
2014/03/19 Javascript
iframe中使用jquery进行查找的方法【案例分析】
2016/06/17 Javascript
EasyUI中在表单提交之前进行验证
2016/07/19 Javascript
JavaScript函数中的this四种绑定形式
2017/08/15 Javascript
详解webpack3如何正确引用并使用jQuery库
2017/08/26 jQuery
JS路由跳转的简单实现代码
2017/09/21 Javascript
详解vuex中mapState,mapGetters,mapMutations,mapActions的作用
2018/04/13 Javascript
JavaScript寄生组合式继承原理与用法分析
2019/01/11 Javascript
layui实现下拉复选功能的例子(包括数据的回显与上传)
2019/09/24 Javascript
讲解Python中if语句的嵌套用法
2015/05/14 Python
从零开始学Python第八周:详解网络编程基础(socket)
2016/12/14 Python
Django框架的使用教程路由请求响应的方法
2018/07/03 Python
对python mayavi三维绘图的实现详解
2019/01/08 Python
wxPython电子表格功能wx.grid实例教程
2019/11/19 Python
基于MSELoss()与CrossEntropyLoss()的区别详解
2020/01/02 Python
在python tkinter界面中添加按钮的实例
2020/03/04 Python
Python2.7:使用Pyhook模块监听鼠标键盘事件-获取坐标实例
2020/03/14 Python
python如何保存文本文件
2020/06/07 Python
html5表单及新增的改良元素详解
2016/06/07 HTML / CSS
审核会计岗位职责
2013/11/08 职场文书
计算机专业大学生的自我评价
2013/11/14 职场文书
安全生产承诺书
2014/03/26 职场文书
关于责任的演讲稿
2014/05/20 职场文书
电子商务系毕业生自荐信
2014/05/29 职场文书
元旦联欢会策划方案
2014/06/11 职场文书
2015年幼儿教育工作总结
2015/07/24 职场文书
mysql 数据插入优化方法之concurrent_insert
2021/07/01 MySQL