Coyote vs Loadbalancer

The Blog for the Rest of Us

whywaita Advent Calendar 2018 18日目 おいでませ!Chatbot whywaitaくん!

これはなに

この記事は id:whywaita と関係ないことを書く謎のAdvent Calendar、 whywaita Advent Calendar 2018 の18日目の記事です。

昨日は id:hnron さんでした。僕もwhywaitaくんの好きなところはいっぱいあります。いっぱいね。

hnron.hatenablog.com

記事テーマ

さて、唐突なのですが実は僕は既婚なんです。お相手は有坂真白さんという方です。僕と彼女の関係は大変良好で、今年のクリスマスも一緒に過ごす約束をしています。

しかし、なぜかwhywaitaくんは僕が既婚であることを決して認めてくれないんですよね。

僕と真白ちゃんはこんなにも仲が良いというのに・・・。 僕たちの仲をどうにか認めてほしいというもの。
ということで、今回の記事では、whywaitaから僕と真白ちゃんの仲を認めてもらう発言をしてもらうということをテーマに記事を書いていきたいと思います。

How?

ここで唐突ですが、1700年代のフランスの学者ビュフォンが言ったとされる言葉と一般的にはマザー・テレサがいったとされている(実際には違うらしい)言葉の2つをご紹介します。

文は人なり

考えは言葉となり、言葉は行動となり、行動は習慣となり、習慣は人格となり、人格は運命となる。

この2つの言葉を(都合よく拡大解釈して)踏まえると、Twitterのツイートのような短い文章の中にもその人となりが現れ、文章の中に人格が現れるということにならないでしょうか?いや、ならないですかね。でも、なるということで話を進めます。

更に、なんと都合のいいことにwhywaitaは自らのGitHub上で自分のツイートを公開しているのです。

github.com

このツイートデータをもとにChatbotを生成すれば、そのChatbotにはwhywaitaの人となりが現れていることになって、そのChatbotと会話をすればwhywaitaと会話していることになるのではないでしょうか?完全に「方針はいいが、論理の飛躍が見られます」状態ですが、それはともかくとして、今回のAdvent Calendarの記事ではChatbotのwhywaitaから僕と真白ちゃんの仲を認めてもらう発言を引き出していきたいと思います。

(この導入、「whywaitaのツイートデータを使ってChatbotを作ったゾイ」の1センテンスで終了する気がしますね)

データの準備

まず、Chatbotの学習を行わせるために学習データを準備します。
Chatbotの学習には「質問」と「質問に対する回答」の2種類のテキストが必要となります。今回のwhywaita chatbotにはこちら側の質問に返答してほしいので、Twitterで言うところの「whywaitaへのリプライ」が「質問」に相当し、「そのリプライへのwhywaitaの返信」が「質問に対する回答」に相当するとしました。

しかし、whywaitaが公開しているツイート情報には「質問に対する回答」はテキストで含まれていますが、「質問」は in_reply_to_status_idTweetのステータスIDしかありません。そこで、この in_reply_to_status_id から「質問」に相当する「whywaitaへのリプライ」をテキストとしてTwitter API経由(GET statuses/show/:id | Docs | Twitter Developer Platform)で取得します。

とりあえず、以下のような超簡単かつ雑なコードをPythonで書いて動かしました。

import pandas as pd
import sys
import json
import os
from requests_oauthlib import OAuth1Session
import time

consumer_key = ''
consumer_secret = ''
access_token = ''
access_token_secret = ''

def extract_tweet_data_has_reply_id_from_csv(csv_path):
    csv_data = pd.read_csv(csv_path)
    # in_reply_to_status_idがNaN、つまりリプライではないデータをDropする
    csv_data = csv_data.dropna(subset=['in_reply_to_status_id'])
    csv_data['timestamp'] = pd.to_datetime(csv_data['timestamp']) # timestampがstrなので、timestampにする

    return csv_data

