diff --git a/src/main/java/ch/psi/daq/queryrest/config/QueryRestConfig.java b/src/main/java/ch/psi/daq/queryrest/config/QueryRestConfig.java index 511623b..97c0664 100644 --- a/src/main/java/ch/psi/daq/queryrest/config/QueryRestConfig.java +++ b/src/main/java/ch/psi/daq/queryrest/config/QueryRestConfig.java @@ -9,11 +9,9 @@ import java.util.function.Function; import javax.annotation.PostConstruct; import javax.annotation.Resource; -import javax.servlet.Filter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Import; @@ -25,10 +23,6 @@ import org.springframework.util.StringUtils; import org.springframework.validation.Validator; import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.core.JsonFactory; -import com.fasterxml.jackson.databind.ObjectMapper; - import ch.psi.daq.common.statistic.StorelessStatistics; import ch.psi.daq.domain.DataEvent; import ch.psi.daq.query.analyzer.QueryAnalyzer; @@ -38,12 +32,16 @@ import ch.psi.daq.query.model.Aggregation; import ch.psi.daq.query.model.Query; import ch.psi.daq.query.model.QueryField; import ch.psi.daq.queryrest.controller.validator.QueryValidator; -import ch.psi.daq.queryrest.filter.CorsFilter; import ch.psi.daq.queryrest.model.PropertyFilterMixin; import ch.psi.daq.queryrest.response.csv.CSVResponseStreamWriter; import ch.psi.daq.queryrest.response.json.JSONResponseStreamWriter; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.databind.ObjectMapper; + @Configuration +@Import(value=QueryRestConfigCORS.class) @PropertySource(value = {"classpath:queryrest.properties"}) @PropertySource(value = {"file:${user.home}/.config/daq/queryrest.properties"}, ignoreResourceNotFound = true) public class QueryRestConfig extends WebMvcConfigurerAdapter { @@ -56,7 +54,8 @@ public class QueryRestConfig extends WebMvcConfigurerAdapter { // a nested configuration // this guarantees that the ordering of the properties file is as expected - // see: https://jira.spring.io/browse/SPR-10409?focusedCommentId=101393&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-101393 + // see: + // https://jira.spring.io/browse/SPR-10409?focusedCommentId=101393&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-101393 @Configuration @Import({QueryConfig.class}) static class InnerConfiguration { @@ -74,7 +73,7 @@ public class QueryRestConfig extends WebMvcConfigurerAdapter { @Resource private ObjectMapper objectMapper; - + @PostConstruct public void afterPropertiesSet() { // only include non-null values @@ -85,6 +84,7 @@ public class QueryRestConfig extends WebMvcConfigurerAdapter { objectMapper.addMixIn(StorelessStatistics.class, PropertyFilterMixin.class); objectMapper.addMixIn(EnumMap.class, PropertyFilterMixin.class); } + /** * {@inheritDoc} @@ -115,7 +115,7 @@ public class QueryRestConfig extends WebMvcConfigurerAdapter { public JSONResponseStreamWriter jsonResponseStreamWriter() { return new JSONResponseStreamWriter(); } - + @Bean public CSVResponseStreamWriter csvResponseStreamWriter() { return new CSVResponseStreamWriter(); @@ -165,23 +165,17 @@ public class QueryRestConfig extends WebMvcConfigurerAdapter { return new QueryValidator(); } - @Bean - @ConditionalOnProperty("queryrest.cors.enable") - public Filter corsFilter() { - return new CorsFilter(); - } - @Bean(name = BEAN_NAME_CORS_ALLOWEDORIGINS) - public String configuredOrigins(){ - String value = env.getProperty(QUERYREST_CORS_ALLOWEDORIGINS, "http://localhost:8080, *"); - LOGGER.debug("Load '{}={}'", QUERYREST_CORS_ALLOWEDORIGINS, value); - return value; + public String allowedOrigins() { + String value = env.getProperty(QUERYREST_CORS_ALLOWEDORIGINS, "http://localhost:8080, *"); + LOGGER.debug("Load '{}={}'", QUERYREST_CORS_ALLOWEDORIGINS, value); + return value; } - + @Bean(name = BEAN_NAME_CORS_FORCEALLHEADERS) - public Boolean forceAllHeaders(){ - Boolean value = env.getProperty(QUERYREST_CORS_FORCEALLHEADERS, Boolean.class, true); - LOGGER.debug("Load '{}={}'", QUERYREST_CORS_FORCEALLHEADERS, value); - return value; + public Boolean forceAllHeaders() { + Boolean value = env.getProperty(QUERYREST_CORS_FORCEALLHEADERS, Boolean.class, true); + LOGGER.debug("Load '{}={}'", QUERYREST_CORS_FORCEALLHEADERS, value); + return value; } } diff --git a/src/main/java/ch/psi/daq/queryrest/config/QueryRestConfigCORS.java b/src/main/java/ch/psi/daq/queryrest/config/QueryRestConfigCORS.java new file mode 100644 index 0000000..de93f8f --- /dev/null +++ b/src/main/java/ch/psi/daq/queryrest/config/QueryRestConfigCORS.java @@ -0,0 +1,72 @@ +package ch.psi.daq.queryrest.config; + +import java.util.Arrays; +import java.util.Set; +import java.util.stream.Collectors; + +import javax.annotation.PostConstruct; +import javax.annotation.Resource; + +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.PropertySource; +import org.springframework.core.env.Environment; +import org.springframework.web.servlet.config.annotation.CorsRegistry; +import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter; + + +@Configuration +@PropertySource(value = {"classpath:queryrest.properties"}) +@PropertySource(value = {"file:${user.home}/.config/daq/queryrest.properties"}, ignoreResourceNotFound = true) +public class QueryRestConfigCORS extends WebMvcConfigurerAdapter { + + @Resource + private Environment env; + + @Resource(name = QueryRestConfig.BEAN_NAME_CORS_ALLOWEDORIGINS) + private String configuredOrigins; + + @Resource(name = QueryRestConfig.BEAN_NAME_CORS_FORCEALLHEADERS) + private Boolean forceAllHeaders; + + private String[] allowedOrigins; + + @PostConstruct + public void afterPropertiesSet() { + Set origs = Arrays.stream(configuredOrigins.split(",")).map(s -> s.trim()).collect(Collectors.toSet()); + allowedOrigins = origs.toArray(new String[origs.size()]); + } + + /** + * @{inheritDoc + */ + @Override + public void addCorsMappings(CorsRegistry registry) { + + boolean corsEnabled = Boolean.valueOf(env.getProperty("queryrest.cors.enable", "false")); + boolean forceAll = Boolean.valueOf(env.getProperty("queryrest.cors.forceallheaders", "false")); + + if (corsEnabled) { + + if (forceAll) { + registry + .addMapping("/**") + .allowedOrigins("*") + .allowedMethods("GET", "POST", "PUT", "DELETE", "OPTIONS") + .allowedHeaders("Origin", "Accept", "X-Requested-With", "Content-Type", "Access-Control-Request-Method", + "Access-Control-Request-Headers") + .allowCredentials(true) + .maxAge(1800); + } else if (!forceAll) { + // see https://spring.io/blog/2015/06/08/cors-support-in-spring-framework + registry + .addMapping("/**") + .allowedOrigins(allowedOrigins) + .allowedMethods("GET", "POST", "PUT", "DELETE", "OPTIONS") + .allowedHeaders("Origin", "Accept", "X-Requested-With", "Content-Type", "Access-Control-Request-Method", + "Access-Control-Request-Headers") + .allowCredentials(true) + .maxAge(1800); + } + } + } +} diff --git a/src/main/java/ch/psi/daq/queryrest/filter/CorsFilter.java b/src/main/java/ch/psi/daq/queryrest/filter/CorsFilter.java deleted file mode 100644 index 017d202..0000000 --- a/src/main/java/ch/psi/daq/queryrest/filter/CorsFilter.java +++ /dev/null @@ -1,79 +0,0 @@ -package ch.psi.daq.queryrest.filter; - -import java.io.IOException; -import java.util.Arrays; -import java.util.Set; -import java.util.stream.Collectors; - -import javax.annotation.PostConstruct; -import javax.annotation.Resource; -import javax.servlet.FilterChain; -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -import org.springframework.web.filter.OncePerRequestFilter; - -import ch.psi.daq.queryrest.config.QueryRestConfig; - -public class CorsFilter extends OncePerRequestFilter { - - private static final String ALLOW_ORIGIN_HEADER = "Access-Control-Allow-Origin"; - - @Resource(name = QueryRestConfig.BEAN_NAME_CORS_ALLOWEDORIGINS) - private String configuredOrigins; - - @Resource(name = QueryRestConfig.BEAN_NAME_CORS_FORCEALLHEADERS) - private Boolean forceAllHeaders; - - private Set allowedOrigins; - - @PostConstruct - public void afterPropertiesSet(){ - allowedOrigins = Arrays.stream(configuredOrigins.split(",")) - .map(s -> s.trim()) - .collect(Collectors.toSet()); - } - - @Override - protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) - throws ServletException, IOException { - String originHeader = request.getHeader("Origin"); - if (forceAllHeaders) { - // include headers no matter what - good for development - if (allowedOrigins.contains(originHeader)) { - response.addHeader(ALLOW_ORIGIN_HEADER, originHeader); - } else { - response.addHeader(ALLOW_ORIGIN_HEADER, "*"); - } - setDefaultCorsHeaders(response); - - } else if (request.getHeader("Access-Control-Request-Method") != null && "OPTIONS".equals(request.getMethod())) { - // this is for 'real' Cross-site browser requests - if (allowedOrigins.contains(originHeader)) { - response.addHeader(ALLOW_ORIGIN_HEADER, originHeader); - } - setDefaultCorsHeaders(response); - } - - filterChain.doFilter(request, response); - } - - - private void setDefaultCorsHeaders(HttpServletResponse response) { - response.addHeader("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"); - response.addHeader("Access-Control-Allow-Headers", "Origin, Authorization, Accept, Content-Type"); - response.addHeader("Access-Control-Max-Age", "1800"); - } - - - public void setConfiguredOrigins(String configuredOrigins) { - this.configuredOrigins = configuredOrigins; - } - - - public void setForceAllHeaders(boolean forceAllHeaders) { - this.forceAllHeaders = forceAllHeaders; - } - -} diff --git a/src/main/resources/queryrest.properties b/src/main/resources/queryrest.properties index 0baa29d..5fbaf69 100644 --- a/src/main/resources/queryrest.properties +++ b/src/main/resources/queryrest.properties @@ -7,8 +7,12 @@ queryrest.default.response.aggregations=min,mean,max # enables / disables the CORS servlet filter. Adds multiple CORS headers to the response queryrest.cors.enable=true -# includes the CORS headers no matter what request or preflight was sent. If an Origin header is set, this header will be used. + +# Includes the CORS headers no matter what request or preflight was sent. If an Origin header is set, this header will be used. # If no Origin header is set, '*' will be used. -queryrest.cors.forceallheaders=true -# defines the allowed origins for CORS requests. Only relevant if queryrest.enableCORS==true (see above). -queryrest.cors.allowedorigins=http://localhost:8080, * \ No newline at end of file +queryrest.cors.forceallheaders=false + +# Defines the allowed origins for CORS requests. Only relevant if queryrest.enableCORS==true (see above). +# If this is set to '*', then all requests are allowed, from any source. If it's set to, say, http://ui-data-api.psi.ch, then only requests +# originating from that domain (Origin header set to that value) will be allowed. Otherwise a 403 error will be returned. +queryrest.cors.allowedorigins=* diff --git a/src/test/java/ch/psi/daq/test/queryrest/controller/QueryRestControllerCsvTest.java b/src/test/java/ch/psi/daq/test/queryrest/controller/QueryRestControllerCsvTest.java index 1b8bf14..88a703a 100644 --- a/src/test/java/ch/psi/daq/test/queryrest/controller/QueryRestControllerCsvTest.java +++ b/src/test/java/ch/psi/daq/test/queryrest/controller/QueryRestControllerCsvTest.java @@ -40,7 +40,6 @@ import ch.psi.daq.query.model.impl.DAQQueries; import ch.psi.daq.query.model.impl.DAQQuery; import ch.psi.daq.query.model.impl.DAQQueryElement; import ch.psi.daq.queryrest.controller.QueryRestController; -import ch.psi.daq.queryrest.filter.CorsFilter; import ch.psi.daq.test.cassandra.admin.CassandraTestAdmin; import ch.psi.daq.test.queryrest.AbstractDaqRestTest; @@ -55,9 +54,6 @@ public class QueryRestControllerCsvTest extends AbstractDaqRestTest { @Resource private CassandraDataGen dataGen; - @Resource - private CorsFilter corsFilter; - public static final String TEST_CHANNEL = "testChannel"; public static final String TEST_CHANNEL_01 = TEST_CHANNEL + "1"; public static final String TEST_CHANNEL_02 = TEST_CHANNEL + "2"; diff --git a/src/test/java/ch/psi/daq/test/queryrest/controller/QueryRestControllerJsonTest.java b/src/test/java/ch/psi/daq/test/queryrest/controller/QueryRestControllerJsonTest.java index f8afe18..ceb5db4 100644 --- a/src/test/java/ch/psi/daq/test/queryrest/controller/QueryRestControllerJsonTest.java +++ b/src/test/java/ch/psi/daq/test/queryrest/controller/QueryRestControllerJsonTest.java @@ -27,7 +27,6 @@ import ch.psi.daq.query.model.impl.DAQQuery; import ch.psi.daq.query.model.impl.DAQQueryElement; import ch.psi.daq.query.request.ChannelsRequest; import ch.psi.daq.queryrest.controller.QueryRestController; -import ch.psi.daq.queryrest.filter.CorsFilter; import ch.psi.daq.test.cassandra.admin.CassandraTestAdmin; import ch.psi.daq.test.queryrest.AbstractDaqRestTest; @@ -42,8 +41,6 @@ public class QueryRestControllerJsonTest extends AbstractDaqRestTest { @Resource private CassandraDataGen dataGen; - @Resource - private CorsFilter corsFilter; public static final String TEST_CHANNEL_01 = "testChannel1"; public static final String TEST_CHANNEL_02 = "testChannel2"; @@ -195,8 +192,7 @@ public class QueryRestControllerJsonTest extends AbstractDaqRestTest { @Test public void testCorsFilterNoHeaders() throws Exception { - corsFilter.setForceAllHeaders(false); - this.mockMvc = MockMvcBuilders.webAppContextSetup(webApplicationContext).addFilters(corsFilter).build(); + this.mockMvc = MockMvcBuilders.webAppContextSetup(webApplicationContext).build(); this.mockMvc.perform( MockMvcRequestBuilders @@ -211,71 +207,29 @@ public class QueryRestControllerJsonTest extends AbstractDaqRestTest { @Test public void testCorsFilterIncludesHeaders() throws Exception { // all headers are set - this.mockMvc = MockMvcBuilders.webAppContextSetup(webApplicationContext).addFilters(corsFilter).build(); - + this.mockMvc = MockMvcBuilders.webAppContextSetup(webApplicationContext).build(); + + // curl -H "Origin: *" -H "Access-Control-Request-Method: POST" -X OPTIONS -v http://localhost:8080/channels this.mockMvc.perform( MockMvcRequestBuilders .options(QueryRestController.CHANNELS) .header("Origin", "*") + .header("Access-Control-Request-Method", "POST") .contentType(MediaType.APPLICATION_JSON)) .andDo(MockMvcResultHandlers.print()) .andExpect(MockMvcResultMatchers.status().isOk()) - // we didn't set the 'Origin' header so no access-control .andExpect(MockMvcResultMatchers.header().string("Access-Control-Allow-Origin", "*")); + // curl -H "Origin: http://localhost:8080" -H "Access-Control-Request-Method: POST" -X OPTIONS -v http://localhost:8080/channels this.mockMvc.perform( MockMvcRequestBuilders .options(QueryRestController.CHANNELS) .header("Origin", "http://localhost:8080") + .header("Access-Control-Request-Method", "POST") .contentType(MediaType.APPLICATION_JSON)) .andDo(MockMvcResultHandlers.print()) .andExpect(MockMvcResultMatchers.status().isOk()) - // we didn't set the 'Origin' header so no access-control .andExpect(MockMvcResultMatchers.header().string("Access-Control-Allow-Origin", "http://localhost:8080")); - - this.mockMvc.perform( - MockMvcRequestBuilders - .options(QueryRestController.CHANNELS) - .header("Origin", "someBogusDomain.com") - .contentType(MediaType.APPLICATION_JSON)) - .andDo(MockMvcResultHandlers.print()) - .andExpect(MockMvcResultMatchers.status().isOk()) - .andExpect(MockMvcResultMatchers.header().string("Access-Control-Allow-Origin", "*")); - - } - - @Test - public void testCorsFilterMismatchSpecificOrigin() throws Exception { - corsFilter.setForceAllHeaders(false); - this.mockMvc = MockMvcBuilders.webAppContextSetup(webApplicationContext).addFilters(corsFilter).build(); - - this.mockMvc - .perform( - MockMvcRequestBuilders - .options(QueryRestController.CHANNELS) - .header("Origin", "*") - .header("Access-Control-Request-Method", "GET") - .contentType(MediaType.APPLICATION_JSON)) - .andDo(MockMvcResultHandlers.print()) - .andExpect(MockMvcResultMatchers.status().isOk()) - .andExpect(MockMvcResultMatchers.header().string("Access-Control-Allow-Origin", "*")) - .andExpect( - MockMvcResultMatchers.header().string("Access-Control-Allow-Headers", - "Origin, Authorization, Accept, Content-Type")); - - this.mockMvc - .perform( - MockMvcRequestBuilders - .options(QueryRestController.CHANNELS) - .header("Origin", "someBogusDomain.com") - .header("Access-Control-Request-Method", "GET") - .contentType(MediaType.APPLICATION_JSON)) - .andDo(MockMvcResultHandlers.print()) - .andExpect(MockMvcResultMatchers.status().isOk()) - .andExpect(MockMvcResultMatchers.header().doesNotExist("Access-Control-Allow-Origin")) - .andExpect( - MockMvcResultMatchers.header().string("Access-Control-Allow-Headers", - "Origin, Authorization, Accept, Content-Type")); } @Test