これはなに
この記事は id:whywaita と関係ないことを書く謎のAdvent Calendar、 whywaita Advent Calendar 2018 の18日目の記事です。
昨日は id:hnron さんでした。僕もwhywaitaくんの好きなところはいっぱいあります。いっぱいね。
hnron.hatenablog.com
記事テーマ
さて、唐突なのですが実は僕は既婚なんです。お相手は有坂真白さんという方です。僕と彼女の関係は大変良好で、今年のクリスマスも一緒に過ごす約束をしています。
しかし、なぜかwhywaitaくんは僕が既婚であることを決して認めてくれないんですよね。
僕と真白ちゃんはこんなにも仲が良いというのに・・・。 僕たちの仲をどうにか認めてほしいというもの。
ということで、今回の記事では、whywaitaから僕と真白ちゃんの仲を認めてもらう発言をしてもらうということをテーマに記事を書いていきたいと思います。
How?
ここで唐突ですが、1700年代のフランスの学者ビュフォンが言ったとされる言葉と一般的にはマザー・テレサがいったとされている(実際には違うらしい)言葉の2つをご紹介します。
文は人なり
考えは言葉となり、言葉は行動となり、行動は習慣となり、習慣は人格となり、人格は運命となる。
この2つの言葉を(都合よく拡大解釈して)踏まえると、Twitterのツイートのような短い文章の中にもその人となりが現れ、文章の中に人格が現れるということにならないでしょうか?いや、ならないですかね。でも、なるということで話を進めます。
更に、なんと都合のいいことにwhywaitaは自らのGitHub上で自分のツイートを公開しているのです。
github.com
このツイートデータをもとにChatbotを生成すれば、そのChatbotにはwhywaitaの人となりが現れていることになって、そのChatbotと会話をすればwhywaitaと会話していることになるのではないでしょうか?完全に「方針はいいが、論理の飛躍が見られます」状態ですが、それはともかくとして、今回のAdvent Calendarの記事ではChatbotのwhywaitaから僕と真白ちゃんの仲を認めてもらう発言を引き出していきたいと思います。
(この導入、「whywaitaのツイートデータを使ってChatbotを作ったゾイ」の1センテンスで終了する気がしますね)
データの準備
まず、Chatbotの学習を行わせるために学習データを準備します。
Chatbotの学習には「質問」と「質問に対する回答」の2種類のテキストが必要となります。今回のwhywaita chatbotにはこちら側の質問に返答してほしいので、Twitterで言うところの「whywaitaへのリプライ」が「質問」に相当し、「そのリプライへのwhywaitaの返信」が「質問に対する回答」に相当するとしました。
しかし、whywaitaが公開しているツイート情報には「質問に対する回答」はテキストで含まれていますが、「質問」は in_reply_to_status_id
のTweetのステータスIDしかありません。そこで、この in_reply_to_status_id
から「質問」に相当する「whywaitaへのリプライ」をテキストとしてTwitter API経由(GET statuses/show/:id | Docs | Twitter Developer Platform)で取得します。
とりあえず、以下のような超簡単かつ雑なコードをPythonで書いて動かしました。
import pandas as pd
import sys
import json
import os
from requests_oauthlib import OAuth1Session
import time
consumer_key = ''
consumer_secret = ''
access_token = ''
access_token_secret = ''
def extract_tweet_data_has_reply_id_from_csv(csv_path):
csv_data = pd.read_csv(csv_path)
csv_data = csv_data.dropna(subset=['in_reply_to_status_id'])
csv_data['timestamp'] = pd.to_datetime(csv_data['timestamp'])
return csv_data
def get_whywaita_conversations(csv_path, output_dir):
twitter = OAuth1Session(consumer_key, consumer_secret, access_token, access_token_secret)
get_tweet_from_status_id_api_url = "https://api.twitter.com/1.1/statuses/show.json"
sleep_time = int(15 * 60 / 160)
reply_data = extract_tweet_data_has_reply_id_from_csv(csv_path)
reply_data = reply_data[reply_data['timestamp'] > pd.to_datetime('2014/1/1')]
for whywaita_tweet in reply_data.iterrows():
print("Sleeping({}s)...".format(sleep_time))
time.sleep(sleep_time)
in_reply_to_status_id = int(whywaita_tweet[1]['in_reply_to_status_id'])
output_path = os.path.join(output_dir, "{}.json".format(in_reply_to_status_id))
params = {'id': in_reply_to_status_id}
print(params)
res = twitter.get(get_tweet_from_status_id_api_url, params=params)
if res.status_code != 200:
print("Status code is not 200, actual = {}".format(res.status_code))
continue
json_data = json.loads(res.text)
with open(output_path, 'w') as json_file:
json.dump(json_data, json_file)
if __name__ == '__main__':
csv_path = sys.argv[1]
output_dir = sys.argv[2]
get_whywaita_conversations(csv_path, output_dir)
とりあえずこれを動かすと雑に出力先ディレクトリに in_reply_to_status_id.json
という感じでどんどんデータが溜まって行きます(なんでこんなアホな設計したんだ)。
取得完了後、まずは文書の正規化等を行います。正規化処理は neologdn というライブラリを使用したり、絵文字除去をしたり、テキスト中から @hogehoge
とかURLを消去したりとかしています。(URLと絵文字除去の処理周りはどっかから拾ってきましたが、どこだったか失念・・。)
その後、カラムとして question
と answer
の2つを持つCSVに吐き出します。この動作をするのが以下のコードです。
import emoji
import neologdn
import re
def remove_emoji(src_str):
"""
絵文字除去
"""
return ''.join(c for c in src_str if c not in emoji.UNICODE_EMOJI)
def twitter_specific_normalize_process(text):
"""
@を消す
URLを削除
"""
text = re.sub(r'@[\w]+', '', text)
text=re.sub(r'https?://[\w/:%#\$&\?\(\)~\.=\+\-…]+', "", text)
return text
def normalize(text):
text = twitter_specific_normalize_process(text)
text = remove_emoji(text)
text = neologdn.normalize(text)
return text
from normalize import normalize
import pandas as pd
import sys
import os
import json
def extract_tweet_data_has_reply_id_from_csv(csv_path):
"""
whywaitaがGithubで公開しているCSVからin_reply_toのあるデータのみを取り出す
(+timestampカラムがpandas上でstr扱いされるので、timestamp型に変換)
"""
csv_data = pd.read_csv(csv_path)
csv_data = csv_data.dropna(subset=['in_reply_to_status_id'])
csv_data['timestamp'] = pd.to_datetime(csv_data['timestamp'])
return csv_data
if __name__ == '__main__':
whywaita_github_tweet_csv_path = sys.argv[1]
crawled_in_reply_to_json_dir = sys.argv[2]
generated_dataset_csv_output_path = sys.argv[3]
whywaita_github_tweet_csv_data = extract_tweet_data_has_reply_id_from_csv(whywaita_github_tweet_csv_path)
dataset_data = pd.DataFrame({'question':[], 'answer':[]})
for _, row in whywaita_github_tweet_csv_data.iterrows():
answer = row['text']
in_reply_to_status_id = int(row['in_reply_to_status_id'])
crawled_json_file_path = os.path.join(crawled_in_reply_to_json_dir, "{}.json".format(in_reply_to_status_id))
if not os.path.exists(crawled_json_file_path):
continue
json_data = None
with open(crawled_json_file_path) as json_file:
json_data = json.load(json_file)
question = json_data['text']
answer = normalize(answer)
question = normalize(question)
dataset_data = dataset_data.append(pd.Series({'question':question, 'answer':answer}), ignore_index=True)
dataset_data.to_csv(generated_dataset_csv_output_path, index=False)
これでとりあえずデータの準備は終わりました。
がくしゅー
さて、データが集まりましたのでモデルの構築をして学習を行わせたいと思います。Chatbotといえばseq2seqという感じですね。ということでKerasかChainerで実装しようと思ったんですが、時間がなかったので今回はTensor2Tensorというとても便利なライブラリを使用することにしました
What is Tensor2Tensor
Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.
github.com
リポジトリにある通り、DeepLearningをより身近に、そして機械学習の研究を加速化させることを目的としたDeep Learningのモデルとデータセットのライブラリです(意訳しただけでは)。正直、謳い文句はどうでもよくて「どれだけ楽に使えるの」というところが大事であるんですが、本当に楽に使えます。MNIST(手書き文字のデータセット、機械学習のHello, world的なもの)を使った手書き認識モデルの学習であれば、以下のコマンドを叩くだけです。
t2t-trainer \
--generate_data \
--data_dir=~/t2t_data \
--output_dir=~/t2t_train/mnist \
--problem=image_mnist \
--model=shake_shake \
--hparams_set=shake_shake_quick \
--train_steps=1000 \
--eval_steps=100
すっごい楽。既存モデルとデータセットを使うだけなら、コードを書く必要もなし。ここまで楽なのも正直どうかとは思いますが、ツールとしてDeepLearningを使う人(僕)にはとてもありがたい限りです。
今回はこのTensor2Tensorの lstm_seq2seq_attention_bidirectional_encoder
というモデルを使用した上で、自作のデータセットを食わせて学習を行わせてみます。
自作データセットで学習させる(前準備)
Tensor2Tensor側が用意している Train on Your Own Data
というドキュメントを読めばだいたいわかる(丸投げ)。
tensorflow.github.io
ドキュメントではText2TextProblem
というテキストからテキストという問題を解かせる場合が記述されています。今回の入力として質問文を、回答として質問への答えを返すという問題も Text2TextProblem
に相当しています。ということで、今回はこのドキュメントのサンプルコードをほぼそのまま使い、以下のようなファイル whywaita.py
を作成しました。
import pandas as pd
import os
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry
@registry.register_problem
class Whywaita(text_problems.Text2TextProblem):
"""
whywaitaっぽい返答をするChatbot
"""
@property
def approx_vocab_size(self):
return 2**13
@property
def is_generate_per_split(self):
return False
@property
def dataset_splits(self):
"""Splits of data to produce and number of output shards for each."""
return [{
"split": problem.DatasetSplit.TRAIN,
"shards": 9,
}, {
"split": problem.DatasetSplit.EVAL,
"shards": 1,
}]
def generate_samples(self, data_dir, tmp_dir, dataset_split):
del tmp_dir
del dataset_split
csv_path = os.path.join(data_dir, 'conversations.csv')
csv_data = pd.read_csv(csv_path)
csv_data = csv_data.dropna()
for _, row in csv_data.iterrows():
question = row['question']
answer = row['answer']
answer = answer.strip()
question = question.strip()
yield {
'inputs': question,
'targets': answer
}
approx_vocab_size
とかはもう少し調整するといいのかもしれないですが、今回は時間がなかったのでデフォルト値を使用しています。
更に、上記のファイルをimportするだけのファイル __init__.py
を同じディレクトリに置いておきます。
from whywaita import Whywaita
この2つのファイルを usr_dir
等、任意のディレクトリに置いておきます。
自作データセットで学習させる(Tensor2Tensor用のデータを生成)
さて、データセットをTensor2Tensorで使える形にするために以下のように t2t-datagen
コマンドを実行していきます。(なお、このときデータ準備で作成したデータをdata_dir
という名前のディレクトリに置いています)
t2t-datagen \
--data_dir=data_dir \
--tmp_dir=tmp_dir \
--problem=whywaita \
--t2t_usr_dir=./
オプションの説明を簡単にすると
--data_dir
--t2t_usr_dir
whywaita.py
と __init__.py
が置かれているディレクトリ
--tmp_dir
--problem
whywaita.py
でいう text_problems.Text2TextProblem
を継承したクラスである Whywaita
を指定する。ただし、クラス名の大文字はすべて小文字になり、キャメルケースだった場合にはスネークケースにTensor2Tensor内で変換されているため注意が必要。
PROBLEM is the name of the class that was registered with @registry.register_problem, but converted from CamelCase to snake_case.
実行後、しばらく待つとTensor2Tensor用のデータ生成が終了します。
学習
以上で学習に必要なデータの準備が終わったので、さっそく学習をしていきたいと思います。すでに書いたように、Tensor2Tensorはディレクトリと使用するモデルとハイパーパラメータを引数として指定してコマンドを実行するだけで学習が進んでいきます。
今回は以下のようなコマンドを実行しました。
t2t-trainer \
--data_dir=./data_dir \
--problem=whywaita \
--model=lstm_seq2seq_attention_bidirectional_encoder \
--hparams_set=lstm_luong_attention_multi \
--output_dir=./train_dir \
--t2t_usr_dir=./
オプションはほぼ t2t-datagen
と同じです。 --model
と --hparams_st
には学習時に使用するモデルとハイパーパラメータを指定しています。これらの一覧は t2t-trainer --registry_help
を実行すると表示されます。・・・表示はされるんだけど、あまり詳細な情報はなくて困る。更にググってもいまいち有益な情報は出てこない。すごい困る。現状ではトライ・アンド・エラーかなぁ・・・という状態。もし、ドキュメントがあったら教えてください・・・。
学習の間にはlossとval_lossは適宜標準出力に表示されますし、 output_dir
で指定したディレクトリにcheckpoint毎のモデルが出力されます。また、 Tensorboard に logdir
として --output_dir
で指定したディレクトリを渡すとTensorboardで可視化されます。
会話をしてみるぞ!
さて、学習がある程度進んだところで、会話をしてみることにします。 t2t-decoder
というコマンドから学習したモデルを使用することができます。
--data_dir
, --output_dir
, --model
, --problem
, --hparams_st
, --t2t_usr_dir
は学習時に使ったものと同じものを指定しておきます。また、--decode_hparams
はなにか適当な値を指定しました。この辺もいまいち情報がなくて困る・・・。
DATA_DIR=./data_dir
PROBLEM=whywaita
MODEL=lstm_seq2seq_attention_bidirectional_encoder
TRAIN_DIR=$1
BEAM_SIZE=4
ALPHA=0.6
HPARAMS=lstm_luong_attention_multi
t2t-decoder \
--data_dir=$DATA_DIR \
--problem=$PROBLEM \
--model=$MODEL \
--hparams_set=$HPARAMS \
--output_dir=$TRAIN_DIR \
--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \
--decode_interactive=true \
--t2t_usr_dir=./
あとはインタラクティブに会話ができるので、DeepLearningの世界に顕現したwhywaitaとの会話をしてみましょう。
茶番
だいたいいつもの感じですね。
whywaitaっぽい。
食べないで!
会話になってない
本番
ちょっと怪しい感じですが、僕とましろちゃんの仲を認めてもらいましょう。
は〜〜〜〜〜〜〜〜〜?????????????
これはだめそうですね、データセットを増やしてパラメータを調整して更に学習をさせてみました。
ブチ切れた
締め
ということで、Chatbotのwhywaitaからでさえ認めてもらうことができませんでしたが、そんなことで揺らぐ僕と真白ちゃんの仲ではないので大丈夫です。
実を言うと、わりとlossとval_lossがかなりアレで過学習気味です。データセットの規模が小さいせいなのか、それともTwitter上のツイートにかなりノイズが乗っているのか、モデル/ハイパーパラメータが適切ではないのかなど様々検証すべき点はあります。あと、そもそも t2t-decoder
で入力したこちら側からの質問をデータセットと同様の前処理(正規化)をしていない段階でだいぶお察しです。
今回はwhywaita Advent Calendar用のネタだったので深追いをしませんでしたが、個人的にTensor2Tensorは便利だなぁと思ったので、そのへんを他のデータセット等を使って調査してみたいなと思います。
なにはともあれ、来年もきっと id:masawada さんが伝統としてwhywaita Advent Calendarを作ってくれるはずなので楽しみにまっています。
それでは、明日は id:yu_ki_kun_0 さんです。弊社でのお仕事は楽しいですか?