{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Импортируем нужные библиотеки."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import xgboost as xgb"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Распакуйте архив с данными в папку,где находится этот jupyter notebook (baseline.ipynb). У вас будет папка data, содержащая необходимые файлы. \n",
"\n",
"В данном соревновании перед вами ставится задача предсказания категории возраста, к которой принадлежит клиент банка, на основании его транзакций.\n",
"В обучающем наборе содержатся информация по транзакциям 30000 клиентов, она находится в файле **transactions_train.csv**. Правильная категория возраста для обучающего набора находится в файле **train_target.csv**."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Считаем данные по транзакциям и правильные ответы."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"transactions_train=pd.read_csv('data/transactions_train.csv')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"train_target=pd.read_csv('data/train_target.csv')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Посмотрим на данные."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" client_id | \n",
" trans_date | \n",
" small_group | \n",
" amount_rur | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 33172 | \n",
" 6 | \n",
" 4 | \n",
" 71.463 | \n",
"
\n",
" \n",
" 1 | \n",
" 33172 | \n",
" 6 | \n",
" 35 | \n",
" 45.017 | \n",
"
\n",
" \n",
" 2 | \n",
" 33172 | \n",
" 8 | \n",
" 11 | \n",
" 13.887 | \n",
"
\n",
" \n",
" 3 | \n",
" 33172 | \n",
" 9 | \n",
" 11 | \n",
" 15.983 | \n",
"
\n",
" \n",
" 4 | \n",
" 33172 | \n",
" 10 | \n",
" 11 | \n",
" 21.341 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" client_id trans_date small_group amount_rur\n",
"0 33172 6 4 71.463\n",
"1 33172 6 35 45.017\n",
"2 33172 8 11 13.887\n",
"3 33172 9 11 15.983\n",
"4 33172 10 11 21.341"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"transactions_train.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* client_id - уникальный идентификатор клиента\n",
"* trans_date - дата совершения транзакции\n",
"* small_group - категория покупки\n",
"* amount_rur - сумма транзакции"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" client_id | \n",
" bins | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 24662 | \n",
" 2 | \n",
"
\n",
" \n",
" 1 | \n",
" 1046 | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" 34089 | \n",
" 2 | \n",
"
\n",
" \n",
" 3 | \n",
" 34848 | \n",
" 1 | \n",
"
\n",
" \n",
" 4 | \n",
" 47076 | \n",
" 3 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" client_id bins\n",
"0 24662 2\n",
"1 1046 0\n",
"2 34089 2\n",
"3 34848 1\n",
"4 47076 3"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_target.head(5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* client_id - уникальный идентификатор клиента, соответствует полю client_id из транзакций\n",
"* bins - целевая переменная, которую нужно предсказать, это категория возраста клиента"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Далее представлен простой вариант решения задачи. Вы можете решать соревнование используя совершенно другой подход."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Посчитаем по каждому клиенту самые простые аггрегационные признаки."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"agg_features=transactions_train.groupby('client_id')['amount_rur'].agg(['sum','mean','std','min','max']).reset_index()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" client_id | \n",
" sum | \n",
" mean | \n",
" std | \n",
" min | \n",
" max | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 4 | \n",
" 28404.121 | \n",
" 39.450168 | \n",
" 73.511624 | \n",
" 0.043 | \n",
" 1341.802 | \n",
"
\n",
" \n",
" 1 | \n",
" 6 | \n",
" 15720.739 | \n",
" 21.535259 | \n",
" 26.200397 | \n",
" 0.045 | \n",
" 315.781 | \n",
"
\n",
" \n",
" 2 | \n",
" 7 | \n",
" 53630.036 | \n",
" 69.379089 | \n",
" 253.261383 | \n",
" 0.043 | \n",
" 4505.971 | \n",
"
\n",
" \n",
" 3 | \n",
" 10 | \n",
" 34419.365 | \n",
" 48.752642 | \n",
" 63.191701 | \n",
" 0.045 | \n",
" 654.893 | \n",
"
\n",
" \n",
" 4 | \n",
" 11 | \n",
" 26789.404 | \n",
" 32.991877 | \n",
" 107.395139 | \n",
" 0.388 | \n",
" 2105.058 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" client_id sum mean std min max\n",
"0 4 28404.121 39.450168 73.511624 0.043 1341.802\n",
"1 6 15720.739 21.535259 26.200397 0.045 315.781\n",
"2 7 53630.036 69.379089 253.261383 0.043 4505.971\n",
"3 10 34419.365 48.752642 63.191701 0.045 654.893\n",
"4 11 26789.404 32.991877 107.395139 0.388 2105.058"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agg_features.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Посчитаем для каждого клиента количество транзакций по каждой категории."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"counter_df_train=transactions_train.groupby(['client_id','small_group'])['amount_rur'].count()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"cat_counts_train=counter_df_train.reset_index().pivot(index='client_id', \\\n",
" columns='small_group',values='amount_rur')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"cat_counts_train=cat_counts_train.fillna(0)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"cat_counts_train.columns=['small_group_'+str(i) for i in cat_counts_train.columns]"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" small_group_0 | \n",
" small_group_1 | \n",
" small_group_2 | \n",
" small_group_3 | \n",
" small_group_4 | \n",
" small_group_5 | \n",
" small_group_6 | \n",
" small_group_7 | \n",
" small_group_8 | \n",
" small_group_9 | \n",
" ... | \n",
" small_group_192 | \n",
" small_group_193 | \n",
" small_group_195 | \n",
" small_group_196 | \n",
" small_group_197 | \n",
" small_group_198 | \n",
" small_group_199 | \n",
" small_group_200 | \n",
" small_group_202 | \n",
" small_group_203 | \n",
"
\n",
" \n",
" client_id | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" 4 | \n",
" 0.0 | \n",
" 447.0 | \n",
" 1.0 | \n",
" 44.0 | \n",
" 93.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 13.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 6 | \n",
" 2.0 | \n",
" 397.0 | \n",
" 0.0 | \n",
" 172.0 | \n",
" 10.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 6.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 7 | \n",
" 2.0 | \n",
" 79.0 | \n",
" 5.0 | \n",
" 27.0 | \n",
" 19.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 2.0 | \n",
" 1.0 | \n",
" 39.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 10 | \n",
" 12.0 | \n",
" 309.0 | \n",
" 1.0 | \n",
" 71.0 | \n",
" 65.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 3.0 | \n",
" 19.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 11 | \n",
" 2.0 | \n",
" 423.0 | \n",
" 0.0 | \n",
" 59.0 | \n",
" 23.0 | \n",
" 3.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 10.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
"
\n",
"
5 rows × 202 columns
\n",
"
"
],
"text/plain": [
" small_group_0 small_group_1 small_group_2 small_group_3 \\\n",
"client_id \n",
"4 0.0 447.0 1.0 44.0 \n",
"6 2.0 397.0 0.0 172.0 \n",
"7 2.0 79.0 5.0 27.0 \n",
"10 12.0 309.0 1.0 71.0 \n",
"11 2.0 423.0 0.0 59.0 \n",
"\n",
" small_group_4 small_group_5 small_group_6 small_group_7 \\\n",
"client_id \n",
"4 93.0 0.0 0.0 0.0 \n",
"6 10.0 0.0 0.0 0.0 \n",
"7 19.0 1.0 0.0 2.0 \n",
"10 65.0 0.0 0.0 0.0 \n",
"11 23.0 3.0 0.0 0.0 \n",
"\n",
" small_group_8 small_group_9 ... small_group_192 \\\n",
"client_id ... \n",
"4 1.0 13.0 ... 0.0 \n",
"6 0.0 6.0 ... 0.0 \n",
"7 1.0 39.0 ... 0.0 \n",
"10 3.0 19.0 ... 0.0 \n",
"11 0.0 10.0 ... 0.0 \n",
"\n",
" small_group_193 small_group_195 small_group_196 small_group_197 \\\n",
"client_id \n",
"4 0.0 0.0 0.0 0.0 \n",
"6 0.0 0.0 0.0 0.0 \n",
"7 0.0 0.0 0.0 0.0 \n",
"10 0.0 0.0 0.0 0.0 \n",
"11 0.0 0.0 0.0 0.0 \n",
"\n",
" small_group_198 small_group_199 small_group_200 small_group_202 \\\n",
"client_id \n",
"4 0.0 0.0 0.0 0.0 \n",
"6 0.0 0.0 0.0 0.0 \n",
"7 0.0 0.0 0.0 0.0 \n",
"10 0.0 0.0 0.0 0.0 \n",
"11 0.0 0.0 0.0 0.0 \n",
"\n",
" small_group_203 \n",
"client_id \n",
"4 0.0 \n",
"6 0.0 \n",
"7 0.0 \n",
"10 0.0 \n",
"11 0.0 \n",
"\n",
"[5 rows x 202 columns]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cat_counts_train.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Далее соединим все файлы в один датафрейм с таргетом."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"train=pd.merge(train_target,agg_features,on='client_id')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"train=pd.merge(train,cat_counts_train.reset_index(),on='client_id')"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" client_id | \n",
" bins | \n",
" sum | \n",
" mean | \n",
" std | \n",
" min | \n",
" max | \n",
" small_group_0 | \n",
" small_group_1 | \n",
" small_group_2 | \n",
" ... | \n",
" small_group_192 | \n",
" small_group_193 | \n",
" small_group_195 | \n",
" small_group_196 | \n",
" small_group_197 | \n",
" small_group_198 | \n",
" small_group_199 | \n",
" small_group_200 | \n",
" small_group_202 | \n",
" small_group_203 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 24662 | \n",
" 2 | \n",
" 30254.011 | \n",
" 34.774725 | \n",
" 72.037354 | \n",
" 0.074 | \n",
" 1227.314 | \n",
" 0.0 | \n",
" 174.0 | \n",
" 2.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 1 | \n",
" 1046 | \n",
" 0 | \n",
" 42548.570 | \n",
" 52.015367 | \n",
" 106.540962 | \n",
" 0.550 | \n",
" 1210.506 | \n",
" 1.0 | \n",
" 187.0 | \n",
" 61.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 2 | \n",
" 34089 | \n",
" 2 | \n",
" 26842.816 | \n",
" 34.325852 | \n",
" 59.927450 | \n",
" 0.043 | \n",
" 782.641 | \n",
" 0.0 | \n",
" 372.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 3 | \n",
" 34848 | \n",
" 1 | \n",
" 15773.126 | \n",
" 16.160990 | \n",
" 14.224936 | \n",
" 0.043 | \n",
" 109.590 | \n",
" 0.0 | \n",
" 359.0 | \n",
" 1.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 4 | \n",
" 47076 | \n",
" 3 | \n",
" 12488.375 | \n",
" 15.929050 | \n",
" 35.473591 | \n",
" 0.432 | \n",
" 541.165 | \n",
" 0.0 | \n",
" 378.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
"
\n",
"
5 rows × 209 columns
\n",
"
"
],
"text/plain": [
" client_id bins sum mean std min max \\\n",
"0 24662 2 30254.011 34.774725 72.037354 0.074 1227.314 \n",
"1 1046 0 42548.570 52.015367 106.540962 0.550 1210.506 \n",
"2 34089 2 26842.816 34.325852 59.927450 0.043 782.641 \n",
"3 34848 1 15773.126 16.160990 14.224936 0.043 109.590 \n",
"4 47076 3 12488.375 15.929050 35.473591 0.432 541.165 \n",
"\n",
" small_group_0 small_group_1 small_group_2 ... \\\n",
"0 0.0 174.0 2.0 ... \n",
"1 1.0 187.0 61.0 ... \n",
"2 0.0 372.0 0.0 ... \n",
"3 0.0 359.0 1.0 ... \n",
"4 0.0 378.0 0.0 ... \n",
"\n",
" small_group_192 small_group_193 small_group_195 small_group_196 \\\n",
"0 0.0 0.0 0.0 0.0 \n",
"1 0.0 0.0 0.0 0.0 \n",
"2 0.0 0.0 0.0 0.0 \n",
"3 0.0 0.0 0.0 0.0 \n",
"4 0.0 0.0 0.0 0.0 \n",
"\n",
" small_group_197 small_group_198 small_group_199 small_group_200 \\\n",
"0 0.0 0.0 0.0 0.0 \n",
"1 0.0 0.0 0.0 0.0 \n",
"2 0.0 0.0 0.0 0.0 \n",
"3 0.0 0.0 0.0 0.0 \n",
"4 0.0 0.0 0.0 0.0 \n",
"\n",
" small_group_202 small_group_203 \n",
"0 0.0 0.0 \n",
"1 0.0 0.0 \n",
"2 0.0 0.0 \n",
"3 0.0 0.0 \n",
"4 0.0 0.0 \n",
"\n",
"[5 rows x 209 columns]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Теперь подгрузим тестовые данные для того, чтобы сделать предсказание. Проделаем с ними те же самые манипуляции, как и с обучающими данными."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"transactions_test=pd.read_csv('data/transactions_test.csv')\n",
"\n",
"test_id=pd.read_csv('data/test.csv')"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"agg_features_test=transactions_test.groupby('client_id')['amount_rur'].agg(['sum','mean','std','min','max']).reset_index()"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" client_id | \n",
" sum | \n",
" mean | \n",
" std | \n",
" min | \n",
" max | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0 | \n",
" 17036.127 | \n",
" 19.163247 | \n",
" 40.561700 | \n",
" 0.065 | \n",
" 595.339 | \n",
"
\n",
" \n",
" 1 | \n",
" 1 | \n",
" 34748.964 | \n",
" 47.666617 | \n",
" 89.489278 | \n",
" 0.298 | \n",
" 1181.221 | \n",
"
\n",
" \n",
" 2 | \n",
" 2 | \n",
" 51188.069 | \n",
" 68.433247 | \n",
" 152.093601 | \n",
" 0.043 | \n",
" 2837.682 | \n",
"
\n",
" \n",
" 3 | \n",
" 3 | \n",
" 47975.203 | \n",
" 67.857430 | \n",
" 318.651653 | \n",
" 0.043 | \n",
" 6135.652 | \n",
"
\n",
" \n",
" 4 | \n",
" 5 | \n",
" 20059.100 | \n",
" 21.803370 | \n",
" 34.258433 | \n",
" 0.043 | \n",
" 439.902 | \n",
"
\n",
" \n",
" 5 | \n",
" 8 | \n",
" 56077.096 | \n",
" 64.754152 | \n",
" 140.806402 | \n",
" 0.603 | \n",
" 1432.934 | \n",
"
\n",
" \n",
" 6 | \n",
" 9 | \n",
" 32317.239 | \n",
" 36.270751 | \n",
" 97.376195 | \n",
" 0.043 | \n",
" 2472.396 | \n",
"
\n",
" \n",
" 7 | \n",
" 15 | \n",
" 20394.051 | \n",
" 20.558519 | \n",
" 23.800658 | \n",
" 0.432 | \n",
" 243.099 | \n",
"
\n",
" \n",
" 8 | \n",
" 16 | \n",
" 24881.653 | \n",
" 35.193286 | \n",
" 82.298739 | \n",
" 0.906 | \n",
" 1401.336 | \n",
"
\n",
" \n",
" 9 | \n",
" 21 | \n",
" 23818.108 | \n",
" 29.735466 | \n",
" 75.615020 | \n",
" 0.229 | \n",
" 1341.481 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" client_id sum mean std min max\n",
"0 0 17036.127 19.163247 40.561700 0.065 595.339\n",
"1 1 34748.964 47.666617 89.489278 0.298 1181.221\n",
"2 2 51188.069 68.433247 152.093601 0.043 2837.682\n",
"3 3 47975.203 67.857430 318.651653 0.043 6135.652\n",
"4 5 20059.100 21.803370 34.258433 0.043 439.902\n",
"5 8 56077.096 64.754152 140.806402 0.603 1432.934\n",
"6 9 32317.239 36.270751 97.376195 0.043 2472.396\n",
"7 15 20394.051 20.558519 23.800658 0.432 243.099\n",
"8 16 24881.653 35.193286 82.298739 0.906 1401.336\n",
"9 21 23818.108 29.735466 75.615020 0.229 1341.481"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agg_features_test.head()"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"counter_df_test=transactions_test.groupby(['client_id','small_group'])['amount_rur'].count()"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"cat_counts_test=counter_df_test.reset_index().pivot(index='client_id', columns='small_group',values='amount_rur')"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"cat_counts_test=cat_counts_test.fillna(0)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"cat_counts_test.columns=['small_group_'+str(i) for i in cat_counts_test.columns]"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" small_group_0 | \n",
" small_group_1 | \n",
" small_group_2 | \n",
" small_group_3 | \n",
" small_group_4 | \n",
" small_group_5 | \n",
" small_group_6 | \n",
" small_group_7 | \n",
" small_group_8 | \n",
" small_group_9 | \n",
" ... | \n",
" small_group_192 | \n",
" small_group_193 | \n",
" small_group_194 | \n",
" small_group_195 | \n",
" small_group_196 | \n",
" small_group_197 | \n",
" small_group_198 | \n",
" small_group_200 | \n",
" small_group_201 | \n",
" small_group_202 | \n",
"
\n",
" \n",
" client_id | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.0 | \n",
" 226.0 | \n",
" 1.0 | \n",
" 36.0 | \n",
" 9.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 2.0 | \n",
" 20.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 1 | \n",
" 30.0 | \n",
" 326.0 | \n",
" 0.0 | \n",
" 40.0 | \n",
" 56.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 60.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 2 | \n",
" 21.0 | \n",
" 242.0 | \n",
" 1.0 | \n",
" 50.0 | \n",
" 48.0 | \n",
" 4.0 | \n",
" 0.0 | \n",
" 6.0 | \n",
" 1.0 | \n",
" 21.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 3 | \n",
" 0.0 | \n",
" 156.0 | \n",
" 83.0 | \n",
" 48.0 | \n",
" 31.0 | \n",
" 2.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 2.0 | \n",
" 27.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 5 | \n",
" 16.0 | \n",
" 398.0 | \n",
" 1.0 | \n",
" 23.0 | \n",
" 25.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 5.0 | \n",
" 29.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 8 | \n",
" 29.0 | \n",
" 296.0 | \n",
" 9.0 | \n",
" 114.0 | \n",
" 76.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 2.0 | \n",
" 9.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 9 | \n",
" 35.0 | \n",
" 222.0 | \n",
" 110.0 | \n",
" 60.0 | \n",
" 17.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 3.0 | \n",
" 6.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 15 | \n",
" 0.0 | \n",
" 398.0 | \n",
" 0.0 | \n",
" 29.0 | \n",
" 17.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 9.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 16 | \n",
" 0.0 | \n",
" 288.0 | \n",
" 2.0 | \n",
" 21.0 | \n",
" 68.0 | \n",
" 1.0 | \n",
" 2.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 6.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 21 | \n",
" 0.0 | \n",
" 185.0 | \n",
" 25.0 | \n",
" 246.0 | \n",
" 21.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 2.0 | \n",
" 9.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
"
\n",
"
10 rows × 202 columns
\n",
"
"
],
"text/plain": [
" small_group_0 small_group_1 small_group_2 small_group_3 \\\n",
"client_id \n",
"0 0.0 226.0 1.0 36.0 \n",
"1 30.0 326.0 0.0 40.0 \n",
"2 21.0 242.0 1.0 50.0 \n",
"3 0.0 156.0 83.0 48.0 \n",
"5 16.0 398.0 1.0 23.0 \n",
"8 29.0 296.0 9.0 114.0 \n",
"9 35.0 222.0 110.0 60.0 \n",
"15 0.0 398.0 0.0 29.0 \n",
"16 0.0 288.0 2.0 21.0 \n",
"21 0.0 185.0 25.0 246.0 \n",
"\n",
" small_group_4 small_group_5 small_group_6 small_group_7 \\\n",
"client_id \n",
"0 9.0 0.0 0.0 0.0 \n",
"1 56.0 0.0 0.0 0.0 \n",
"2 48.0 4.0 0.0 6.0 \n",
"3 31.0 2.0 0.0 1.0 \n",
"5 25.0 0.0 0.0 0.0 \n",
"8 76.0 0.0 0.0 0.0 \n",
"9 17.0 0.0 0.0 0.0 \n",
"15 17.0 0.0 0.0 0.0 \n",
"16 68.0 1.0 2.0 0.0 \n",
"21 21.0 0.0 0.0 0.0 \n",
"\n",
" small_group_8 small_group_9 ... small_group_192 \\\n",
"client_id ... \n",
"0 2.0 20.0 ... 0.0 \n",
"1 0.0 60.0 ... 0.0 \n",
"2 1.0 21.0 ... 0.0 \n",
"3 2.0 27.0 ... 0.0 \n",
"5 5.0 29.0 ... 0.0 \n",
"8 2.0 9.0 ... 0.0 \n",
"9 3.0 6.0 ... 0.0 \n",
"15 0.0 9.0 ... 0.0 \n",
"16 0.0 6.0 ... 0.0 \n",
"21 2.0 9.0 ... 0.0 \n",
"\n",
" small_group_193 small_group_194 small_group_195 small_group_196 \\\n",
"client_id \n",
"0 0.0 0.0 0.0 0.0 \n",
"1 0.0 0.0 0.0 0.0 \n",
"2 0.0 0.0 0.0 0.0 \n",
"3 0.0 0.0 0.0 0.0 \n",
"5 0.0 0.0 0.0 0.0 \n",
"8 0.0 0.0 0.0 0.0 \n",
"9 0.0 0.0 0.0 0.0 \n",
"15 0.0 0.0 0.0 0.0 \n",
"16 0.0 0.0 0.0 0.0 \n",
"21 0.0 0.0 0.0 0.0 \n",
"\n",
" small_group_197 small_group_198 small_group_200 small_group_201 \\\n",
"client_id \n",
"0 0.0 0.0 0.0 0.0 \n",
"1 0.0 0.0 0.0 0.0 \n",
"2 0.0 0.0 0.0 0.0 \n",
"3 0.0 0.0 0.0 0.0 \n",
"5 0.0 0.0 0.0 0.0 \n",
"8 0.0 0.0 0.0 0.0 \n",
"9 0.0 0.0 0.0 0.0 \n",
"15 0.0 0.0 0.0 0.0 \n",
"16 0.0 0.0 0.0 0.0 \n",
"21 0.0 0.0 0.0 0.0 \n",
"\n",
" small_group_202 \n",
"client_id \n",
"0 0.0 \n",
"1 0.0 \n",
"2 0.0 \n",
"3 0.0 \n",
"5 0.0 \n",
"8 0.0 \n",
"9 0.0 \n",
"15 0.0 \n",
"16 0.0 \n",
"21 0.0 \n",
"\n",
"[10 rows x 202 columns]"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cat_counts_test.head()"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"test=pd.merge(test_id,agg_features_test,on='client_id')"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"test=pd.merge(test,cat_counts_test.reset_index(),on='client_id')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В тесте не было некоторых категорий трат, поэтому для того, чтобы обучить модель, нам нужно объединить пространство признаков и train и test."
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"common_features=list(set(train.columns).intersection(set(test.columns)))"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"y_train=train['bins']\n",
"X_train=train[common_features]\n",
"X_test=test[common_features]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучим xgboost на текущих признаках."
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"param={'objective':'multi:softprob','num_class':4,'n_jobs':4,'seed':42}"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 9min 9s, sys: 2.8 s, total: 9min 12s\n",
"Wall time: 2min 33s\n"
]
}
],
"source": [
"%%time\n",
"model=xgb.XGBClassifier(**param,n_estimators=300)\n",
"model.fit(X_train,y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Сделаем предсказание."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"pred=model.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0, 2, 3, ..., 2, 2, 3])"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pred"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"На public лидерборде такое предсказание должно дать качество 0.6118."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Подготовим файл для отправки в систему"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" bins | \n",
"
\n",
" \n",
" client_id | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" 28571 | \n",
" 0 | \n",
"
\n",
" \n",
" 27046 | \n",
" 2 | \n",
"
\n",
" \n",
" 13240 | \n",
" 3 | \n",
"
\n",
" \n",
" 19974 | \n",
" 0 | \n",
"
\n",
" \n",
" 10505 | \n",
" 1 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" bins\n",
"client_id \n",
"28571 0\n",
"27046 2\n",
"13240 3\n",
"19974 0\n",
"10505 1"
]
},
"execution_count": 90,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"submission = pd.DataFrame({'bins': pred}, index=test.client_id)\n",
"submission.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Сохраняем прогноз на диск в папку submissions. Имя прогноза соответсвует дате и времени его создания, закодированными с помощью timestamp."
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"submissions/1573837933.csv\n"
]
}
],
"source": [
"import time\n",
"import os\n",
"\n",
"current_timestamp = int(time.time())\n",
"submission_path = 'submissions/{}.csv'.format(current_timestamp)\n",
"\n",
"if not os.path.exists('submissions'):\n",
" os.makedirs('submissions')\n",
"\n",
"print(submission_path)\n",
"submission.to_csv(submission_path, index=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Теперь все готово! Можно отправлять решение."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [Root]",
"language": "python",
"name": "Python [Root]"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}