package tries;

import java.util.HashMap;
import java.util.function.Consumer;

public class Trie {
  static class Node {
    boolean terminal;
    final HashMap<Character, Node> map = new HashMap<>();
    
    void each(StringBuilder builder, Consumer<String> consumer) {
      if (terminal) {
        consumer.accept(builder.toString());
      }
      map.forEach((letter, child) -> {
        builder.append(letter);
        child.each(builder, consumer);
        builder.setLength(builder.length() - 1);
      });
    }
  }
  
  private final Node root = new Node();
  
  public void add(String word) {
    Node current = root;
    for(char letter: word.toCharArray()) {
      current = current.map.computeIfAbsent(letter, __ -> new Node());
    }
    current.terminal = true;
  }
  
  private Node findNode(String word) {
    Node current = root;
    for(char letter: word.toCharArray()) {
      Node node = current.map.get(letter);
      if (node == null) {
        return null;
      }
      current = node;
    }
    return current;
  }
  
  public boolean contains(String word) {
    Node node = findNode(word);
    return (node == null)? false: node.terminal; 
  }
  
  public void prefix(String prefix, Consumer<String> consumer) {
    Node node = findNode(prefix);
    if (node == null) {
      return;
    }
    node.each(new StringBuilder().append(prefix), consumer);
  }
  
  public static void main(String[] args) {
    Trie tries = new Trie();
    tries.add("toto");
    tries.add("tuto");
    tries.add("tota");
    System.out.println(tries.contains("toto"));
    System.out.println(tries.contains("to"));
    tries.prefix("t", System.out::println);
  }
}
