REST APIでフィルタリングをサポートする方法

なぜAPIにフィルタリング機能が必要になるのか

Webサービスを運営している時に、サーバサイドではAPIにフィルタリング機能を求められる事がしばしば起こります。このようなAPIの一例として、タスク管理ツールでのタスクの一覧を返す以下のようなAPIがあります。

openapi: 3.0.1
info:
  title: OpenAPI definition
  version: v0
servers:
- url: http://localhost:8080
  description: Generated server url
paths:
  /tasks:
    get:
      tags:
      - tasks-controller
      summary: タスクをリストアップする
      operationId: listTasks
      responses:
        "200":
          description: OK
          content:
            application/json:
              schema:
                $ref: "#/components/schemas/ListTasksResponse"
components:
  schemas:
    ListTasksResponse:
      type: object
      properties:
        results:
          type: array
          items:
            $ref: "#/components/schemas/TaskDto"
    Task:
      type: object
      properties:
        id:
          type: string
          description: ID
          format: uuid
        title:
          type: string
          description: タスクの名前です。
        description:
          type: string
          description: タスクの説明です。
        completed:
          type: boolean
          description: 完了したかどうか。
        creatAt:
          type: string
          description: 作成日時です。
          format: date-time
        updateAt:
          type: string
          description: 更新日時です。
          format: date-time
      description: タスクです。

このAPIに対して、「タスクの名前でフィルタリングしたい」、「未完了のものだけ返して欲しい」とか、「未完了のもので1週間以上更新のないものをリストアップしたい」というような要望が出てくるのは一般的です。フィルタリングの処理そのものを実装するのは、RDBMSなどのデータベースへのクエリに条件を追加するだけなので難しくはないでしょう。

ではAPIの仕様はどうするのか

ですが、ここで問題となるのは、APIのクライアントに対してどのようにフィルタリングの条件を指定させるかです。例えばSQLでは以下のように条件を指定して絞り込む事になるでしょう。

SELECT id, title, description, completed, create_at, update_at FROM tasks WHERE completed = false AND update_at < '2024-11-01';

NoSQLのMongoDBでは、以下のようなクエリを実行する事になるでしょう。

db.tasks.find({completed: false, update_at: { $lt : ISODate('2024-11-01')}})

APIの利用者にとって completed = false AND update_at < '2024-11-01'{completed: false, update_at: { $lt : ISODate('2024-11-01')}} では、どちらの方が学習コストが低くて利用しやすいと思うでしょうか? 前者のような条件式の方がよいと答えるプログラマの方が多いでしょう。(嫌われている事が多いSQLですが、ほとんどのプログラマは基本的な事は知っているのです)

ですが、前者のSQLのような書式のフィルタリングの条件式をサポートするとして、サーバ側ではどのように処理をすればよいのでしょう。以下のような実装は絶対に してはいけません 。(ここではJava言語でSpring JDBCを使用した例を記載しています) SQLインジェクション攻撃を受けて個人情報が漏れてしまってIPAへ届け出する事になるでしょう。

@Repository
public class TaskRepository {
    private final JdbcTemplate jdbcTemplate;

    public TaskRepository(JdbcTemplate jdbcTemplate) {
        this.jdbcTemplate = jdbcTemplate;
    }

    /**
     * 条件を指定してタスクを取得する。
     *
     * @param filter クライアントが指定した文字列。
     * @return
     */
    public List<TaskEntity> findAllWithFilter(String filter) {
        // 絶対に駄目な実装例
        return jdbcTemplate.query("SELECT id, title, description, completed, create_at, update_at FROM tasks WHERE " + filter,
                                new DataClassRowMapper<>(TaskEntity.class));
    }
}

サーバではクライアントから受け取った値は信用せずに検証するのが基本ですので、フィルタリングの指定も検証する必要があります。検証するには受け取ったフィルタリングの条件式を解析して妥当な内容であるかを確かめます。

便利なAPIは検証が大変なのか

ここで一部の方は「パーサを書かなきゃいけないのか。大変だな。」と言いながら、嬉しそうに解析用のコードを書き出すかもしれません。ですが、他の人が書いたパーサをメンテナンスする事はあまり楽しいことではないので、保守性を考えてやめておきましょう。ANTLRという、パーサを生成するツール(パーサジェネレータとかコンパイラコンパイラといいます)がありますのでこちらを使用します。ANTLR自身はJavaで書かれていますが、JavaC#PythonJavaScript、Go、C++、Swiftと多くの言語向けにパーサを生成できます。本記事ではこのANTLRを使って、APIで指定された条件式をパースして検証していきます。

ANTLRを使って解析、検証する

