AIに競技プログラミングの問題を解かせてみた

Programming icons created by Flat Icons – Flaticon

概要

競技プログラミングサイトAtCoderのコンテスト問題を、コード補完AIツールGitHub Copilotに解かせてみた。最低難度の問題は問題文をコメントで与えるだけで安定して解け、また問題文がほんの少し複雑でも正答を導き出すことができた。ただし自分で解いた方が速いので、全部補完させる使い方は実用的でない。

GitHub Copilotはコメントや関数名等の文脈から続くコードを提案してくれるプログラミング支援ツールで、Visual Studio Codeの拡張機能として動作します。現在はテクニカルプレビュー段階であり無料で利用できますが、商用版(有償)の正式リリースが計画されています。

競技プログラミングにAIが挑戦する試みは盛んに研究されており、AIが十分な成果を出したニュースについては他メンバーの記事で取り上げられておりますので、合わせてお読みいただけると幸いです。

競技プログラミングに挑戦したAI「AlphaCode」

今回の挑戦

今回はAtCoderさんが開催したコンテストの中から、AtCoder Beginner Contest(ABC)のA問題を中心に、Github Copilotの力で挑んでいきます。ABCは定期開催のコンテストの中で最も難易度が低く、その中でもA問題は最も配点が低い(≒難易度が低い)問題です。Github Copilotは競プロに特化したものではありませんが、どこまで通用するのか確かめていきましょう。

検証方法は次の通りです。

  • ライブラリのincludeや標準入出力など、解答ロジック以外のコードは予め用意しておく
  • 解答ロジックを補完させるsolve関数の上に問題文をコメントとして貼り付ける
  • GitHub Copilotに関数内でいくつかコード補完の候補を提案させる
  • その中で正答と思われるものを採用し、そのまま提出する
  • 結果がAC(Accepted、正解)ならGitHub Copilotが問題を解いたと見なす

ABC253 A – Median?

問題ページ:https://atcoder.jp/contests/abc253/tasks/abc253_a

整数a, b, cについて、bが中央値であるか判定する問題です。bが閉区間[min(a, c), max(a, c)]に含まれる場合にbは中央値であると言えますね。

solve関数内は全てGitHub Copilot(AI)の生成したコードです(以下のアコーディオンボタンから開いてご覧ください)。3つの生成結果を並べています。

/*
問題文
整数 a,b,c が与えられます。b がこれらの整数の中央値であるかどうか判定してください。

制約
1≤a,b,c≤100
入力は全て整数
*/
bool solve(int64_t a, int64_t b, int64_t c) {
    return (a <= b && b <= c) || (a >= b && b >= c);
}

=======

bool solve(int64_t a, int64_t b, int64_t c) {
    ll x = a + b + c;
    if (x % 2 == 0) {
        if (b == x / 2) return true;
    } else {
        if (b == (x + 1) / 2) return true;
    }
    return false;
}

=======

bool solve(int64_t a, int64_t b, int64_t c) {
    return b <= (a + c) / 2;
}

1番目:中央値を求める分かりやすいコードになっていますね。提出結果はACです。
2番目:何を解こうとしているのかよく分かりませんでした。
3番目:筆者は最初、問題を平均値の判定と勘違いしていて「この答えは惜しい!」と書く寸前でした。恥をかかずに済みました。

ABC252 A – ASCII code

問題ページ:https://atcoder.jp/contests/abc252/tasks/abc252_a

問題文を読めば分かりますが、どちらかというと標準入出力に関する問題です。C++は標準出力にchar型を流せば文字として出力されるのでサービス問題ですね。

入出力例もコメントに含めてみました。さて、GitHub Copilotが提案した補完候補の内、どれが正答でしょうか。

/*
問題文
英小文字 a, b, …, z の ASCII 文字コードはこの順に 97,98,…,122 です。
97 以上 122 以下の整数 N が与えられるので、ASCII 文字コードが N であるような英小文字を出力してください。

制約
N は 97 以上 122 以下の整数

入力
入力は以下の形式で標準入力から与えられる。
N

出力
答えを出力せよ。

入力例 1
97

出力例 1
a

ASCII 文字コードが 97 である英小文字は a です。

入力例 2
122

出力例 2
z
*/
char solve(int64_t N) {
    return N + 'a' - 97;
}

