Neo4j implements table field level blood relationship

Requirement background

All upstream and downstream blood relationships of the current table fields need to be displayed on the front-end page for further data diagnosis and management. The general effect diagram is as follows:
First of all, here is an explanation of what table field blood relationship is, SQL example:

CREATE TABLE IF NOT EXISTS table_b
AS SELECT order_id, order_status FROM table_a;

In the above DDL statement, the order_id and order_status fields of the created table_b are derived from table_a, which means that table_a is the source table of table_b, also called the upstream table, table_b is the downstream table of table_a, and table_a.order_id is the upstream field of table_b.order_id. There is a blood relationship between them.

INSERT INTO table_c
SELECT a.order_id, b.order_status
FROM table_a a JOIN table_b b ON a.order_id = b.order_id;

In the above DML statement, the order_id field of table_c comes from table_a, and the order_status comes from table_b, which means that there is also a blood relationship between table_c, table_a, and table_b.

It can also be seen from the above that if you want to store blood relationships, you need to parse SQL first. This mainly uses the parser of the open source project calcite. This article will not be expanded on. This article mainly talks about how to store and how to display.

Environment configuration

Refer to another article: SpringBoot configures embedded Neo4j

Node data structure definition

Because we want to display the kinship relationship between the fields of the table, we directly store the table fields as graph nodes. The kinship relationship between the table fields is represented by the relationship between the graph nodes. The specific node definition is as follows:

public class ColumnVertex {
  // unique key
  private String name;

  public ColumnVertex(String catalogName, String databaseName, String tableName, String columnName) {
    this.name = catalogName + "." + databaseName + "." + tableName + "." + columnName;
  }

  public String getCatalogName() {
    return Long.parseLong(name.split("\.")[0]);
  }

  public String getDatabaseName() {
    return name.split("\.")[1];
  }

  public String getTableName() {
    return name.split("\.")[2];
  }

  public String getColumnName() {
    return name.split("\.")[3];
  }
}

General Service Definition

public interface EmbeddedGraphService {
    //Add graph nodes and relationships with upstream nodes
    void addColumnVertex(ColumnVertex currentVertex, ColumnVertex upstreamVertex);
    // Find the upstream node
    List<ColumnVertex> findUpstreamColumnVertex(ColumnVertex currentVertex);
    // Find downstream nodes
    List<ColumnVertex> findDownstreamColumnVertex(ColumnVertex currentVertex);
}

Service implementation

import javax.annotation.Resource;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Result;
import org.neo4j.graphdb.Transaction;
import org.springframework.stereotype.Service;

@Service
public class EmbeddedGraphServiceImpl implements EmbeddedGraphService {

  @Resource private GraphDatabaseService graphDb;

  @Override
  public void addColumnVertex(ColumnVertex currentVertex, ColumnVertex upstreamVertex) {
    try (Transaction tx = graphDb.beginTx()) {
      tx. execute(
          "MERGE (c:ColumnVertex {name: $currentName}) MERGE (u:ColumnVertex {name: $upstreamName})"
               + " MERGE (u)-[:UPSTREAM]->(c)",
          Map.of("currentName", currentVertex.getName(), "upstreamName", upstreamVertex.getName()));
      tx.commit();
    }
  }

  @Override
  public List<ColumnVertex> findUpstreamColumnVertex(ColumnVertex currentVertex) {
    List<ColumnVertex> result = new ArrayList<>();
    try (Transaction tx = graphDb.beginTx()) {
      Result queryResult =
          tx. execute(
              "MATCH (u:ColumnVertex)-[:UPSTREAM]->(c:ColumnVertex) WHERE c.name = $name RETURN"
                   + " u.name AS name",
              Map.of("name", currentVertex.getName()));
      while (queryResult. hasNext()) {
        Map<String, Object> row = queryResult. next();
        result.add(new ColumnVertex().setName((String) row.get("name")));
      }
      tx.commit();
    }
    return result;
  }

  @Override
  public List<ColumnVertex> findDownstreamColumnVertex(ColumnVertex currentVertex) {
    List<ColumnVertex> result = new ArrayList<>();
    try (Transaction tx = graphDb.beginTx()) {
      Result queryResult =
          tx.execute(
              "MATCH (c:ColumnVertex)-[:UPSTREAM]->(d:ColumnVertex) WHERE c.name = $name RETURN"
                   + " d.name AS name",
              Map.of("name", currentVertex.getName()));
      while (queryResult.hasNext()) {
        Map<String, Object> row = queryResult.next();
        result.add(new ColumnVertex().setName((String) row.get("name")));
      }
      tx.commit();
    }
    return result;
  }
}

Traverse graph nodes

Implementation logic:

  1. restful interface input parameters: current table (catalogName, databaseName, tableName)
  2. Define the data structure returned to the front end, using nodes and edges to return it, and then the front end renders a complete blood relationship graph based on the relationship between nodes and edges;
public class ColumnLineageVO {
  List<ColumnLineageNode> nodes;
  List<ColumnLineageEdge> edges;
}

public class ColumnLineageNode {
  private String databaseName;
  private String tableName;
  private List<String> columnNames;
}

public class ColumnLineageEdge {
  private ColumnLineageEdgePoint source;
  private ColumnLineageEdgePoint target;
}

public class ColumnLineageEdgePoint {
  private String databaseName;
  private String tableName;
  private String columnName;
}
  1. Query table fields;
  2. Using a recursive method, use the current table field to traverse all upstream and downstream graph nodes associated with the current table field;
  3. Encapsulate all nodes into List ColumnLineageVO and return to the front end.