システム構成について

ANTLRは、様々なプログラミング言語向けにパーサを生成できますが、本記事では以下のような構成で解説します。

プロジェクトの作成

まずSprint initializrの画面を開いて、下図のように設定して「GENERATE」ボタンをクリックします。

SpringInitializr

ダウンロードしたファイルを適当なディレクトリに展開します。

次にpom.xmlに以下の内容を追加します。

 <properties>
    <!-- 中略 -->
        <antlr.version>4.13.2</antlr.version>
    </properties>

  <!-- 中略 -->
        <dependency>
            <groupId>org.antlr</groupId>
            <artifactId>antlr4-runtime</artifactId>
            <version>${antlr.version}</version>
        </dependency>
    <!-- 中略 -->
            <plugin>
                <groupId>org.antlr</groupId>
                <artifactId>antlr4-maven-plugin</artifactId>
                <version>${antlr.version}</version>
                <configuration>
                    <listener>true</listener>
                    <visitor>true</visitor>
                </configuration>
                <executions>
                    <execution>
                        <id>antlr-generate</id>
                        <phase>generate-sources</phase>
                        <goals>
                            <goal>antlr4</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>

そして、以下のANTLRの文法ファイルを作成します。ここではカラムの値との単純な比較と、ANDORのみをサポートするだけにしています。もっと複雑な条件式をサポートしたいこちらを参考にしてください。

ファイル名: src/main/antlr4/com/example/taskdemo/application/model/query/SQLCondition.g4

grammar SQLCondition;

query: expression EOF;

expression
    : expression binaryOp=LOGICAL_OPERATOR expression
    | '(' expression ')'
    | condition
    ;

condition
    : column COMPARISON_OPERATOR value
    ;

column: IDENTIFIER;
value: STRING | INT;

LOGICAL_OPERATOR: 'AND' | 'OR';
COMPARISON_OPERATOR: '=' | '!=' | '<' | '<=' | '>' | '>=';

