Kaggle竞赛排名Top 6%——泰坦尼克号幸存预测

Kaggle竞赛排名Top 6%——泰坦尼克号幸存预测

第一次参加Kaggle竞赛,选择了入门级的“泰坦尼克号生存预测”。最近在coursera上学习吴恩达的机器学习课程,这个竞赛能够帮助熟悉机器学习的一些基本概念,突然有了一种connecting dots的感觉,废话不多说,下面进入正题~

1. 竞赛内容

通过训练数据集分析出什么类型的人更可能幸存,并预测出测试集中的乘客是否生还。

这项竞赛属于机器学习中的二元分类问题。

2. 赛前准备

在自己实践之前,先在Kaggle上学习了一些教程,推荐教程如下:

Exploring Survival on the Titanic

还参考了另一篇文章:

机器学习(二) 如何做到Kaggle排名前2%

大致有了思路之后,开始导入数据进行实践。

3. 理解数据

train<-read.csv(file.choose(),stringsAsFactors = FALSE)
test<-read.csv(file.choose(),stringsAsFactors = FALSE)
data<-bind_rows(train,test)
train.row<-1:nrow(train)
test.row<-(1+nrow(train)):nrow(data)
str(data)
  • PassengerId: 乘客的ID,对预测没有用处
  • Survived:1代表幸存,2代表遇难,转换为factor方便后续分析
  • Pclass:可代表乘客的社会经济状况:1代表Upper,2代表Middle,3代表Lower
  • Name:除包含姓名外,还包含Title相关信息
  • Sex:性别,转换为factor方便后续分析
  • Age:年龄
  • SibSp:兄弟姐妹及配偶的个数
  • Parch:父母或子女的个数
  • Ticket:船票号
  • Fare:船票价格
  • Cabin:舱位
  • Embarked:登船口岸,转换为factor方便后续分析

4. 缺失值

sapply(data,function(x) sum(is.na(x)))
sapply(data,function(x) sum(x==""))

查看NA和空值的情况

  • Survived有418个缺失值是由于测试集的原因,是正常的
  • Fare有1个缺失值
  • Age有263个缺失值
  • Cabin有1014个缺失值
  • Embarked有2个缺失值

由于Age缺失值较多,需要通过其他的变量来预测,暂时不进行缺失值的填补,先进行特征抽取。

5. 特征工程

都说数据和特征决定了机器学习的上限,而模型和算法只是逼近这个上限而已。我们需要通过探索数据,利用先验知识,从数据中总结出特征。关于什么是特征工程,这篇文章 特征工程到底是什么? - 知乎 做了详细的解释。

1)特征处理和抽取

a. Survived变量:转换数据类型

data$Survived<-factor(data$Survived)

b. PClass变量:转换数据类型,探索与Survived的关系

data$Pclass<-factor(data$Pclass)
ggplot(data = data[1:nrow(train),],mapping = aes(x=Pclass,y=..count..,fill=Survived))+
  geom_bar(stat = "count",position = "dodge")+
  xlab("Pclass")+
  ylab("Count")+
  ggtitle("How Pclass impacts survival")+
  geom_text(stat = "count",aes(label=..count..),position = position_dodge(width = 1),vjust=-.6)+
  theme_few()

可见,Pclass为1的大部分幸存,Pclass为2的接近一半幸存,Pclass为3的小部分幸存

c. Name变量:将变量中有关Title的信息抽取出来

data$Title<-sapply(data$Name,FUN = function(x){trimws(strsplit(x,split = "[,.]")[[1]][2])})
data$Title[data$Title %in% c("Mme","Mlle")]<-"Mlle"
data$Title[data$Title %in% c("Capt","Don","Major","Sir")]<-"Sir"
data$Title[data$Title %in% c("Dona","Lady","the Countess","Jonkheer")]<-"Lady"
data$Title<-factor(data$Title)
ggplot(data = data[1:nrow(train),],mapping = aes(x=Title,y=..count..,fill=Survived))+
  geom_bar(stat = "count",position = "stack")+
  xlab("Title")+
  ylab("Count")+
  ggtitle("How Title impacts Survival")+
  geom_text(stat = "Count",aes(label=..count..),position = position_stack(vjust = 0.5))+
  theme_few()

可见,Title为Mrs和Miss的幸存率比较大,为Mr的幸存率比较小

d. Sex变量:转换数据类型,探索与Survived的关系

data$Sex<-factor(data$Sex)
ggplot(data = data[1:nrow(train),],mapping = aes(x=Sex,y=..count..,fill=Survived))+
  geom_bar(stat = "count",position = "dodge")+
  xlab("Sex")+
  ylab("Count")+
  ggtitle("How Sex impacts Survival")+
  geom_text(stat = "count",aes(label=..count..),position = position_dodge(width = 1),vjust=-.5)+
  theme_few()

可见,女性的幸存率远远大于男性

e. Age变量:探索与Survived的关系

ggplot(data = data[(!is.na(data$Age)) & row(as.matrix(data[,"Age"]))<=891,],aes(x=Age,color=Survived))+
  geom_line(aes(label=..count..),stat = "bin",binwidth=5,na.rm = TRUE)+
  labs(title="How Age impacts Survival",x="Age",y="Count",fill="Survived")

可见,未成年人的幸存率要远远高于青壮年

f. SibSp变量:探索与Survived的关系

