分式session实现单点登录

基于zookeeper和cookieID实现分布式sso session

Posted by GrayWind on March 23, 2019

快三个月没更新了,最近忙于新工作上手时期强度比以前大很多,下班了非常摸鱼。后续准备开始写一些简单的应用和尝试分析一些常用中间件源码。

对于常见的前后端分离web应用来说,单点登录是非常常见的功能点。尤其是对于企业的众多应用来说,如果说我先去百度贴吧要登录,再去百度网盘又要输密码,再去百度知道答题还要重新登录,体验之差是显而易见的。

一种可选的做法是,所有应用统一采用同一个登录服务进行验证,完成后返回cookie作为凭证,前端存储该凭证并加入每次请求中。当然,实际情况中,还要有安全性方面复杂的验证机制。但是,对于不同的应用来说,相互之间需要进行数据的共享,否则每次新登录一个应用还是不知道登录信息,需要再去登录服务那里验证。

设计方案

方案一,应用服务器主动将凭证信息推送给其他服务器,主要步骤如下:

sso1

  1. web端向应用服务器B发起请求,由于此时是第一次请求没有凭证信息,所以会被拦截,并返回错误码
  2. web服务器发现请求失败后,会向登录服务器请求进行登录验证,此时就会进行跳转到统一的登录页面输入用户名密码
  3. 登录服务器验证请求有效后返回登录凭证
  4. web端带着凭证再次请求应用服务器B
  5. 应用服务器B向登录服务器确认凭证是否有效
  6. 应用服务器返回请求的业务结果,同时将凭证推送给其他的服务器

那么,这个方案问题在哪呢?应用服务器它就不该知道其他服务器的地址,对于服务功能划分来说,它做了额外的推送工作,为此它需要去注册中心获取所有服务器的地址(或者是同一个服务的所有服务器地址),再一个个推送过去。

方案二,每个应用服务器分别向登录服务器去请求

sso2

1-6和方案一相同

  1. web端向另一个应用服务器C请求服务
  2. C因为没有存储过这个凭证,向登录服务器要求验证凭证是否有效
  3. C收到凭据有效后,将服务结果返回给web端

这个方案的优点是每个应用服务器不需要额外的去获取其他应用服务器的地址,只需要知道登录服务器的位置,需要的时候去请求就行了,原则上是谁需要信息就想办法自己去获取。但是,登录服务器本身也是一个应用服务器,所有机器全都要去请求凭据的有效性肯定压力会随着应用服务器数量的上升而增大,势必要增加登录服务器的数量。但是,可以想到,我们的压力主要在于查询凭据是否有效,它并不是一个复杂的业务逻辑,比起一次完整的HTTP请求,通过一些专门的分布式集群来做这个工作更加易于扩展。

方案三,应用服务器第一次验证凭证后存储至外部集群,其他应用服务器向该集群验证

sso3

以zookeeper同步session为例,1-5和方案一相同

  1. 应用服务器B验证凭证成功后存储至zookeeper
  2. 返回应用请求
  3. web端请求服务C
  4. C向zookeeper查询该凭证是否有效
  5. 查询到该凭据有效后,返回结果给web端

具体设计

web端服务请求需要经过拦截器拦截,判断是否有cookie session,如果没有且不是登录请求的话直接返回错误码。对于登录请求来说,需要发送HTTP请求给指定的验证服务器,然后将结果缓存并保存至zookeeper,zookeeper的节点以对应的session ID作为路径,data部分包含用户信息和它的最近使用时间。应用定期扫描所有的节点,同步新的信息并删除过期的节点。

代码部分

首先对于应用端来说,需要有一个统一的Filter来拦截请求并验证他们是否有凭证

@Component
@WebFilter
public class AuthenticationFilter implements Filter {

  private static Logger logger = LoggerFactory.getLogger(AuthenticationFilter.class);

  @Autowired
  private DefaultAdminConfiguration config;

  @Autowired
  private SessionManager sessionManager;

  private static final String D_SESSION_ID = "D_SESSION_ID";

  /**
   * cookie 域.
   */
  private static final String COOKIE_DOMAIN = "";

  /**
   * cookie路径.
   */
  private static final String COOKIE_PATH = "/";


  @Override
  public void init(FilterConfig filterConfig) throws ServletException {
  }