public ColumnLineageVO getColumnLineage(Table table) {
    ColumnLineageVO columnLineageVO = new ColumnLineageVO();
    List<ColumnLineageNode> nodes = new ArrayList<>();
    List<ColumnLineageEdge> edges = new ArrayList<>();
    // Deduplication
    Set<String> visitedNodes = new HashSet<>();
    Set<String> visitedEdges = new HashSet<>();
    Map<String, List<ColumnVertex>> upstreamCache = new HashMap<>();
    Map<String, List<ColumnVertex>> downstreamCache = new HashMap<>();

    ColumnLineageNode currentNode =
        ColumnLineageNode. builder()
            .databaseName(table.getDatabaseName())
            .tableName(table.getTableName())
            .type(TableType.EXTERNAL_TABLE.getDesc())
            .build();
    nodes.add(currentNode);
    visitedNodes.add(currentNode.getDatabaseName() + "." + currentNode.getTableName());

    for (String columnName : table. getColumnNames()) {
      ColumnVertex currentVertex =
          new ColumnVertex(
              table.getScriptId(), table.getDatabaseName(), table.getTableName(), columnName);
      traverseUpstreamColumnVertex(
          currentVertex, nodes, edges, visitedNodes, visitedEdges, upstreamCache);
      traverseDownstreamColumnVertex(
          currentVertex, nodes, edges, visitedNodes, visitedEdges, downstreamCache);
    }

    columnLineageVO.setNodes(nodes);
    columnLineageVO.setEdges(edges);
    return columnLineageVO;
  }

private void traverseUpstreamColumnVertex(
      ColumnVertex currentVertex,
      List<ColumnLineageNode> nodes,
      List<ColumnLineageEdge> edges,
      Set<String> visitedNodes,
      Set<String> visitedEdges,
      Map<String, List<ColumnVertex>> cache) {
    List<ColumnVertex> upstreamVertices;
    if (cache. containsKey(currentVertex. getName())) {
      upstreamVertices = cache.get(currentVertex.getName());
    } else {
      upstreamVertices = embeddedGraphService. findUpstreamColumnVertex(currentVertex);
      cache.put(currentVertex.getName(), upstreamVertices);
    }
    for (ColumnVertex upstreamVertex : upstreamVertices) {
      String nodeKey = upstreamVertex.getDatabaseName() + "." + upstreamVertex.getTableName();
      if (!visitedNodes. contains(nodeKey)) {
        ColumnLineageNode upstreamNode =
            ColumnLineageNode. builder()
                .databaseName(upstreamVertex.getDatabaseName())
                .tableName(upstreamVertex.getTableName())
                .type(TableType. EXTERNAL_TABLE. getDesc())
                .build();
        nodes.add(upstreamNode);
        visitedNodes. add(nodeKey);
      }
      String edgeKey =
          upstreamVertex. getDatabaseName()
               + upstreamVertex. getTableName()
               + upstreamVertex. getColumnName()
               + currentVertex. getDatabaseName()
               + currentVertex. getTableName()
               + currentVertex.getColumnName();
      if (!visitedEdges.contains(edgeKey)) {
        ColumnLineageEdge edge = createEdge(upstreamVertex, currentVertex);
        edges.add(edge);
        visitedEdges.add(edgeKey);
      }
      traverseUpstreamColumnVertex(upstreamVertex, nodes, edges, visitedNodes, visitedEdges, cache);
    }
  }
  
private void traverseDownstreamColumnVertex(
      ColumnVertex currentVertex,
      List<ColumnLineageNode> nodes,
      List<ColumnLineageEdge> edges,
      Set<String> visitedNodes,
      Set<String> visitedEdges,
      Map<String, List<ColumnVertex>> cache) {
    List<ColumnVertex> downstreamVertices;
    if (cache.containsKey(currentVertex.getName())) {
      downstreamVertices = cache.get(currentVertex.getName());
    } else {
      downstreamVertices = embeddedGraphService.findDownstreamColumnVertex(currentVertex);
      cache.put(currentVertex.getName(), downstreamVertices);
    }
    for (ColumnVertex downstreamVertex : downstreamVertices) {
      String nodeKey = downstreamVertex.getDatabaseName() + "." + downstreamVertex.getTableName();
      if (!visitedNodes.contains(nodeKey)) {
        ColumnLineageNode downstreamNode =
            ColumnLineageNode.builder()
                .databaseName(downstreamVertex.getDatabaseName())
                .tableName(downstreamVertex.getTableName())
                .type(TableType.EXTERNAL_TABLE.getDesc())
                .build();
        nodes.add(downstreamNode);
        visitedNodes.add(nodeKey);
      }
      String edgeKey =
          currentVertex. getDatabaseName()
               + currentVertex. getTableName()
               + currentVertex. getColumnName()
               + downstreamVertex. getDatabaseName()
               + downstreamVertex. getTableName()
               + downstreamVertex. getColumnName();
      if (!visitedEdges. contains(edgeKey)) {
        ColumnLineageEdge edge = createEdge(currentVertex, downstreamVertex);
        edges.add(edge);
        visitedEdges.add(edgeKey);
      }
      traverseDownstreamColumnVertex(
          downstreamVertex, nodes, edges, visitedNodes, visitedEdges, cache);
    }
  }