ggplot(data = data[1:nrow(train),],mapping = aes(x=SibSp,y=..count..,fill=Survived))+
  geom_bar(stat = "count",position = "dodge")+
  xlab("SibSp")+
  ylab("Count")+
  ggtitle("How SibSp impacts Survival")+
  geom_text(stat = "count",aes(label=..count..),position = position_dodge(width = 1),vjust=-.5)+
  theme_few()

可见,SibSp为1或2的乘客幸存率最高

g. Parch变量:探索与Survived的关系

ggplot(data = data[1:nrow(train),],mapping = aes(x=Parch,y=..count..,fill=Survived))+
  geom_bar(stat = "count",position = "dodge")+
  xlab("Parch")+
  ylab("Count")+
  ggtitle("How Parch impacts Survival")+
  geom_text(stat = "count",aes(label=..count..),position = position_dodge(width = 1),vjust=-.5)+
  theme_few()

可见,Parch为1到3的乘客幸存率最高

h. FamilySize变量:新增变量,探索与Survived的关系

data$FamilySize<-data$Parch+data$SibSp+1
ggplot(data = data[1:nrow(train),],mapping = aes(x=FamilySize,y=..count..,fill=Survived))+
  geom_bar(stat = "count",position = "dodge")+
  xlab("FamilySize")+
  ylab("Count")+
  ggtitle("How FamilySize impacts Survival")+
  geom_text(stat = "count",aes(label=..count..),position = position_dodge(width = 1),vjust=-.5)+
  theme_few()

可见,FamilySize为2到4的乘客幸存率最高

i. Fare变量:探索与Survived的关系

ggplot(data = data[(!is.na(data$Fare)) & row(as.matrix(data[,"Fare"]))<=891,],aes(x=Fare,color=Survived))+
  geom_line(aes(label=..count..),stat = "bin",binwidth=5,na.rm = TRUE)+
  labs(title="How Fare impacts Survival",x="Fare",y="Count",fill="Survived")

可见,票价越高幸存率越高

j. Embarked变量:探索与Survived的关系

ggplot(data = data[1:nrow(train),],mapping = aes(x=Embarked,y=..count..,fill=Survived))+
  geom_bar(stat = "count",position = "dodge")+
  xlab("Embarked")+
  ylab("Count")+
  ggtitle("How Embarked impacts Survival")+
  geom_text(stat = "count",aes(label=..count..),position = position_dodge(width = 1),vjust=-.5)+
  theme_few()

可见,Embarked为C或NA的幸存率最高

k. 填补缺失值

Fare只有1个缺失值,采取中位数填补法

data$Fare[is.na(data$Fare)]<-median(data$Fare,na.rm = TRUE)

Embarked有2个缺失值,先将这两个缺失值对应的乘客信息选取出来

data[is.na(data$Embarked),c("PassengerId","Pclass","Fare","Embarked")]

发现Pclass都为1,Fare都为80。

ggplot(data = data[data$Embarked!="",],aes(x=Embarked,y=Fare,fill=Pclass))+
  geom_boxplot()+
  geom_hline(aes(yintercept=80),color="red",linetype="dashed",lwd=2)+
  theme_few()

发现Embarked为C的Pclass属于1的Fare中位数正好是80,所以将缺失值填补为C

data$Embarked[c(62,830)]<-"C"
data$Embarked<-factor(data$Embarked)

Cabin的缺失值较多,目前来看对预测没有什么意义,先忽略。

最后就是要预测填补Age的缺失值,这里用到了决策树方法,先看一下已有年龄的分布:

hist(data$Age[!is.na(data$Age)],freq = F,main = "Age Distribution")
age.model<-rpart(Age~Pclass+Sex+SibSp+Parch+Fare+Embarked+Title+FamilySize,data = data[!is.na(data$Age),],method = "anova")
data$Age[is.na(data$Age)]<-predict(age.model,data[is.na(data$Age),])
hist(data$Age,freq = F,main = "Age Distribution")

至此,特征都处理抽取结束,接下来就要进行特征选择。

2)特征选择

在以上操作中我们尽可能多地抽取特征,但过多地特征会造成冗余,产生过拟合问题,因此我们需要对特征进行筛选。这里采用最简单的相关性分析。

在上一篇实践分享 Kaggle数据分析实践--优秀员工为何离职 - 知乎专栏 中也提到了相关性分析,这里再补充一下:

相关性分析是衡量两个变量之间的线性关系,数值在[-1.0,1.0]之间。数值越接近0,说明两个变量越线性不相关。

最终,选取了如下变量,开始建模。

6. 建模

这里用的模型是随机森林,一开始使用的是random forest,后来换了cforest,发现排名有提升。关于两者的区别,可以参考这篇文章:【R语言进行数据挖掘】决策树和随机森林 - 推酷

set.seed(102)
model<-cforest(Survived~Pclass+Title+Sex+Age+SibSp+Parch+FamilySize+Fare+Embarked,data=data[train.row,],controls = cforest_unbiased(ntree=2000,mtry=3))
prediction<-predict(model,data[test.row,],OOB=TRUE,type = "response") 
output<-data.frame(PassengerId=test$PassengerId,Survived=prediction)
write.csv(output,file = "Prediction1.csv",row.names = FALSE)

这里应该还要做模型验证的,但由于还不是很懂这部分的知识,暂时略过了。

最终将预测结果上传,Kaggle排名还不错,不过也是深度模仿的结果~

这次实践花了较长时间,很多概念也是第一次碰到,查阅了不少文章,希望在以后的实践中能够灵活运用~

编辑于 2017-07-02

文章被以下专栏收录