  @Override
  public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
      throws IOException, ServletException {
    HttpServletRequest httpRequest = (HttpServletRequest) request;
    HttpServletResponse httpResponse = (HttpServletResponse) response;

    String sessionId = processSession(httpRequest, httpResponse);
    HttpSession session = httpRequest.getSession();
    session.setAttribute(Constants.SESSION_ID, sessionId);
    String username = null;
    try {
      Map<String, String> data = sessionManager.getSession(sessionId);
      if (data != null) {
        username = data.get("username");
      }
    } catch (Exception e) {
      logger.error("Failed to get sessionId [{}].", sessionId, e);
    }

    if (StringUtils.isBlank(username)) {
      boolean login = httpRequest.getRequestURI().endsWith("/api/login");
      if (login) {
        chain.doFilter(request, response);
        return;
      }

      httpResponse.setStatus(HttpStatus.SC_UNAUTHORIZED);
      writeAndFlushResponse("Login info is required.", httpResponse);
      return;
    }

    ContextHolder.current().setUsername(username);
    try {
      chain.doFilter(request, response);
    } finally {
      ContextHolder.remove();
    }
  }

  private String processSession(HttpServletRequest request, HttpServletResponse response) {
    String sessionId = null;
    // 获取 sessionId
    Cookie[] cookies = request.getCookies();
    if (cookies != null && cookies.length > 0) {
      for (Cookie cookie : cookies) {
        if (cookie.getName().equals(D_SESSION_ID)) {
          sessionId = cookie.getValue();
          break;
        }
      }
    }
    
    // 创建新的cookie session
    if (StringUtils.isBlank(sessionId)) {
      sessionId = UUID.randomUUID().toString();
      Cookie sessionCookie = new Cookie(D_SESSION_ID, sessionId);
      sessionCookie.setDomain(COOKIE_DOMAIN);
      sessionCookie.setPath(COOKIE_PATH);
      sessionCookie.setHttpOnly(true);
      response.addCookie(sessionCookie);
    }
    return sessionId;
  }

  private void writeAndFlushResponse(String msg, HttpServletResponse httpResponse)
      throws IOException {
    httpResponse.setCharacterEncoding("UTF-8");
    httpResponse.setContentType("application/json; charset=utf-8");
    ServletOutputStream output = httpResponse.getOutputStream();
    output.write(msg.getBytes());
    output.flush();
  }

  @Override
  public void destroy() {
  }
}

SessionManager负责管理Session的cache,并且定时同步zookeeper上面的节点信息,对于过期的节点要删除,重新获取节点信息说明是有一个新的服务器被访问,需要重新更新节点的有效时间。

@Component
public class SessionManager implements InitializingBean {

  private static final int SCAN_INTERVAL_SECONDS = 60;

  private static Logger logger = LoggerFactory.getLogger(SessionManager.class);

  private Map<String, SessionMetaVO> cache = Maps
      .newConcurrentMap();

  private ScheduledExecutorService executorService = Executors.newScheduledThreadPool(1);

  @Autowired
  private ZKService zkService;

  /**
   * 获取全部的session.
   *
   * @param sessionId sessionId
   * @return sessionVal
   */
  public Map<String, String> getSession(String sessionId) throws Exception {
    // 先从cache读取数据
    SessionMetaVO sessionMeta = cache.get(sessionId);
    if (sessionMeta == null) {
      sessionMeta = readSession(sessionId);
    }
    if (sessionMeta == null) {
      return null;
    }
    sessionMeta.setLastAccessTime(System.currentTimeMillis());
    cache.put(sessionId, sessionMeta);
    updateSession(sessionId, sessionMeta);
    return sessionMeta.getData();
  }

  /**
   * 如果需要立即更新,则立刻同步到 zk.
   */
  private void updateSession(String sessionId, SessionMetaVO sessionMeta) throws Exception {
    String path = getPath(sessionId);
    if (zkService.pathNodeExists(path)) {
      byte[] data = JSON.toJSONBytes(sessionMeta);
      zkService.updatePathNodeData(path, data);
    }
  }

  /**
   * 在 zk上创建session节点.
   *
   * @param sessionId sessionId
   * @param sessionMeta sessionMeta
   */
  public void addSession(String sessionId, SessionMetaVO sessionMeta) throws Exception {
    cache.put(sessionId, sessionMeta);
    byte[] data = JSON.toJSONBytes(sessionMeta);
    zkService.createPathNode(getPath(sessionId), data, true);
  }

  /**
   * 从zk上读取session数据.
   *
   * @param sessionId sessionId
   * @return data
   */
  private SessionMetaVO readSession(String sessionId) throws Exception {
    String path = getPath(sessionId);
    if (zkService.pathNodeExists(path)) {
      byte[] data = zkService.getPathNodeData(path);
      return JSON.parseObject(data, SessionMetaVO.class);
    }
    return null;
  }

  private String getPath(String... sessionIds) {
    StringBuilder path = new StringBuilder();
    path.append("/admin");
    if (sessionIds != null) {
      for (String sessionId : sessionIds) {
        path.append("/");
        path.append(sessionId);
      }
    }
    return path.toString();
  }