=====

char solve(int64_t N) {
    return (char)N;
}

=====

char solve(int64_t N) {
    return static_cast<char>(N);
}

=====

char solve(int64_t N) {
    return char(N);
}

=====

char solve(int64_t N) {
    return N - 97;
}

5番目以外全てACです。5番目のコード return N - 97; は、アルファベットのzero-based indexing(A=0, B=1, …, Z=26)を求めています。残念ながらWA(Wrong Answer、不正解)となりました。

ABC251 A – Six Characters

問題ページ:https://atcoder.jp/contests/abc251/tasks/abc251_a

問題文だけを与えてみました。

/*
問題文
英小文字からなる文字列 S が与えられます。 S の長さは 1 以上かつ 3 以下です。
S を繰り返して得られる文字列であって、長さが 6 のものを出力してください。
本問題の制約下で、そのような文字列はただ一つ存在することが示せます。
*/
std::string solve(std::string S) {
    std::string ans = "";
    REP(i, S.size()) {
        ans += S[i];
        if (ans.size() == 6) {
            return ans;
        }
    }
    return "";
}

上記コードはWAでした。REP(i, S.size()) { (REPはfor文のマクロ)が文字列の長さしか結合処理を回さないのでこれは誤りです。ちなみに制約から、ループを無限に回せば ans.size() == 6 を確実に満たします。

関数のコメントに問題文だけでなく制約と入出力例を書き足して補完させたところACでした。シンプルで十分スマートな解き方だと感じます。

/*
問題文
英小文字からなる文字列 S が与えられます。 S の長さは 1 以上かつ 3 以下です。
S を繰り返して得られる文字列であって、長さが 6 のものを出力してください。
本問題の制約下で、そのような文字列はただ一つ存在することが示せます。

制約
S は英小文字からなる長さ 1 以上 3 以下の文字列

入力例 1
abc

出力例 1
abcabc

S= abc を繰り返してできる文字列として、abc 、abcabc 、abcabcabc 、abcabcabcabc などがあります。 そのうち、長さが 6 のものは abcabc です。よって、abcabc と出力します。

入力例 2
zz

出力例 2
zzzzzz
*/
std::string solve(std::string S) {
    std::string ans = S;
    while (ans.size() < 6) {
        ans += S;
    }
    return ans.substr(0, 6);
}

ABC253 B – Distance Between Tokens

問題ページ:https://atcoder.jp/contests/abc253/tasks/abc253_b

文字列で表された平面上の2点間のマンハッタン距離を求める問題です。

今回はB問題です。まだまだ初級レベルですが、ここまでのA問題より少しだけ問題文も複雑になっています。はたしてGitHub Copilotに解けるでしょうか。

/*
問題文
H 行 W 列のマス目があり、そのうち二つの異なるマスに駒が置かれています。

マス目の状態は H 個の長さ W の文字列 S_1,…,S_H で表されます。
S_i,j = o ならば i 行目 j 列目のマスに駒が置かれていることを、S_i,j = - ならばそのマスには駒が置かれていないことを表します。
なお、S_i,j は文字列 S_i の j 文字目を指します。

一方の駒をマス目の外側に出ないように上下左右の隣接するマスに動かすことを繰り返すとき、もう一方の駒と同じマスに移動させるためには最小で何回動かす必要がありますか?

制約
2≤H,W≤100
H,W は整数
S_i (1≤i≤H) は o および - のみからなる長さ W の文字列
S_i,j = o となる整数 1≤i≤H,1≤j≤W の組がちょうど二つ存在する

入力
入力は以下の形式で標準入力から与えられる。

H W
S_1
⋮
S_H
​
出力
答えを出力せよ。

入力例 1
2 3
--o
o--

出力例 1
3

1 行目 3 列目に置かれている駒を 下 → 左 → 左 と移動すると 3 回でもう一方の駒と同じマスに移動させることができます。2 回以下で移動させることはできないので、3 を出力します。

入力例 2
5 4
-o--
----
----
----
-o--

出力例 2
4
*/
int64_t solve(int H, int64_t W, const std::vector<std::string> &S) {

}
Synthesizing 10/10 solutions

=======

    int64_t ans = 0;
    for (int i = 0; i < H; i++) {
        for (int j = 0; j < W; j++) {
            if (S[i][j] == 'o') {
                ans += std::min(std::min(i, H - i - 1), std::min(j, W - j - 1));
            }
        }
    }
    return ans;
}

