package ai.giskard.learnspringwebsockets; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequestWrapper; import jakarta.servlet.http.HttpServletResponse; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.stereotype.Component; import org.springframework.web.filter.OncePerRequestFilter; import java.io.IOException; import java.text.MessageFormat; import java.util.Collections; import java.util.Enumeration; import java.util.List; import java.util.StringTokenizer; @SpringBootApplication public class LearnSpringWebsocketsApplication { public static void main(String[] args) { SpringApplication.run(LearnSpringWebsocketsApplication.class, args); } @Component class RequestWrapperFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { final HttpServletRequestWrapper reqWrapper = new HttpServletRequestWrapper(request) { @Override public Enumeration getHeaders(String name) { if ("connection".equals(name) && "UPGRADE".equals(super.getHeaders(name).nextElement())) { return Collections.enumeration(Collections.singleton("Upgrade")); } return super.getHeaders(name); } @Override public String getHeader(String name) { if ("connection".equals(name) && "UPGRADE".equals(super.getHeader(name))) { return "Upgrade"; } return super.getHeader(name); } }; filterChain.doFilter(reqWrapper, response); } } }