  /**
   * 删除zk上的session.
   *
   * @param sessionId sessionId
   */
  public void deleteSession(String sessionId) {
    cache.remove(sessionId);
    try {
      zkService.removePathNode(getPath(sessionId));
    } catch (Exception e) {
      logger.error(e.getMessage());
    }
  }

  /**
   * 遍历根节点
   */
  private void scan() {
    String root = getPath();
    List<String> nodes = Lists.emptyList();
    try {
      nodes = zkService.getPathChildrenNodes(root);
    } catch (Exception e) {
      logger.error("Does not create path [%s].", root);
    }
    if (!nodes.isEmpty()) {
      for (String node : nodes) {
        try {
          SessionMetaVO sessionMeta = readSession(node);
          if (sessionMeta != null) {
            if (System.currentTimeMillis() - sessionMeta.getLastAccessTime() <= 60000) {
              cache.put(node, sessionMeta);
            } else {
              deleteSession(node);
            }
          }
        } catch (Exception e) {
          logger.error(e.getMessage());
        }
      }
    }
  }

  @Override
  public void afterPropertiesSet() throws Exception {
    executorService.scheduleWithFixedDelay(() -> {
      scan();
    }, 0, SCAN_INTERVAL_SECONDS, TimeUnit.SECONDS);
  }
}

ZKService主要是对CuratorFramework的一些方法进行了一些封装

public class DefaultZKService implements ZKService {
  public DefaultZKService(String connectionString, int sessionTimeout, int connectTimeout, RetryPolicy retryPolicy) {
    this.connectionString = connectionString;
    this.sessionTimeout = sessionTimeout;
    this.connectTimeout = connectTimeout;
    this.retryPolicy = retryPolicy;
    framework = CuratorFrameworkFactory
        .newClient(connectionString, sessionTimeout, connectTimeout, retryPolicy);
  }
    
  public static class DefaultZKServiceBuilder implements ZKServiceBuilder {

    private String connectionString;

    private int connectionTimeout = 2000;

    private int sessionTimeout = 5000;

    private int retryInterval = 1000;

    private int maxRetryCount = 3;

    @Override
    public DefaultZKServiceBuilder connectionString(String connectionString) {
      this.connectionString = connectionString;
      return this;
    }

    @Override
    public DefaultZKServiceBuilder connectionTimeout(int connectionTimeout) {
      this.connectionTimeout = connectionTimeout;
      return this;
    }

    @Override
    public DefaultZKServiceBuilder sessionTimeout(int sessionTimeout) {
      this.sessionTimeout = sessionTimeout;
      return this;
    }

    @Override
    public DefaultZKServiceBuilder retryInterval(int retryInterval) {
      this.retryInterval = retryInterval;
      return this;
    }

    public DefaultZKServiceBuilder maxRetryCount(int maxRetryCount) {
      this.maxRetryCount = maxRetryCount;
      return this;
    }

    private void ensureValidToBuild() {
      if (StringUtils.isBlank(connectionString)) {
        throw new RuntimeException("Connection string is required.");
      }
    }

    @Override
    public DefaultZKService build() throws Exception {
      ensureValidToBuild();
      ExponentialBackoffRetry retryPolicy = new ExponentialBackoffRetry(retryInterval, maxRetryCount);
      return new DefaultZKService(connectionString, sessionTimeout, connectionTimeout, retryPolicy);
    }
  }
    
  public void updatePathNodeData(String path, byte[] data) throws Exception {
    if (pathNodeExists(path)) {
      framework.setData().forPath(path, data);
      logger.info("Update path node [{}] data to [{}]", path, data == null? null: new String(data, "UTF-8"));
    } else {
      throw new PathNodeNoExistException(path);
    }
  }
    
  public void createPathNode(String path, byte[] data, boolean isEphemeral) throws Exception {
    if (!pathNodeExists(path)) {
      framework.create().creatingParentsIfNeeded().withMode(isEphemeral? CreateMode.EPHEMERAL: CreateMode.PERSISTENT)
          .forPath(path, data);
      logger.info("Create path node [{}] with data [{}]", path, data == null? null: new String(data, "UTF-8"));
    } else {
      throw new PathNodeAlreadyExistException(path);
    }
  }
    
  public void removePathNode(String path) throws Exception {
    if (pathNodeExists(path)) {
      framework.delete().deletingChildrenIfNeeded().forPath(path);
      logger.info("Remove path node [{}]", path);
    } else {
      throw new PathNodeNoExistException(path);
    }
  }
    
  public List<String> getPathChildrenNodes(String path) throws Exception {
    if (pathNodeExists(path)) {
      return framework.getChildren().forPath(path);
    }
    throw new PathNodeNoExistException(path);
  }
}