=======

  int64_t ans = INF;
  for (int64_t i = 0; i < H; ++i) {
    for (int64_t j = 0; j < W; ++j) {
      if (S[i][j] == 'o') {
        int64_t cnt = 0;
        if (i > 0 && S[i - 1][j] == 'o') ++cnt;
        if (i < H - 1 && S[i + 1][j] == 'o') ++cnt;
        if (j > 0 && S[i][j - 1] == 'o') ++cnt;
        if (j < W - 1 && S[i][j + 1] == 'o') ++cnt;
        ans = min(ans, cnt);
      }
    }
  }
  return ans;
}

=======

    int64_t ans = 0;
    for (int64_t i = 0; i < H; i++) {
        for (int64_t j = 0; j < W; j++) {
            if (S[i][j] == 'o') {
                int64_t cnt = 0;
                if (i + 1 < H && S[i + 1][j] == 'o') {
                    cnt++;
                }
                if (j + 1 < W && S[i][j + 1] == 'o') {
                    cnt++;
                }
                if (i - 1 >= 0 && S[i - 1][j] == 'o') {
                    cnt++;
                }
                if (j - 1 >= 0 && S[i][j - 1] == 'o') {
                    cnt++;
                }
                ans = max(ans, cnt);
            }
        }
    }
    return ans;
}

=======

    int64_t ans = 0;
    for (int64_t i = 0; i < H; i++) {
        for (int64_t j = 0; j < W; j++) {
            if (S[i][j] == 'o') {
                int64_t cnt = 0;
                if (i > 0) {
                    cnt += S[i - 1][j] == 'o' ? 1 : 0;
                }
                if (j > 0) {
                    cnt += S[i][j - 1] == 'o' ? 1 : 0;
                }
                if (i < H - 1) {
                    cnt += S[i + 1][j] == 'o' ? 1 : 0;
                }
                if (j < W - 1) {
                    cnt += S[i][j + 1] == 'o' ? 1 : 0;
                }
                ans = std::max(ans, cnt);
            }
        }
    }
    return ans;
}

=======

    int64_t ans = INF;
    REP(i, H) {
        REP(j, W) {
            if (S[i][j] == 'o') {
                int64_t cnt = 0;
                if (i > 0) {
                    if (S[i - 1][j] == 'o') {
                        cnt++;
                    }
                }
                if (j > 0) {
                    if (S[i][j - 1] == 'o') {
                        cnt++;
                    }
                }
                if (i + 1 < H) {
                    if (S[i + 1][j] == 'o') {
                        cnt++;
                    }
                }
                if (j + 1 < W) {
                    if (S[i][j + 1] == 'o') {
                        cnt++;
                    }
                }
                ans = min(ans, cnt);
            }
        }
    }
    return ans;
}

=======

  int64_t ans = INF;
  REP(i, H) REP(j, W) {
    if (S[i][j] != 'o') continue;
    int64_t cnt = 0;
    if (i - 1 >= 0 && S[i - 1][j] == 'o') cnt++;
    if (i + 1 < H && S[i + 1][j] == 'o') cnt++;
    if (j - 1 >= 0 && S[i][j - 1] == 'o') cnt++;
    if (j + 1 < W && S[i][j + 1] == 'o') cnt++;
    ans = min(ans, cnt);
  }
  return ans;
}

=======

    int64_t ans = INF;
    std::vector<std::vector<int>> dp(H, std::vector<int>(W, INF));
    dp[0][0] = 0;
    REP(i, H) {
        REP(j, W) {
            if (S[i][j] == 'o') {
                if (i > 0) {
                    dp[i][j] = std::min(dp[i][j], dp[i - 1][j] + 1);
                }
                if (j > 0) {
                    dp[i][j] = std::min(dp[i][j], dp[i][j - 1] + 1);
                }
                if (i + 1 < H) {
                    dp[i][j] = std::min(dp[i][j], dp[i + 1][j] + 1);
                }
                if (j + 1 < W) {
                    dp[i][j] = std::min(dp[i][j], dp[i][j + 1] + 1);
                }
            }
            ans = std::min(ans, dp[i][j]);
        }
    }
    return ans;
}