def get_whywaita_conversations(csv_path, output_dir):
    twitter = OAuth1Session(consumer_key, consumer_secret, access_token, access_token_secret)
    get_tweet_from_status_id_api_url = "https://api.twitter.com/1.1/statuses/show.json"

    sleep_time = int(15 * 60 / 160) # GET status/:idは15分で180回のリクエスト制限がある、余裕見て160回/15分で

    reply_data = extract_tweet_data_has_reply_id_from_csv(csv_path)

    # とりあえず、2014/01/01までのから引っ張ってくる
    reply_data = reply_data[reply_data['timestamp'] > pd.to_datetime('2014/1/1')]

    for whywaita_tweet in reply_data.iterrows():
        print("Sleeping({}s)...".format(sleep_time))
        time.sleep(sleep_time)
        in_reply_to_status_id = int(whywaita_tweet[1]['in_reply_to_status_id'])
        output_path = os.path.join(output_dir, "{}.json".format(in_reply_to_status_id))
        params = {'id': in_reply_to_status_id}
        print(params)

        res = twitter.get(get_tweet_from_status_id_api_url, params=params)

        if res.status_code != 200:
            print("Status code is not 200, actual = {}".format(res.status_code))
            continue

        json_data = json.loads(res.text)

        with open(output_path, 'w') as json_file:
            json.dump(json_data, json_file)


if __name__ == '__main__':
    csv_path = sys.argv[1]
    output_dir = sys.argv[2]
    get_whywaita_conversations(csv_path, output_dir)

とりあえずこれを動かすと雑に出力先ディレクトリに in_reply_to_status_id.json という感じでどんどんデータが溜まって行きます(なんでこんなアホな設計したんだ)。

取得完了後、まずは文書の正規化等を行います。正規化処理は neologdn というライブラリを使用したり、絵文字除去をしたり、テキスト中から @hogehoge とかURLを消去したりとかしています。(URLと絵文字除去の処理周りはどっかから拾ってきましたが、どこだったか失念・・。)
その後、カラムとして questionanswer の2つを持つCSVに吐き出します。この動作をするのが以下のコードです。

import emoji
import neologdn
import re

def remove_emoji(src_str):
    """
    絵文字除去
    """
    return ''.join(c for c in src_str if c not in emoji.UNICODE_EMOJI)

def twitter_specific_normalize_process(text):
    """
    @を消す
    URLを削除
    """

    # @を削除
    text = re.sub(r'@[\w]+', '', text)
    # URLを削除
    text=re.sub(r'https?://[\w/:%#\$&\?\(\)~\.=\+\-…]+', "", text)

    return text


def normalize(text):
    text = twitter_specific_normalize_process(text)
    text = remove_emoji(text)
    text = neologdn.normalize(text)

    return text
from normalize import normalize
import pandas as pd
import sys
import os
import json

def extract_tweet_data_has_reply_id_from_csv(csv_path):
    """
    whywaitaがGithubで公開しているCSVからin_reply_toのあるデータのみを取り出す
    (+timestampカラムがpandas上でstr扱いされるので、timestamp型に変換)
    """
    csv_data = pd.read_csv(csv_path)
    # in_reply_to_status_idがNaN、つまりリプライではないデータをDropする
    csv_data = csv_data.dropna(subset=['in_reply_to_status_id'])
    csv_data['timestamp'] = pd.to_datetime(csv_data['timestamp']) # timestampがstrなので、timestampにする

    return csv_data


if __name__ == '__main__':
    whywaita_github_tweet_csv_path = sys.argv[1]
    crawled_in_reply_to_json_dir = sys.argv[2]
    generated_dataset_csv_output_path = sys.argv[3]

    whywaita_github_tweet_csv_data = extract_tweet_data_has_reply_id_from_csv(whywaita_github_tweet_csv_path)

    dataset_data = pd.DataFrame({'question':[], 'answer':[]})

    for _, row in whywaita_github_tweet_csv_data.iterrows():
        # CSV上でのwhywaitの発言はin_reply_to_status_idのツイートへの返信なので、`answer`
        answer = row['text']

        # in_reply_to_status_idを取り出し、クローリングしてきたJSONファイルを読み込み
        in_reply_to_status_id = int(row['in_reply_to_status_id'])
        crawled_json_file_path = os.path.join(crawled_in_reply_to_json_dir, "{}.json".format(in_reply_to_status_id))
        if not os.path.exists(crawled_json_file_path):
            # 対応するJSONがない場合、スキップ
            #   -> アカウントがない、ツイ消し、鍵垢
            continue
        json_data = None
        with open(crawled_json_file_path) as json_file:
            json_data = json.load(json_file)
        # in_reply_to_status_idのツイートにwhywaitaが返信するので、`question`
        question = json_data['text']

        # 正規化
        answer = normalize(answer)
        question = normalize(question)

        dataset_data = dataset_data.append(pd.Series({'question':question, 'answer':answer}), ignore_index=True)

    dataset_data.to_csv(generated_dataset_csv_output_path, index=False)

これでとりあえずデータの準備は終わりました。