STRING: '\'' ( ~['\\] | '\\.' )* '\'';
INT: [0-9]+;
IDENTIFIER: [a-zA-Z_][a-zA-Z0-9_]*;

WS: [ \t\r\n]+ -> skip;

ファイルが作成できたら以下のコマンドを実行します。

$ ./mvnw antlr4:antlr4

完了するとtargetディレクトリに以下のパーサのソースコードが生成されます。

$ cd target/generated-sources/antlr4/com/example/taskdemo/application/model/query/
$ ls *.java
SQLConditionBaseListener.java  SQLConditionLexer.java     SQLConditionParser.java
SQLConditionBaseVisitor.java   SQLConditionListener.java  SQLConditionVisitor.java

IntelliJを使用している場合は生成されたソースコードを認識させるために、下図のように「Project Structure」 - 「Modules」を選択して、target/generated-sources/antlr4を右クリックして「Sources」を選択し、「OK」ボタンをクリックします。

IntelliJの設定

これで、フィルタリングのリクエストを解析して検証する準備ができました。まず、コントローラークラスでリクエストをパースしてユースケースクラスを呼び出します。

ファイル名: src/main/java/com/example/taskdemo/controller/TaskController.java

@RestController
public class TaskController {
    private final ListTasksUseCase listTasksUseCase;

    public TaskController(ListTasksUseCase listTasksUseCase) {
        this.listTasksUseCase = listTasksUseCase;
    }

    @GetMapping("/tasks")
    public ListTaskResponse getTasks(@RequestParam(required = false) String filter) {
        var filterContext = Optional.ofNullable(filter)
                .map(filterStr -> {
                    var lexer = new SQLConditionLexer(CharStreams.fromString(filterStr));
                    var parser = new SQLConditionParser(new CommonTokenStream(lexer));
                    return parser.query();
                });
        return new ListTaskResponse(listTasksUseCase.handle(filterContext));
    }
}

なお、ここではクライアントへのレスポンス用に以下のクラスを定義しています。

ファイル名: src/main/java/com/example/taskdemo/controller/ListTaskResponse.java

public record ListTaskResponse(List<Task> results) {
}

ユースケースクラスは、リクエストの条件で指定されたプロパティが正しいのか意味的な検証をしてから、リポジトリクラスを呼び出します。

ファイル名: src/main/java/com/example/taskdemo/application/ListTasksUseCase.java

@Service
public class ListTasksUseCase {
    private final TaskRepository taskRepository;

    public ListTasksUseCase(TaskRepository taskRepository) {
        this.taskRepository = taskRepository;
    }

    public List<Task> handle(Optional<SQLConditionParser.QueryContext> filterContext) {
        filterContext.ifPresent(fCtx -> new ListTasksFilterValidator().validate(fCtx));
        return taskRepository.findAll(filterContext);
    }
}

ファイル名: src/main/java/com/example/taskdemo/application/ListTasksFilterValidator.java

public class ListTasksFilterValidator implements SQLConditionBaseListener {
    static final Set<String> FIELD_NAMES = Set.of("title", "description", "completed", "createdAt", "updatedAt");

    public void validate(SQLConditionParser.QueryContext queryContext) {
        ParseTreeWalker.DEFAULT.walk(this, queryContext);
    }

    @Override
    public void enterCondition(SQLConditionParser.ConditionContext ctx) {
        var columnName = ctx.column().getText();
        if (!FIELD_NAMES.contains(columnName)) {
            throw new IllegalFieldNameException(columnName);
        }
    }

    public static class IllegalFieldNameException extends RuntimeException {
        public IllegalFieldNameException(String fieldName) {
            super("Illegal field name: " + fieldName);
        }
    }
}

Exceptionが発生するようになるのでハンドラを定義しておきます。

ファイル名: src/main/java/com/example/taskdemo/controller/TaskDemoExceptionHandler.java

@RestControllerAdvice
public class TaskDemoExceptionHandler {
    @ExceptionHandler(value = ListTasksFilterValidator.IllegalFieldNameException.class)
    public ResponseEntity illegalFieldNameException(ListTasksFilterValidator.IllegalFieldNameException e) {
        return ResponseEntity.badRequest().body(new TaskDemoErrorResponse("ILLEGAL_FIELD_NAME", e.getMessage()));
    }
}

ファイル名: src/main/java/com/example/taskdemo/controller/TaskDemoErrorResponse.java

public record TaskDemoErrorResponse(String errorCode, String errorMessage) {
}

リポジトリクラスは、リクエストの条件を受け取って動的にSQLを組み立てデータベースに問い合わせます。

ファイル名: src/main/java/com/example/taskdemo/infrastructure/TaskRepository.java

@Repository
public class TaskRepository {
    private final NamedParameterJdbcTemplate jdbcTemplate;

    public TaskRepository(NamedParameterJdbcTemplate jdbcTemplate) {
        this.jdbcTemplate = jdbcTemplate;
    }

    public List<Task> findAll(Optional<SQLConditionParser.QueryContext> queryContext) {
        return queryContext
                .map(qCtx -> {
                    var walker = new ParseTreeWalker();
                    var builder = new ListTaskSQLQueryBuilder();
                    walker.walk(builder, qCtx);

                    return jdbcTemplate.query(builder.getSql(), builder.getParameters(),new DataClassRowMapper<>(Task.class));
                })
                .orElseGet(() ->
                        jdbcTemplate.query("SELECT id, title, description, completed, created_at, updated_at FROM tasks",
                                new DataClassRowMapper<>(Task.class)));
    }

    static class ListTaskSQLQueryBuilder implements SQLConditionBaseListener {
        static final Map<String, String> COLUMN_MAP = Map.ofEntries(
                entry("createdAt", "created_at"),
                entry("updatedAt", "updated_at")
        );
        private StringBuilder sql = new StringBuilder("SELECT id, title, description, completed, created_at, updated_at FROM tasks WHERE ");

        private MapSqlParameterSource parameters = new MapSqlParameterSource();
        private int paramIndex = 0;

        public String getSql() {
            return sql.toString();
        }

        public MapSqlParameterSource getParameters() {
            return parameters;
        }

        @Override
        public void enterExpression(SQLConditionParser.ExpressionContext ctx) {
            if (ctx.getChildCount() == 3 && ctx.getChild(0).getText().equals("(")) {
                sql.append("(");
            }
        }

        @Override
        public void exitExpression(SQLConditionParser.ExpressionContext ctx) {
            if (ctx.getChildCount() == 3 && ctx.getChild(2).getText().equals(")")) {
                sql.append(")");
                if (!ctx.getParent().isEmpty()) {
                    var parent = ctx.getParent();
                    if (parent instanceof SQLConditionParser.ExpressionContext) {
                        if (((SQLConditionParser.ExpressionContext) parent).binaryOp != null && parent.getChild(0) == ctx) {
                            sql.append(" " + ((SQLConditionParser.ExpressionContext) parent).LOGICAL_OPERATOR().getText() + " ");
                        }
                    }
                }
            }
        }

        @Override
        public void enterCondition(SQLConditionParser.ConditionContext ctx) {
            var column = COLUMN_MAP.getOrDefault(ctx.column().getText(), ctx.column().getText());
            var comparator = ctx.COMPARISON_OPERATOR().getText();
            var value = ctx.value().getText().replace("'", "");
            var valuePlaceHolder = column + paramIndex++;

            sql.append(String.format("%s %s :%s", column, comparator, valuePlaceHolder));
            parameters.addValue(valuePlaceHolder, value);
        }

        @Override
        public void exitCondition(SQLConditionParser.ConditionContext ctx) {
            if (!ctx.getParent().isEmpty()) {
                var parent = ctx.getParent();
                if (parent instanceof SQLConditionParser.ExpressionContext && !parent.getParent().isEmpty()) {
                    var grandParent = parent.getParent();
                    if (((SQLConditionParser.ExpressionContext) grandParent).binaryOp != null && grandParent.getChild(0) == parent) {
                        sql.append(" " + ((SQLConditionParser.ExpressionContext) grandParent).LOGICAL_OPERATOR().getText() + " ");
                    }
                }
            }
        }
    }
}

コードを実装できたので、データベースを起動するようにしてスキーマを定義し、サンプルのデータを投入します。まず、application.ymlファイルに以下の設定を記述します。

ファイル名: src/main/resources/application.yml

spring:
  datasource:
    driver-class-name: org.h2.Driver
    url: jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1;DB_CLOSE_ON_EXIT=false
    username: sa
    password:
  sql:
    init:
      encoding: UTF-8

次にデータベースのDDLを作成します。

ファイル名: src/main/resources/schema.sql

CREATE TABLE tasks (id uuid, title VARCHAR, description VARCHAR, completed BOOLEAN, created_at TIMESTAMP, updated_at TIMESTAMP);

サンプルのデータを投入します。

ファイル名: src/main/resources/data.sql

INSERT INTO TASKS VALUES
                      ( RANDOM_UUID() , '週次レポートを書く', 'ANTLRについて書く。', TRUE, '2024-11-01T09:30:23', '2024-11-01T13:05:14'),
                      ( RANDOM_UUID() , 'プロジェクト提案を提出する', '収益の見積もりをブラッシュアップする事。', FALSE, '2024-10-15T14:53:48', '2024-10-15T14:53:48'),
                      ( RANDOM_UUID() , '出張旅費を精算する', '催促されないようにする事。', FALSE, '2024-10-31T10:15:31', '2024-10-31T10:15:31')
                      ;

準備が出来たので、Spring Bootを起動します。

./mvnw spring-boot:run

別のコンソール画面からリストの取得をしてみます。

$ curl -s http://localhost:8080/tasks | jq .
{
  "results": [
    {
      "id": "87fe00a2-c8af-45bf-b8aa-700e89884d8f",
      "title": "週次レポートを書く",
      "description": "ANTLRについて書く。",
      "completed": true,
      "createdAt": "2024-11-01T09:30:23",
      "updatedAt": "2024-11-01T13:05:14"
    },
    {
      "id": "d698904d-b915-41c7-8e58-6d813db6424f",
      "title": "プロジェクト提案を提出する",
      "description": "収益の見積もりをブラッシュアップする事。",
      "completed": false,
      "createdAt": "2024-10-15T14:53:48",
      "updatedAt": "2024-10-15T14:53:48"
    },
    {
      "id": "3177a191-7e67-45ba-9c81-4351a2f2eb00",
      "title": "出張旅費を精算する",
      "description": "催促されないようにする事。",
      "completed": false,
      "createdAt": "2024-10-31T10:15:31",
      "updatedAt": "2024-10-31T10:15:31"
    }
  ]
}
$ curl -s --get --data-urlencode "filter=title = '週次レポートを書く'" http://localhost:8080/tasks | jq .
{
  "results": [
    {
      "id": "87fe00a2-c8af-45bf-b8aa-700e89884d8f",
      "title": "週次レポートを書く",
      "description": "ANTLRについて書く。",
      "completed": true,
      "createdAt": "2024-11-01T09:30:23",
      "updatedAt": "2024-11-01T13:05:14"
    }
  ]
}
$ curl -s --get --data-urlencode "filter=completed = false AND updatedAt < '2024-10-25'" http://localhost:8080/tasks | jq .
{
  "results": [
    {
      "id": "d698904d-b915-41c7-8e58-6d813db6424f",
      "title": "プロジェクト提案を提出する",
      "description": "収益の見積もりをブラッシュアップする事。",
      "completed": false,
      "createdAt": "2024-10-15T14:53:48",
      "updatedAt": "2024-10-15T14:53:48"
    }
  ]
}

以上のようにANTLRを使用すると、文法の定義を行えば簡単にパースして検証する事ができました。