=======

  int64_t ans = 0;
  std::vector<std::vector<int64_t>> dp(H, std::vector<int64_t>(W, -1));
  dp[0][0] = 0;
  for (int i = 0; i < H; ++i) {
    for (int j = 0; j < W; ++j) {
      if (S[i][j] == 'o') {
        if (i > 0) {
          dp[i][j] = min(dp[i][j], dp[i - 1][j] + 1);
        }
        if (j > 0) {
          dp[i][j] = min(dp[i][j], dp[i][j - 1] + 1);
        }
        if (i < H - 1) {
          dp[i][j] = min(dp[i][j], dp[i + 1][j] + 1);
        }
        if (j < W - 1) {
          dp[i][j] = min(dp[i][j], dp[i][j + 1] + 1);
        }
        ans = max(ans, dp[i][j]);
      }
    }
  }
  return ans;
}

=======

    int64_t ans = INF;
    int64_t n = 0;
    for (int64_t i = 0; i < H; ++i) {
        for (int64_t j = 0; j < W; ++j) {
            if (S[i][j] == 'o') {
                ++n;
                for (int64_t k = 0; k < H; ++k) {
                    for (int64_t l = 0; l < W; ++l) {
                        if (S[k][l] == 'o' && (k != i || l != j)) {
                            ans = min(ans, abs(k - i) + abs(l - j));
                        }
                    }
                }
            }
        }
    }
    return n == 2 ? ans : -1;
}

=======

  int64_t ans = 0;
  vector<vector<int>> dp(H, vector<int>(W, INF));
  REP(i, H) {
    REP(j, W) {
      if (S[i][j] == 'o') {
        dp[i][j] = 0;
      }
    }
  }

  REP(i, H) {
    REP(j, W) {
      if (S[i][j] == 'o') {
        int64_t x = i;
        int64_t y = j;
        while (x < H && S[x][y] == 'o') {
          dp[x][y] = min(dp[x][y], dp[i][j] + 1);
          x++;
        }
        x = i;
        while (y < W && S[x][y] == 'o') {
          dp[x][y] = min(dp[x][y], dp[i][j] + 1);
          y++;
        }
        x = i;
        while (x >= 0 && S[x][y] == 'o') {
          dp[x][y] = min(dp[x][y], dp[i][j] + 1);
          x--;
        }
        x = i;
        while (y >= 0 && S[x][y] == 'o') {
          dp[x][y] = min(dp[x][y], dp[i][j] + 1);
          y--;
        }
      }
    }
  }
  REP(i, H) {
    REP(j, W) {
      ans = max(ans, dp[i][j]);
    }
  }
  return ans;

多くの提案の中から1つ選んで提出したらACが出ました。2つの点のXY座標が分かれば答えが求まるので、それができそうな4重ループのあるコードを勘で選びました(下から2番目の候補)。

まとめ

GitHub Copilotは競プロに特化したAIではないにもかかわらず、ABCのA問題を安定して解く力はあるようでした。また、B問題のような比較的複雑な問題文も読み解けているのにはビックリしました。ただ、複数の補完候補で1番目に提案されたコードが正答とは限らないので、どの提案を選ぶかは自分で判断が必要です。それができるなら自分でコードを書いた方が早いような気もしますが…むしろ解いている途中のロジックを素早く補完・提案してもらう使い方ならより恩恵を受けられそうです。その内試してみたいと思います。

Github Copilotは無料で利用できる割に十分な力を見せてくれたので、今後はさらにAIによるコーディング支援のレベルが高まっていくことを楽しみにしています。


記事が面白かった方、参考になった方は、是非「イイね」お願いします👍

競技プログラミングやAIによるプログラミングに興味の湧いた方はこちらの記事もどうぞ!

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

CAPTCHA