From 324f75ce76646e944771ab5a0b47432a83577584 Mon Sep 17 00:00:00 2001 From: lin <648540858@qq.com> Date: Thu, 21 May 2026 12:50:35 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=A4=9A=E7=AB=AF=E5=8F=A3?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F=E6=94=AF=E6=8C=81=E9=9A=8F=E6=9C=BASSRC?= =?UTF-8?q?=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../genersoft/iot/vmp/conf/UserSetting.java | 5 + .../iot/vmp/gb28181/session/SSRCFactory.java | 96 +++--- .../iot/vmp/service/bean/RTPServerParam.java | 4 + .../service/impl/RtpServerServiceImpl.java | 83 +++-- src/main/resources/application-dev.yml | 2 + .../dao/provider/ChannelProviderTest.java | 293 ++++++++++++++++++ .../provider/DeviceChannelProviderTest.java | 177 +++++++++++ 7 files changed, 588 insertions(+), 72 deletions(-) create mode 100644 src/test/java/com/genersoft/iot/vmp/gb28181/dao/provider/ChannelProviderTest.java create mode 100644 src/test/java/com/genersoft/iot/vmp/gb28181/dao/provider/DeviceChannelProviderTest.java diff --git a/src/main/java/com/genersoft/iot/vmp/conf/UserSetting.java b/src/main/java/com/genersoft/iot/vmp/conf/UserSetting.java index adb733384..28c41b45d 100644 --- a/src/main/java/com/genersoft/iot/vmp/conf/UserSetting.java +++ b/src/main/java/com/genersoft/iot/vmp/conf/UserSetting.java @@ -128,6 +128,11 @@ public class UserSetting { */ private Boolean useCustomSsrcForParentInvite = Boolean.TRUE; + /** + * 多端口模式使用随机SSRC,端口区分流,SSRC允许重复 + */ + private Boolean ssrcRandom = Boolean.FALSE; + /** * 开启接口文档页面。 默认开启,生产环境建议关闭,遇到swagger相关的漏洞时也可以关闭 */ diff --git a/src/main/java/com/genersoft/iot/vmp/gb28181/session/SSRCFactory.java b/src/main/java/com/genersoft/iot/vmp/gb28181/session/SSRCFactory.java index 1f79971e1..29556da2d 100755 --- a/src/main/java/com/genersoft/iot/vmp/gb28181/session/SSRCFactory.java +++ b/src/main/java/com/genersoft/iot/vmp/gb28181/session/SSRCFactory.java @@ -2,6 +2,7 @@ package com.genersoft.iot.vmp.gb28181.session; import com.alibaba.fastjson2.JSONObject; import com.genersoft.iot.vmp.conf.SipConfig; +import com.genersoft.iot.vmp.conf.UserSetting; import com.genersoft.iot.vmp.media.bean.MediaServer; import com.genersoft.iot.vmp.media.service.IMediaServerService; import com.genersoft.iot.vmp.media.zlm.ZLMRESTfulUtils; @@ -24,6 +25,7 @@ import java.util.concurrent.TimeUnit; public class SSRCFactory { private final ConcurrentHashMap usedMap = new ConcurrentHashMap<>(); + private final ConcurrentHashMap lockMap = new ConcurrentHashMap<>(); private final ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor(r -> { Thread t = new Thread(r, "ssrc-rebuild"); t.setDaemon(true); @@ -39,6 +41,9 @@ public class SSRCFactory { @Autowired private SipConfig sipConfig; + @Autowired + private UserSetting userSetting; + private String domainPart; @PostConstruct @@ -58,53 +63,68 @@ public class SSRCFactory { return suffix != null ? "1" + suffix : null; } + public String getPlaySsrcRandom() { + return "0" + domainPart + String.format("%04d", ThreadLocalRandom.current().nextInt(10000)); + } + + public String getPlayBackSsrcRandom() { + return "1" + domainPart + String.format("%04d", ThreadLocalRandom.current().nextInt(10000)); + } + private String allocate(String mediaServerId) { - BitSet bits = usedMap.computeIfAbsent(mediaServerId, k -> new BitSet(10000)); - int start = ThreadLocalRandom.current().nextInt(10000); - int index = start; - do { - if (!bits.get(index)) { - bits.set(index); - return domainPart + String.format("%04d", index); - } - index = (index + 1) % 10000; - } while (index != start); - log.warn("[SSRC] 媒体节点 {} 的SSRC已用尽", mediaServerId); - return null; + synchronized (lockMap.computeIfAbsent(mediaServerId, k -> new Object())) { + BitSet bits = usedMap.computeIfAbsent(mediaServerId, k -> new BitSet(10000)); + int start = ThreadLocalRandom.current().nextInt(10000); + int index = start; + do { + if (!bits.get(index)) { + bits.set(index); + return domainPart + String.format("%04d", index); + } + index = (index + 1) % 10000; + } while (index != start); + log.warn("[SSRC] 媒体节点 {} 的SSRC已用尽", mediaServerId); + return null; + } } void rebuild() { List servers = mediaServerService.getAll(); for (MediaServer server : servers) { - BitSet bits = new BitSet(10000); - int count = 0; - try { - ZLMResult result = zlmresTfulUtils.getMediaList(server, null, null, "rtsp", null); - if (result != null && result.getCode() == 0 && result.getData() != null) { - List list = (List) result.getData(); - for (JSONObject obj : list) { - if (obj.getIntValue("originType") != 3) continue; - String originUrl = obj.getString("originUrl"); - if (originUrl == null) continue; - int idx = originUrl.lastIndexOf("/rtp/"); - if (idx == -1) continue; - try { - int suffix = (int) (Long.parseLong(originUrl.substring(idx + 5), 16) % 10000); - bits.set(suffix); - count++; - } catch (NumberFormatException ignored) { + if (server.isRtpEnable() && userSetting.getSsrcRandom()) { + continue; + } + synchronized (lockMap.computeIfAbsent(server.getId(), k -> new Object())) { + BitSet bits = new BitSet(10000); + int count = 0; + try { + ZLMResult result = zlmresTfulUtils.getMediaList(server, null, null, "rtsp", null); + if (result != null && result.getCode() == 0 && result.getData() != null) { + List list = (List) result.getData(); + for (JSONObject obj : list) { + if (obj.getIntValue("originType") != 3) continue; + String originUrl = obj.getString("originUrl"); + if (originUrl == null) continue; + int idx = originUrl.lastIndexOf("/rtp/"); + if (idx == -1) continue; + try { + int suffix = (int) (Long.parseLong(originUrl.substring(idx + 5), 16) % 10000); + bits.set(suffix); + count++; + } catch (NumberFormatException ignored) { + } } } + } catch (Exception e) { + log.warn("[SSRC重建] 查询媒体节点 {} 失败: {}", server.getId(), e.getMessage()); } - } catch (Exception e) { - log.warn("[SSRC重建] 查询媒体节点 {} 失败: {}", server.getId(), e.getMessage()); - } - usedMap.put(server.getId(), bits); - if (count > 8000) { - log.info("[SSRC重建] 媒体节点 {} 的SSRC使用率已超过80%,请注意扩展服务提升性能", server.getId()); - }else { - if (log.isDebugEnabled()) { - log.debug("[SSRC重建] 节点 {} 已占用 {} 个SSRC", server.getId(), count); + usedMap.put(server.getId(), bits); + if (count > 8000) { + log.info("[SSRC重建] 媒体节点 {} 的SSRC使用率已超过80%,请注意扩展服务提升性能", server.getId()); + } else { + if (log.isDebugEnabled()) { + log.debug("[SSRC重建] 节点 {} 已占用 {} 个SSRC", server.getId(), count); + } } } } diff --git a/src/main/java/com/genersoft/iot/vmp/service/bean/RTPServerParam.java b/src/main/java/com/genersoft/iot/vmp/service/bean/RTPServerParam.java index 6c2dd15d2..0868ccef5 100644 --- a/src/main/java/com/genersoft/iot/vmp/service/bean/RTPServerParam.java +++ b/src/main/java/com/genersoft/iot/vmp/service/bean/RTPServerParam.java @@ -17,6 +17,10 @@ public class RTPServerParam { private MediaServer mediaServer; private String app; private String streamId; + /** + * 传递给zlm创建rtp server的streamId,不填则使用streamId + */ + private String zlmStreamId; /** * 开启rtpServer时使用的ssrc,开启rtpServer时会根据这个ssrc进行校验,如果不填则不校验 */ diff --git a/src/main/java/com/genersoft/iot/vmp/service/impl/RtpServerServiceImpl.java b/src/main/java/com/genersoft/iot/vmp/service/impl/RtpServerServiceImpl.java index 9e5732087..342b29d08 100644 --- a/src/main/java/com/genersoft/iot/vmp/service/impl/RtpServerServiceImpl.java +++ b/src/main/java/com/genersoft/iot/vmp/service/impl/RtpServerServiceImpl.java @@ -90,12 +90,10 @@ public class RtpServerServiceImpl implements IReceiveRtpServerService { final String ssrc; if (presetSSRC != null) { ssrc = presetSSRC; - }else { - if (playback) { - ssrc = ssrcFactory.getPlayBackSsrc(mediaServer.getId()); - }else { - ssrc = ssrcFactory.getPlaySsrc(mediaServer.getId()); - } + } else if (mediaServer.isRtpEnable() && userSetting.getSsrcRandom()) { + ssrc = playback ? ssrcFactory.getPlayBackSsrcRandom() : ssrcFactory.getPlaySsrcRandom(); + } else { + ssrc = playback ? ssrcFactory.getPlayBackSsrc(mediaServer.getId()) : ssrcFactory.getPlaySsrc(mediaServer.getId()); } if (streamId == null) { streamId = String.format("%08x", Long.parseLong(ssrc)).toUpperCase(); @@ -139,18 +137,14 @@ public class RtpServerServiceImpl implements IReceiveRtpServerService { final String ssrc; if (presetSSRC != null) { ssrc = presetSSRC; - }else { + } else if (mediaServer.isRtpEnable() && userSetting.getSsrcRandom()) { + ssrc = ssrcFactory.getPlaySsrcRandom(); + } else { ssrc = ssrcFactory.getPlaySsrc(mediaServer.getId()); } - String streamId; - String streamReplace = null; - if (mediaServer.isRtpEnable()) { - streamId = String.format("%s_%s", device.getDeviceId(), channel.getDeviceId()); - }else { - streamId = String.format("%08x", Long.parseLong(ssrc)).toUpperCase(); - streamReplace = String.format("%s_%s", device.getDeviceId(), channel.getDeviceId()); - } + String streamId = String.format("%08x", Long.parseLong(ssrc)).toUpperCase(); + String streamReplace = String.format("%s_%s", device.getDeviceId(), channel.getDeviceId()); int tcpMode = device.getStreamMode().equals("TCP-ACTIVE")? 2: (device.getStreamMode().equals("TCP-PASSIVE")? 1:0); @@ -161,8 +155,8 @@ public class RtpServerServiceImpl implements IReceiveRtpServerService { Long checkSsrc = device.isSsrcCheck() ? Long.parseLong(ssrc) : 0L; - SSRCInfo ssrcInfo = new SSRCInfo(0, ssrc, MediaStreamUtil.RTP_APP, streamReplace != null ? streamReplace : streamId); - openRtpServer(mediaServer, ssrcInfo, checkSsrc, !channel.isHasAudio(), false, tcpMode, callback); + SSRCInfo ssrcInfo = new SSRCInfo(0, ssrc, MediaStreamUtil.RTP_APP, streamReplace); + openRtpServer(mediaServer, ssrcInfo, checkSsrc, !channel.isHasAudio(), false, tcpMode, callback, streamId); addAuthenticateInfo(streamId, streamReplace, channel.isHasAudio(), record, null); return ssrcInfo; } @@ -180,17 +174,16 @@ public class RtpServerServiceImpl implements IReceiveRtpServerService { } // 获取 mediaServer 可用的 ssrc - String ssrc = ssrcFactory.getPlayBackSsrc(mediaServer.getId()); - - String streamId; - String streamReplace = null; - if (mediaServer.isRtpEnable()) { - streamId = getPlaybackStream(device, channel, startTime, endTime); - }else { - streamId = String.format("%08x", Long.parseLong(ssrc)).toUpperCase(); - streamReplace = getPlaybackStream(device, channel, startTime, endTime); + String ssrc; + if (mediaServer.isRtpEnable() && userSetting.getSsrcRandom()) { + ssrc = ssrcFactory.getPlayBackSsrcRandom(); + } else { + ssrc = ssrcFactory.getPlayBackSsrc(mediaServer.getId()); } + String streamId = String.format("%08x", Long.parseLong(ssrc)).toUpperCase(); + String streamReplace = getPlaybackStream(device, channel, startTime, endTime); + int tcpMode = device.getStreamMode().equals("TCP-ACTIVE")? 2: (device.getStreamMode().equals("TCP-PASSIVE")? 1:0); if (device.isSsrcCheck() && tcpMode > 0) { @@ -200,8 +193,8 @@ public class RtpServerServiceImpl implements IReceiveRtpServerService { Long checkSsrc = device.isSsrcCheck() ? Long.parseLong(ssrc) : 0L; - SSRCInfo ssrcInfo = new SSRCInfo(0, ssrc, MediaStreamUtil.RTP_APP, streamReplace != null ? streamReplace : streamId); - openRtpServer(mediaServer, ssrcInfo, checkSsrc, !channel.isHasAudio(), false, tcpMode, callback); + SSRCInfo ssrcInfo = new SSRCInfo(0, ssrc, MediaStreamUtil.RTP_APP, streamReplace); + openRtpServer(mediaServer, ssrcInfo, checkSsrc, !channel.isHasAudio(), false, tcpMode, callback, streamId); addAuthenticateInfo(streamId, streamReplace, channel.isHasAudio(), false,null); return ssrcInfo; } @@ -233,8 +226,18 @@ public class RtpServerServiceImpl implements IReceiveRtpServerService { int tcpMode = device.getStreamMode().equals("TCP-ACTIVE")? 2: (device.getStreamMode().equals("TCP-PASSIVE")? 1:0); // 获取 mediaServer 可用的 ssrc - String ssrc = ssrcFactory.getPlayBackSsrc(mediaServer.getId()); + String ssrc; + if (mediaServer.isRtpEnable() && userSetting.getSsrcRandom()) { + ssrc = ssrcFactory.getPlayBackSsrcRandom(); + } else { + ssrc = ssrcFactory.getPlayBackSsrc(mediaServer.getId()); + } + String streamId = String.format("%08x", Long.parseLong(ssrc)).toUpperCase(); + String streamReplace = String.format("%s_%s_%s_%s", device.getDeviceId(), channel.getDeviceId(), + startTime.replace("-", "").replace(":", "").replace(" ", ""), + endTime.replace("-", "").replace(":", "").replace(" ", "")); + if (device.isSsrcCheck() && tcpMode > 0) { // 目前zlm不支持 tcp模式更新ssrc,暂时关闭ssrc校验 log.warn("[开启国标录像下载RTP收流] 平台对接时下级可能自定义ssrc,但是tcp模式zlm收流目前无法更新ssrc,可能收流超时,此时请使用udp收流或者关闭ssrc校验"); @@ -242,12 +245,12 @@ public class RtpServerServiceImpl implements IReceiveRtpServerService { Long checkSsrc = device.isSsrcCheck() ? Long.parseLong(ssrc) : 0L; - SSRCInfo ssrcInfo = new SSRCInfo(0, ssrc, MediaStreamUtil.RTP_APP, streamId); - openRtpServer(mediaServer, ssrcInfo, checkSsrc, !channel.isHasAudio(), false, tcpMode, callback); + SSRCInfo ssrcInfo = new SSRCInfo(0, ssrc, MediaStreamUtil.RTP_APP, streamReplace); + openRtpServer(mediaServer, ssrcInfo, checkSsrc, !channel.isHasAudio(), false, tcpMode, callback, streamId); long difference = DateUtil.getDifference(startTime, endTime) / 1000; - addAuthenticateInfo(streamId, null, channel.isHasAudio(), true, (int) difference); + addAuthenticateInfo(streamId, streamReplace, channel.isHasAudio(), true, (int) difference); return ssrcInfo; } @@ -278,7 +281,12 @@ public class RtpServerServiceImpl implements IReceiveRtpServerService { } // 获取 mediaServer 可用的 ssrc - String ssrc = ssrcFactory.getPlaySsrc(mediaServer.getId()); + String ssrc; + if (mediaServer.isRtpEnable() && userSetting.getSsrcRandom()) { + ssrc = ssrcFactory.getPlaySsrcRandom(); + } else { + ssrc = ssrcFactory.getPlaySsrc(mediaServer.getId()); + } SSRCInfo ssrcInfo = new SSRCInfo(0, ssrc, MediaStreamUtil.RTP_APP, streamId); openRtpServer(mediaServer, ssrcInfo, 0L, false, true, tcpMode, callback); @@ -287,8 +295,14 @@ public class RtpServerServiceImpl implements IReceiveRtpServerService { private void openRtpServer(MediaServer mediaServer, SSRCInfo ssrcInfo, Long checkSsrc, boolean disableAuto, boolean onlyAuto, int tcpMode, ErrorCallback callback) { + openRtpServer(mediaServer, ssrcInfo, checkSsrc, disableAuto, onlyAuto, tcpMode, callback, null); + } + + private void openRtpServer(MediaServer mediaServer, SSRCInfo ssrcInfo, Long checkSsrc, boolean disableAuto, boolean onlyAuto, int tcpMode, + ErrorCallback callback, String zlmStreamId) { RTPServerParam rtpServerParam = new RTPServerParam(mediaServer, MediaStreamUtil.RTP_APP, ssrcInfo.getStream(), checkSsrc, null, onlyAuto, disableAuto, false, tcpMode); + rtpServerParam.setZlmStreamId(zlmStreamId); int rtpServerPort = openCommonRTPServer(rtpServerParam, ((code, msg, data) -> { if (code == InviteErrorCode.SUCCESS.getCode()) { OpenRTPServerResult openRTPServerResult = new OpenRTPServerResult(); @@ -336,7 +350,8 @@ public class RtpServerServiceImpl implements IReceiveRtpServerService { int rtpServerPort; if (rtpServerParam.getMediaServer().isRtpEnable()) { - rtpServerPort = mediaServerService.createRTPServer(rtpServerParam.getMediaServer(), rtpServerParam.getApp(), rtpServerParam.getStreamId(), + String effectiveStreamId = rtpServerParam.getZlmStreamId() != null ? rtpServerParam.getZlmStreamId() : rtpServerParam.getStreamId(); + rtpServerPort = mediaServerService.createRTPServer(rtpServerParam.getMediaServer(), rtpServerParam.getApp(), effectiveStreamId, Objects.requireNonNullElse(rtpServerParam.getSsrc(), 0L), rtpServerParam.getPort(), rtpServerParam.isOnlyAuto(), rtpServerParam.isDisableAudio(), rtpServerParam.isReUsePort(), rtpServerParam.getTcpMode()); } else { diff --git a/src/main/resources/application-dev.yml b/src/main/resources/application-dev.yml index 4b10dc6a5..b052e1e06 100644 --- a/src/main/resources/application-dev.yml +++ b/src/main/resources/application-dev.yml @@ -99,6 +99,8 @@ media: user-settings: # 点播/录像回放 等待超时时间,单位:毫秒 play-timeout: 180000 + # [可选] 多端口模式使用随机SSRC,SSRC允许重复(默认false) + ssrc-random: false # [可选] 自动点播, 使用固定流地址进行播放时,如果未点播则自动进行点播, 需要rtp.enable=true auto-apply-play: true # 推流直播是否录制 diff --git a/src/test/java/com/genersoft/iot/vmp/gb28181/dao/provider/ChannelProviderTest.java b/src/test/java/com/genersoft/iot/vmp/gb28181/dao/provider/ChannelProviderTest.java new file mode 100644 index 000000000..c1ce63699 --- /dev/null +++ b/src/test/java/com/genersoft/iot/vmp/gb28181/dao/provider/ChannelProviderTest.java @@ -0,0 +1,293 @@ +package com.genersoft.iot.vmp.gb28181.dao.provider; + +import com.genersoft.iot.vmp.gb28181.bean.Device; +import com.genersoft.iot.vmp.gb28181.bean.Group; +import com.genersoft.iot.vmp.web.custom.bean.CameraGroup; +import com.genersoft.iot.vmp.web.custom.bean.Point; +import org.junit.jupiter.api.Test; + +import java.util.*; + +import static org.junit.jupiter.api.Assertions.*; + +class ChannelProviderTest { + + private final ChannelProvider provider = new ChannelProvider(); + + // ========== queryByGbDeviceIds ========== + + @Test + void queryByGbDeviceIds_shouldUseBindVariables() { + Map params = new HashMap<>(); + params.put("deviceIds", Arrays.asList("DEV001", "DEV002")); + String sql = provider.queryByGbDeviceIds(params); + assertTrue(sql.contains("#{deviceIds[0]}"), "should use #{deviceIds[0]}"); + assertTrue(sql.contains("#{deviceIds[1]}"), "should use #{deviceIds[1]}"); + assertFalse(sql.contains("'DEV001'"), "should not contain raw quoted value"); + assertFalse(sql.contains("'DEV002'"), "should not contain raw quoted value"); + } + + @Test + void queryByGbDeviceIds_shouldNotQuoteBindVariables() { + Map params = new HashMap<>(); + params.put("deviceIds", Collections.singletonList("INJECT' OR 1=1 --")); + String sql = provider.queryByGbDeviceIds(params); + assertTrue(sql.contains("#{deviceIds[0]}"), "should use bind variable for injection attempt"); + assertFalse(sql.contains("1=1"), "should not contain injection payload in SQL"); + } + + // ========== queryByGroupList ========== + + @Test + void queryByGroupList_shouldUseBindVariables() { + Map params = new HashMap<>(); + Group g1 = new Group(); + g1.setDeviceId("GRP001"); + Group g2 = new Group(); + g2.setDeviceId("GRP002"); + params.put("groupList", Arrays.asList(g1, g2)); + String sql = provider.queryByGroupList(params); + assertTrue(sql.contains("#{groupList[0].deviceId}"), "should use #{groupList[0].deviceId}"); + assertTrue(sql.contains("#{groupList[1].deviceId}"), "should use #{groupList[1].deviceId}"); + assertFalse(sql.contains("GRP001"), "should not contain raw deviceId"); + assertFalse(sql.contains("GRP002"), "should not contain raw deviceId"); + } + + // ========== queryOnlineListsByGbDeviceIds ========== + + @Test + void queryOnlineListsByGbDeviceIds_shouldUseBindVariables() { + Map params = new HashMap<>(); + Device d1 = new Device(); + d1.setId(101); + Device d2 = new Device(); + d2.setId(102); + params.put("deviceList", Arrays.asList(d1, d2)); + String sql = provider.queryOnlineListsByGbDeviceIds(params); + assertTrue(sql.contains("#{deviceList[0].id}"), "should use #{deviceList[0].id}"); + assertTrue(sql.contains("#{deviceList[1].id}"), "should use #{deviceList[1].id}"); + assertFalse(sql.contains("101"), "should not contain raw id"); + assertFalse(sql.contains("102"), "should not contain raw id"); + } + + @Test + void queryOnlineListsByGbDeviceIds_withEmptyList_shouldNotHaveInClause() { + Map params = new HashMap<>(); + params.put("deviceList", Collections.emptyList()); + String sql = provider.queryOnlineListsByGbDeviceIds(params); + assertFalse(sql.contains("data_device_id in ("), "should not have IN clause when empty"); + } + + @Test + void queryOnlineListsByGbDeviceIds_withNullList_shouldNotHaveInClause() { + Map params = new HashMap<>(); + params.put("deviceList", null); + String sql = provider.queryOnlineListsByGbDeviceIds(params); + assertFalse(sql.contains("data_device_id in ("), "should not have IN clause when null"); + } + + // ========== queryListWithChildForSy ========== + + @Test + void queryListWithChildForSy_shouldUseBindVariables() { + Map params = new HashMap<>(); + CameraGroup cg1 = new CameraGroup(); + cg1.setDeviceId("CG001"); + CameraGroup cg2 = new CameraGroup(); + cg2.setDeviceId("CG002"); + params.put("groupList", Arrays.asList(cg1, cg2)); + String sql = provider.queryListWithChildForSy(params); + assertTrue(sql.contains("#{groupList[0].deviceId}"), "should use #{groupList[0].deviceId}"); + assertTrue(sql.contains("#{groupList[1].deviceId}"), "should use #{groupList[1].deviceId}"); + assertFalse(sql.contains("'CG001'"), "should not contain raw quoted value"); + } + + @Test + void queryListWithChildForSy_withQuery_shouldUseBindVariable() { + Map params = new HashMap<>(); + params.put("query", "search-term"); + params.put("groupList", Collections.singletonList(new CameraGroup())); + String sql = provider.queryListWithChildForSy(params); + assertTrue(sql.contains("#{query}"), "should use #{query} bind variable"); + assertFalse(sql.contains("search-term"), "should not contain raw query"); + } + + @Test + void queryListWithChildForSy_withSort_shouldUseWhitelist() { + Map params = new HashMap<>(); + params.put("groupList", Collections.singletonList(new CameraGroup())); + params.put("sortName", "gbId"); + params.put("order", true); + String sql = provider.queryListWithChildForSy(params); + assertTrue(sql.contains("order by gb_id"), "should sort by gb_id"); + assertTrue(sql.contains("ASC"), "should be ascending"); + } + + @Test + void queryListWithChildForSy_withSortDesc_shouldUseDesc() { + Map params = new HashMap<>(); + params.put("groupList", Collections.singletonList(new CameraGroup())); + params.put("sortName", "gbId"); + params.put("order", false); + String sql = provider.queryListWithChildForSy(params); + assertTrue(sql.contains("DESC"), "should be descending"); + } + + // ========== queryListInBox ========== + + @Test + void queryListInBox_shouldUseBindVariables() { + Map params = new HashMap<>(); + CameraGroup cg = new CameraGroup(); + cg.setDeviceId("BOX001"); + params.put("groupList", Collections.singletonList(cg)); + params.put("level", 3); + String sql = provider.queryListInBox(params); + assertTrue(sql.contains("#{groupList[0].deviceId}"), "should use bind variable"); + assertFalse(sql.contains("'BOX001'"), "should not contain raw value"); + assertTrue(sql.contains("#{level}"), "should use #{level} bind variable"); + assertTrue(sql.contains("#{minLongitude}"), "should use #{minLongitude}"); + assertTrue(sql.contains("#{maxLatitude}"), "should use #{maxLatitude}"); + } + + // ========== queryListInCircleForMysql ========== + + @Test + void queryListInCircleForMysql_shouldUseBindVariablesForGeometry() { + Map params = new HashMap<>(); + CameraGroup cg = new CameraGroup(); + cg.setDeviceId("CIRCLE001"); + params.put("groupList", Collections.singletonList(cg)); + params.put("centerLongitude", 116.397); + params.put("centerLatitude", 39.908); + params.put("radius", 1000); + + String sql = provider.queryListInCircleForMysql(params); + assertTrue(sql.contains("#{centerLongitude}"), "should use #{centerLongitude} bind variable"); + assertTrue(sql.contains("#{centerLatitude}"), "should use #{centerLatitude} bind variable"); + assertTrue(sql.contains("#{radius}"), "should use #{radius} bind variable"); + assertFalse(sql.contains("116.397"), "should not contain raw longitude"); + assertFalse(sql.contains("39.908"), "should not contain raw latitude"); + assertTrue(sql.contains("CONCAT('point(', #{centerLongitude}, ' ', #{centerLatitude}, ')')"), + "should build WKT via CONCAT with bind variables"); + } + + // ========== queryListInCircleForKingBase ========== + + @Test + void queryListInCircleForKingBase_shouldUseBindVariablesForGeometry() { + Map params = new HashMap<>(); + CameraGroup cg = new CameraGroup(); + cg.setDeviceId("CIRCLE002"); + params.put("groupList", Collections.singletonList(cg)); + params.put("centerLongitude", 121.473); + params.put("centerLatitude", 31.230); + params.put("radius", 500); + + String sql = provider.queryListInCircleForKingBase(params); + assertTrue(sql.contains("#{centerLongitude}"), "should use #{centerLongitude}"); + assertTrue(sql.contains("#{centerLatitude}"), "should use #{centerLatitude}"); + assertTrue(sql.contains("#{radius}"), "should use #{radius}"); + assertFalse(sql.contains("121.473"), "should not contain raw longitude"); + assertFalse(sql.contains("31.230"), "should not contain raw latitude"); + assertTrue(sql.contains("CONCAT('point(', #{centerLongitude}, ' ', #{centerLatitude}, ')')"), + "should build WKT via CONCAT with bind variables"); + } + + // ========== queryListInPolygonForMysql ========== + + @Test + void queryListInPolygonForMysql_shouldUseBindVariablesForPoints() { + Map params = new HashMap<>(); + CameraGroup cg = new CameraGroup(); + cg.setDeviceId("POLY001"); + params.put("groupList", Collections.singletonList(cg)); + + List points = new ArrayList<>(); + Point p1 = new Point(); + p1.setLng(116.0); + p1.setLat(39.0); + Point p2 = new Point(); + p2.setLng(117.0); + p2.setLat(40.0); + points.add(p1); + points.add(p2); + params.put("pointList", points); + + String sql = provider.queryListInPolygonForMysql(params); + assertTrue(sql.contains("#{pointList[0].lng}"), "should use #{pointList[0].lng}"); + assertTrue(sql.contains("#{pointList[0].lat}"), "should use #{pointList[0].lat}"); + assertTrue(sql.contains("#{pointList[1].lng}"), "should use #{pointList[1].lng}"); + assertTrue(sql.contains("#{pointList[1].lat}"), "should use #{pointList[1].lat}"); + assertFalse(sql.contains("116.0"), "should not contain raw lng"); + assertFalse(sql.contains("117.0"), "should not contain raw lat"); + assertTrue(sql.contains("CONCAT('POLYGON(('"), "should use CONCAT to build polygon WKT"); + } + + // ========== queryListInPolygonForKingBase ========== + + @Test + void queryListInPolygonForKingBase_shouldUseBindVariablesForPoints() { + Map params = new HashMap<>(); + CameraGroup cg = new CameraGroup(); + cg.setDeviceId("POLY002"); + params.put("groupList", Collections.singletonList(cg)); + + List points = new ArrayList<>(); + Point p1 = new Point(); + p1.setLng(116.0); + p1.setLat(39.0); + points.add(p1); + params.put("pointList", points); + + String sql = provider.queryListInPolygonForKingBase(params); + assertTrue(sql.contains("#{pointList[0].lng}"), "should use #{pointList[0].lng}"); + assertTrue(sql.contains("#{pointList[0].lat}"), "should use #{pointList[0].lat}"); + assertFalse(sql.contains("116.0"), "should not contain raw lng"); + assertFalse(sql.contains("39.0"), "should not contain raw lat"); + assertTrue(sql.contains("ST_MakePoint"), "should use KingBase specific function"); + } + + // ========== queryListInCircleForMysql with injection attempt ========== + + @Test + void queryListInCircleForMysql_shouldNotContainInjectionPayload() { + Map params = new HashMap<>(); + CameraGroup cg = new CameraGroup(); + cg.setDeviceId("NORMAL"); + params.put("groupList", Collections.singletonList(cg)); + params.put("centerLongitude", "0) OR 1=1 -- "); + params.put("centerLatitude", "0"); + params.put("radius", 1000); + + String sql = provider.queryListInCircleForMysql(params); + assertTrue(sql.contains("#{centerLongitude}"), "should use bind variable for injection payload"); + assertFalse(sql.contains("1=1"), "should not contain 1=1 in SQL text"); + assertFalse(sql.contains("OR 1=1"), "should not contain injection"); + } + + // ========== queryByGbDeviceIds single element ========== + + @Test + void queryByGbDeviceIds_withSingleElement() { + Map params = new HashMap<>(); + params.put("deviceIds", Collections.singletonList("SINGLE01")); + String sql = provider.queryByGbDeviceIds(params); + assertEquals(1, countOccurrences(sql, "#{deviceIds[0]}"), + "should have exactly one bind variable for single element"); + assertFalse(sql.contains("#{deviceIds[0]},"), "should not have trailing comma in IN clause"); + assertFalse(sql.contains(",#{deviceIds[0]}"), "should not have leading comma in IN clause"); + } + + // ========== helper ========== + + private int countOccurrences(String str, String substr) { + int count = 0; + int idx = 0; + while ((idx = str.indexOf(substr, idx)) != -1) { + count++; + idx += substr.length(); + } + return count; + } +} diff --git a/src/test/java/com/genersoft/iot/vmp/gb28181/dao/provider/DeviceChannelProviderTest.java b/src/test/java/com/genersoft/iot/vmp/gb28181/dao/provider/DeviceChannelProviderTest.java new file mode 100644 index 000000000..bf0d1f890 --- /dev/null +++ b/src/test/java/com/genersoft/iot/vmp/gb28181/dao/provider/DeviceChannelProviderTest.java @@ -0,0 +1,177 @@ +package com.genersoft.iot.vmp.gb28181.dao.provider; + +import org.junit.jupiter.api.Test; + +import java.util.*; + +import static org.junit.jupiter.api.Assertions.*; + +class DeviceChannelProviderTest { + + private final DeviceChannelProvider provider = new DeviceChannelProvider(); + + @Test + void queryChannels_withChannelIds_shouldUseBindVariables() { + Map params = new HashMap<>(); + params.put("channelIds", Arrays.asList("CH001", "CH002", "CH003")); + String sql = provider.queryChannels(params); + assertTrue(sql.contains("#{channelIds[0]}"), "should use #{channelIds[0]}"); + assertTrue(sql.contains("#{channelIds[1]}"), "should use #{channelIds[1]}"); + assertTrue(sql.contains("#{channelIds[2]}"), "should use #{channelIds[2]}"); + assertFalse(sql.contains("CH001"), "should not contain raw channel id"); + assertFalse(sql.contains("CH002"), "should not contain raw channel id"); + assertTrue(sql.contains("dc.device_id in ("), "should have IN clause"); + } + + @Test + void queryChannels_withoutChannelIds_shouldNotContainInClause() { + Map params = new HashMap<>(); + String sql = provider.queryChannels(params); + assertFalse(sql.contains("device_id in ("), "should not have IN clause when no channelIds"); + assertTrue(sql.contains("ORDER BY"), "should have ORDER BY"); + } + + @Test + void queryChannels_withEmptyChannelIds_shouldNotContainInClause() { + Map params = new HashMap<>(); + params.put("channelIds", Collections.emptyList()); + String sql = provider.queryChannels(params); + assertFalse(sql.contains("device_id in ("), "should not have IN clause when channelIds empty"); + } + + @Test + void queryChannels_withDataDeviceId_shouldUseBindVariable() { + Map params = new HashMap<>(); + params.put("dataDeviceId", 42); + String sql = provider.queryChannels(params); + assertTrue(sql.contains("#{dataDeviceId}"), "should use #{dataDeviceId}"); + } + + @Test + void queryChannels_withQuery_shouldUseBindVariable() { + Map params = new HashMap<>(); + params.put("query", "test"); + String sql = provider.queryChannels(params); + assertTrue(sql.contains("#{query}"), "should use #{query} bind variable"); + assertFalse(sql.contains("'test'"), "should not contain raw query value"); + } + + @Test + void queryChannels_withOnline_shouldFilterStatus() { + Map params = new HashMap<>(); + params.put("online", true); + String sql = provider.queryChannels(params); + assertTrue(sql.contains("'ON'"), "should filter for ON status"); + assertFalse(sql.contains("'OFF'"), "should not filter for OFF status"); + } + + @Test + void queryChannels_withOffline_shouldFilterStatus() { + Map params = new HashMap<>(); + params.put("online", false); + String sql = provider.queryChannels(params); + assertTrue(sql.contains("'OFF'"), "should filter for OFF status"); + assertFalse(sql.contains("'ON'"), "should not filter for ON status"); + } + + @Test + void queryChannels_withBusinessGroupId_shouldFilter() { + Map params = new HashMap<>(); + params.put("businessGroupId", "group-1"); + String sql = provider.queryChannels(params); + assertTrue(sql.contains("#{businessGroupId}"), "should use bind variable"); + } + + @Test + void queryChannelsByDeviceDbId_shouldUseBindVariable() { + Map params = new HashMap<>(); + params.put("dataDeviceId", 99); + String sql = provider.queryChannelsByDeviceDbId(params); + assertTrue(sql.contains("#{dataDeviceId}"), "should use #{dataDeviceId}"); + } + + @Test + void queryChannelsByDeviceDbId_shouldFilterByDataType() { + Map params = new HashMap<>(); + params.put("dataDeviceId", 1); + String sql = provider.queryChannelsByDeviceDbId(params); + assertTrue(sql.contains("data_type = 1"), "should filter by GB28181 data type"); + } + + @Test + void getOne_shouldUseBindVariable() { + Map params = new HashMap<>(); + params.put("id", 123); + String sql = provider.getOne(params); + assertTrue(sql.contains("#{id}"), "should use #{id} bind variable"); + assertTrue(sql.contains("where"), "should have WHERE clause"); + assertTrue(sql.contains("#{id}"), "should have bind variable"); + } + + @Test + void getOneByDeviceId_shouldUseBindVariables() { + Map params = new HashMap<>(); + params.put("dataDeviceId", 10); + params.put("channelId", "CH999"); + String sql = provider.getOneByDeviceId(params); + assertTrue(sql.contains("#{dataDeviceId}"), "should use #{dataDeviceId}"); + assertTrue(sql.contains("#{channelId}"), "should use #{channelId}"); + } + + @Test + void queryByDeviceId_shouldUseBindVariable() { + Map params = new HashMap<>(); + params.put("gbDeviceId", "GB-TEST-123"); + String sql = provider.queryByDeviceId(params); + assertTrue(sql.contains("#{gbDeviceId}"), "should use #{gbDeviceId}"); + } + + @Test + void queryById_shouldUseBindVariable() { + Map params = new HashMap<>(); + params.put("gbId", 456); + String sql = provider.queryById(params); + assertTrue(sql.contains("#{gbId}"), "should use #{gbId}"); + } + + @Test + void queryList_withQuery_shouldUseBindVariable() { + Map params = new HashMap<>(); + params.put("query", "search-term"); + String sql = provider.queryList(params); + assertTrue(sql.contains("#{query}"), "should use #{query} bind variable"); + assertFalse(sql.contains("search-term"), "should not contain raw query value"); + } + + @Test + void queryList_withOnline_shouldFilter() { + Map params = new HashMap<>(); + params.put("online", true); + String sql = provider.queryList(params); + assertTrue(sql.contains("'ON'"), "should filter for ON"); + } + + @Test + void queryList_withHasCivilCode_shouldFilter() { + Map params = new HashMap<>(); + params.put("hasCivilCode", true); + String sql = provider.queryList(params); + assertTrue(sql.contains("civil_code) is not null"), "should filter for not null civil code"); + } + + @Test + void queryList_withHasGroup_shouldFilter() { + Map params = new HashMap<>(); + params.put("hasGroup", true); + String sql = provider.queryList(params); + assertTrue(sql.contains("parent_id) is not null"), "should filter for not null parent"); + } + + @Test + void queryChannels_withHasStream_shouldFilter() { + Map params = new HashMap<>(); + params.put("hasStream", true); + String sql = provider.queryChannels(params); + assertTrue(sql.contains("stream_id IS NOT NULL"), "should filter for not null stream_id"); + } +}