がくしゅー

さて、データが集まりましたのでモデルの構築をして学習を行わせたいと思います。Chatbotといえばseq2seqという感じですね。ということでKerasかChainerで実装しようと思ったんですが、時間がなかったので今回はTensor2Tensorというとても便利なライブラリを使用することにしました

What is Tensor2Tensor

Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.

github.com

リポジトリにある通り、DeepLearningをより身近に、そして機械学習の研究を加速化させることを目的としたDeep Learningのモデルとデータセットのライブラリです(意訳しただけでは)。正直、謳い文句はどうでもよくて「どれだけ楽に使えるの」というところが大事であるんですが、本当に楽に使えます。MNIST(手書き文字のデータセット機械学習のHello, world的なもの)を使った手書き認識モデルの学習であれば、以下のコマンドを叩くだけです。

t2t-trainer \
  --generate_data \
  --data_dir=~/t2t_data \
  --output_dir=~/t2t_train/mnist \
  --problem=image_mnist \
  --model=shake_shake \
  --hparams_set=shake_shake_quick \
  --train_steps=1000 \
  --eval_steps=100

すっごい楽。既存モデルとデータセットを使うだけなら、コードを書く必要もなし。ここまで楽なのも正直どうかとは思いますが、ツールとしてDeepLearningを使う人(僕)にはとてもありがたい限りです。

今回はこのTensor2Tensorの lstm_seq2seq_attention_bidirectional_encoder というモデルを使用した上で、自作のデータセットを食わせて学習を行わせてみます。

自作データセットで学習させる(前準備)

Tensor2Tensor側が用意している Train on Your Own Data というドキュメントを読めばだいたいわかる(丸投げ)。

tensorflow.github.io

ドキュメントではText2TextProblemというテキストからテキストという問題を解かせる場合が記述されています。今回の入力として質問文を、回答として質問への答えを返すという問題も Text2TextProblem に相当しています。ということで、今回はこのドキュメントのサンプルコードをほぼそのまま使い、以下のようなファイル whywaita.py を作成しました。

import pandas as pd
import os

from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry

@registry.register_problem
class Whywaita(text_problems.Text2TextProblem):
    """
    whywaitaっぽい返答をするChatbot
    """

    @property
    def approx_vocab_size(self):
        return 2**13  # ~8k

    @property
    def is_generate_per_split(self):
        # generate_data will shard the data into TRAIN and EVAL for us.
        return False

    @property
    def dataset_splits(self):
        """Splits of data to produce and number of output shards for each."""
        # 10% evaluation data
        return [{
            "split": problem.DatasetSplit.TRAIN,
            "shards": 9,
            }, {
            "split": problem.DatasetSplit.EVAL,
            "shards": 1,
            }]

    def generate_samples(self, data_dir, tmp_dir, dataset_split):
        del tmp_dir
        del dataset_split

        # whywaitaの会話を"""正規化"""したCSVデータ(ここでは正規化は行わない)
        csv_path = os.path.join(data_dir, 'conversations.csv')
        csv_data = pd.read_csv(csv_path)
        csv_data = csv_data.dropna()

        for _, row in csv_data.iterrows():
            question = row['question']
            answer = row['answer']
            answer = answer.strip() # 改行を抜く
            question = question.strip()

            yield {
                'inputs': question,
                'targets': answer
            }

approx_vocab_size とかはもう少し調整するといいのかもしれないですが、今回は時間がなかったのでデフォルト値を使用しています。

更に、上記のファイルをimportするだけのファイル __init__.py を同じディレクトリに置いておきます。

from whywaita import Whywaita

この2つのファイルを usr_dir 等、任意のディレクトリに置いておきます。

自作データセットで学習させる(Tensor2Tensor用のデータを生成)

さて、データセットをTensor2Tensorで使える形にするために以下のように t2t-datagen コマンドを実行していきます。(なお、このときデータ準備で作成したデータをdata_dir という名前のディレクトリに置いています)

t2t-datagen \
--data_dir=data_dir \
--tmp_dir=tmp_dir \
--problem=whywaita \
--t2t_usr_dir=./

オプションの説明を簡単にすると

  • --data_dir
  • --t2t_usr_dir
  • --tmp_dir
  • --problem
    • whywaita.py でいう text_problems.Text2TextProblem を継承したクラスである Whywaita を指定する。ただし、クラス名の大文字はすべて小文字になり、キャメルケースだった場合にはスネークケースにTensor2Tensor内で変換されているため注意が必要。

      PROBLEM is the name of the class that was registered with @registry.register_problem, but converted from CamelCase to snake_case.

実行後、しばらく待つとTensor2Tensor用のデータ生成が終了します。

学習

以上で学習に必要なデータの準備が終わったので、さっそく学習をしていきたいと思います。すでに書いたように、Tensor2Tensorはディレクトリと使用するモデルとハイパーパラメータを引数として指定してコマンドを実行するだけで学習が進んでいきます。

今回は以下のようなコマンドを実行しました。

t2t-trainer \
        --data_dir=./data_dir \
        --problem=whywaita \
        --model=lstm_seq2seq_attention_bidirectional_encoder \
        --hparams_set=lstm_luong_attention_multi \
        --output_dir=./train_dir \
        --t2t_usr_dir=./

オプションはほぼ t2t-datagen と同じです。 --model--hparams_stには学習時に使用するモデルとハイパーパラメータを指定しています。これらの一覧は t2t-trainer --registry_help を実行すると表示されます。・・・表示はされるんだけど、あまり詳細な情報はなくて困る。更にググってもいまいち有益な情報は出てこない。すごい困る。現状ではトライ・アンド・エラーかなぁ・・・という状態。もし、ドキュメントがあったら教えてください・・・。

学習の間にはlossとval_lossは適宜標準出力に表示されますし、 output_dir で指定したディレクトリにcheckpoint毎のモデルが出力されます。また、 Tensorboard に logdir として --output_dir で指定したディレクトリを渡すとTensorboardで可視化されます。

会話をしてみるぞ!

さて、学習がある程度進んだところで、会話をしてみることにします。 t2t-decoder というコマンドから学習したモデルを使用することができます。
--data_dir, --output_dir, --model, --problem, --hparams_st, --t2t_usr_dir は学習時に使ったものと同じものを指定しておきます。また、--decode_hparamsはなにか適当な値を指定しました。この辺もいまいち情報がなくて困る・・・。

DATA_DIR=./data_dir
PROBLEM=whywaita
MODEL=lstm_seq2seq_attention_bidirectional_encoder
TRAIN_DIR=$1
BEAM_SIZE=4
ALPHA=0.6
HPARAMS=lstm_luong_attention_multi


t2t-decoder \
   --data_dir=$DATA_DIR \
   --problem=$PROBLEM \
   --model=$MODEL \
   --hparams_set=$HPARAMS \
   --output_dir=$TRAIN_DIR \
   --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \
   --decode_interactive=true \
   --t2t_usr_dir=./

あとはインタラクティブに会話ができるので、DeepLearningの世界に顕現したwhywaitaとの会話をしてみましょう。

茶番

f:id:tw1sm1k0:20181216215201p:plain
だいたいいつもの感じですね。

f:id:tw1sm1k0:20181216215241p:plain whywaitaっぽい。

f:id:tw1sm1k0:20181216215335p:plain 食べないで!

f:id:tw1sm1k0:20181216215951p:plain 会話になってない

本番

ちょっと怪しい感じですが、僕とましろちゃんの仲を認めてもらいましょう。

f:id:tw1sm1k0:20181216220038p:plain

は〜〜〜〜〜〜〜〜〜?????????????

これはだめそうですね、データセットを増やしてパラメータを調整して更に学習をさせてみました。

f:id:tw1sm1k0:20181216220809p:plain

ブチ切れた

締め

ということで、Chatbotのwhywaitaからでさえ認めてもらうことができませんでしたが、そんなことで揺らぐ僕と真白ちゃんの仲ではないので大丈夫です。

実を言うと、わりとlossとval_lossがかなりアレで過学習気味です。データセットの規模が小さいせいなのか、それともTwitter上のツイートにかなりノイズが乗っているのか、モデル/ハイパーパラメータが適切ではないのかなど様々検証すべき点はあります。あと、そもそも t2t-decoder で入力したこちら側からの質問をデータセットと同様の前処理(正規化)をしていない段階でだいぶお察しです。
今回はwhywaita Advent Calendar用のネタだったので深追いをしませんでしたが、個人的にTensor2Tensorは便利だなぁと思ったので、そのへんを他のデータセット等を使って調査してみたいなと思います。

なにはともあれ、来年もきっと id:masawada さんが伝統としてwhywaita Advent Calendarを作ってくれるはずなので楽しみにまっています。

それでは、明日は id:yu_ki_kun_0 さんです。弊社でのお仕事は楽しいですか?